Commit 7e71ea99 authored by Adam Osewski's avatar Adam Osewski
Browse files

Commit debug WIP for sharing.

parent 734df790
...@@ -213,7 +213,7 @@ ...@@ -213,7 +213,7 @@
#define CK_WORKAROUND_SWDEV_388832 1 #define CK_WORKAROUND_SWDEV_388832 1
// flag to enable (1) or disable (0) the debugging output in some kernels // flag to enable (1) or disable (0) the debugging output in some kernels
#define DEBUG_LOG 0 #define DEBUG_LOG 1
// denorm test fix, required to work around dissue // denorm test fix, required to work around dissue
#ifndef CK_WORKAROUND_DENORM_FIX #ifndef CK_WORKAROUND_DENORM_FIX
......
...@@ -103,14 +103,17 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, ...@@ -103,14 +103,17 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
block_dim.y, block_dim.y,
block_dim.z); block_dim.z);
printf("Warm up 1 time\n"); printf("Warm up %d times\n", stream_config.cold_niters_);
#endif #endif
// warm up // warm up
preprocess(); for(int i = 0; i < stream_config.cold_niters_; ++i)
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...); {
hip_check_error(hipGetLastError()); preprocess();
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
}
const int nrepeat = 10; const int nrepeat = stream_config.nrepeat_;
#if DEBUG_LOG #if DEBUG_LOG
printf("Start running %d times...\n", nrepeat); printf("Start running %d times...\n", nrepeat);
#endif #endif
......
...@@ -68,15 +68,15 @@ __global__ void ...@@ -68,15 +68,15 @@ __global__ void
void* const __restrict__ p_workspace, void* const __restrict__ p_workspace,
const index_t tile_count, const index_t tile_count,
const index_t k_batch, const index_t k_batch,
const AElementwiseOperation a_element_op, [[maybe_unused]] const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, [[maybe_unused]] const BElementwiseOperation b_element_op,
[[maybe_unused]] const CDEElementwiseOperation cde_element_op) [[maybe_unused]] const CDEElementwiseOperation cde_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size]; [[maybe_unused]] __shared__ uint8_t p_shared[shared_size];
const auto gemm_desc_ptr = const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const)); reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
...@@ -105,6 +105,12 @@ __global__ void ...@@ -105,6 +105,12 @@ __global__ void
index_t gemm_tile_id_end = grid_size_grp; index_t gemm_tile_id_end = grid_size_grp;
auto gridwise_gemm = GridwiseGemm(); auto gridwise_gemm = GridwiseGemm();
[[maybe_unused]] auto is_thread_local_1d_id_idx = [](auto... Ids) -> bool
{
const auto tid = get_thread_local_1d_id();
return ((tid == Ids) || ... );
};
do do
{ {
// Find corresponding GEMM group for our tile // Find corresponding GEMM group for our tile
...@@ -123,12 +129,12 @@ __global__ void ...@@ -123,12 +129,12 @@ __global__ void
gemm_tile_id_end = offset + grid_size_grp; gemm_tile_id_end = offset + grid_size_grp;
} }
const auto p_a_grid = reinterpret_cast<const FloatA*>(gemm_desc_ptr[group_id].p_a_grid); [[maybe_unused]] const auto p_a_grid = reinterpret_cast<const FloatA*>(gemm_desc_ptr[group_id].p_a_grid);
const auto p_b_grid = reinterpret_cast<const FloatB*>(gemm_desc_ptr[group_id].p_b_grid); [[maybe_unused]] const auto p_b_grid = reinterpret_cast<const FloatB*>(gemm_desc_ptr[group_id].p_b_grid);
const auto K = gemm_desc_ptr[group_id].K; [[maybe_unused]] const auto K = gemm_desc_ptr[group_id].K;
const auto StrideA = gemm_desc_ptr[group_id].StrideA; [[maybe_unused]] const auto StrideA = gemm_desc_ptr[group_id].StrideA;
const auto StrideB = gemm_desc_ptr[group_id].StrideB; [[maybe_unused]] const auto StrideB = gemm_desc_ptr[group_id].StrideB;
auto& results_buffer = gridwise_gemm.GetCThreadBuffer(); auto& results_buffer = gridwise_gemm.GetCThreadBuffer();
b2c_tile_map.CalculateBottomIndex(work_scheduler.tile_id_ - offset); b2c_tile_map.CalculateBottomIndex(work_scheduler.tile_id_ - offset);
...@@ -137,21 +143,32 @@ __global__ void ...@@ -137,21 +143,32 @@ __global__ void
// Iterate over K dimension for this [M,N] tile // Iterate over K dimension for this [M,N] tile
// still in the same GEMM && the same [M,N] tile // still in the same GEMM && the same [M,N] tile
// TODO: change desc so that few K-tiles will be done in single GEMM. // TODO: change desc so that few K-tiles will be done in single GEMM.
// {
// if (is_thread_local_1d_id_idx(0))
// {
// printf("bid: %d, group: %d, accumulate tile id (M,N,K): [%d, %d, %d] \n",
// static_cast<index_t>(blockIdx.x),
// group_id,
// b2c_tile_map.GetTileMIdx(),
// b2c_tile_map.GetTileNIdx(),
// b2c_tile_map.GetTileKIdx());
// }
// }
do do
{ {
// just accumulate results in registers! // just accumulate results in registers!
gridwise_gemm.template RunGEMM<HasMainKBlockLoop>(p_a_grid, // gridwise_gemm.template RunGEMM<HasMainKBlockLoop>(p_a_grid,
p_b_grid, // p_b_grid,
static_cast<void*>(p_shared), // static_cast<void*>(p_shared),
a_element_op, // a_element_op,
b_element_op, // b_element_op,
M, // M,
N, // N,
K, // K,
StrideA, // StrideA,
StrideB, // StrideB,
k_batch, // k_batch,
b2c_tile_map); // b2c_tile_map);
} while(work_scheduler.GetNextTile() && b2c_tile_map.GetNextKTileIdx()); } while(work_scheduler.GetNextTile() && b2c_tile_map.GetNextKTileIdx());
...@@ -167,51 +184,122 @@ __global__ void ...@@ -167,51 +184,122 @@ __global__ void
work_scheduler.FlagFinished(k_batch, output_tile_idx, output_tile_idx_offset); work_scheduler.FlagFinished(k_batch, output_tile_idx, output_tile_idx_offset);
// {
// // const uint32_t flag_v2 = __builtin_amdgcn_readfirstlane(
// // work_scheduler.GetFlagValue(k_batch, output_tile_idx, output_tile_idx_offset));
if (is_thread_local_1d_id_idx(0))
{
printf("bid: %d, group: %d, FlagFInished \n",
static_cast<index_t>(blockIdx.x),
group_id);
// printf("bid: %d, group: %d, FlagFInished flag_v[%u]: %u\n",
// static_cast<index_t>(blockIdx.x),
// group_id)
// work_scheduler.GetWorkgroupFlagIdx(k_batch, output_tile_idx, output_tile_idx_offset),
// flag_v2);
}
// }
// The workgroup which processed first K tile accumulates results and stores to GMEM // The workgroup which processed first K tile accumulates results and stores to GMEM
if(b2c_tile_map.IsFirstKSplitBlock()) if(b2c_tile_map.IsFirstKSplitBlock())
{ {
if (is_thread_local_1d_id_idx(0))
{
printf("bid: %d, group: %d, Will wait for neighbours... \n",
static_cast<index_t>(blockIdx.x),
group_id);
}
// Wait untill all other blocks for this [M,N] tile store their results. // Wait untill all other blocks for this [M,N] tile store their results.
work_scheduler.WaitForNeighbours(k_batch, output_tile_idx, output_tile_idx_offset); work_scheduler.WaitForNeighbours(k_batch, output_tile_idx, output_tile_idx_offset);
// Accumulate partial results. We can have different # of workgroups to reduce, thus we // Accumulate partial results. We can have different # of workgroups to reduce, thus we
// read actual flag value. // read actual flag value.
const uint32_t flag_v = __builtin_amdgcn_readfirstlane( [[maybe_unused]] const uint32_t flag_v = __builtin_amdgcn_readfirstlane(
work_scheduler.GetFlagValue(k_batch, output_tile_idx, output_tile_idx_offset)); work_scheduler.GetFlagValue(k_batch, output_tile_idx, output_tile_idx_offset));
// {
// if (is_thread_local_1d_id_idx(0))
// {
// printf("bid: %d, group: %d, WaitForNeighbours flag_v[%u]: %u\n",
// static_cast<index_t>(blockIdx.x),
// group_id,
// work_scheduler.GetWorkgroupFlagIdx(k_batch, output_tile_idx, output_tile_idx_offset),
// static_cast<index_t>(blockIdx.x));
// // flag_v);
// }
// }
// using CThreadBuffer = remove_cvref_t<decltype(results_buffer)>;
// constexpr index_t n_v = CThreadBuffer::num_of_v_.value;
// constexpr index_t s_per_v = CThreadBuffer::s_per_v.value;
// static_for<0, n_v, 1>{}([&](auto v) {
// static_for<0, s_per_v, 1>{}([&](auto s) {
// // printf("bid: %d; tid: %d; [Partial results] c_thread_buff[%d, %d]:
// // %f\n",
// // static_cast<index_t>(blockIdx.x),
// // static_cast<index_t>(threadIdx.x),
// // v.value,
// // s.value,
// // static_cast<float>(results_buffer[v * Number<s_per_v>{} + s])
// // );
// results_buffer(v * Number<s_per_v>{} + s) = threadIdx.x * v + s;
// });
// });
// Accumulate only when there is at least two workgroups processing splitk data-tiles // Accumulate only when there is at least two workgroups processing splitk data-tiles
// across same MN-output tile. // across same MN-output tile.
if(flag_v > 1) // if(flag_v > 1)
gridwise_gemm.AccumulatePartials(p_workspace, flag_v); // gridwise_gemm.AccumulatePartials(p_workspace, flag_v);
if (is_thread_local_1d_id_idx(0))
{
printf("bid: %d, group: %d, Reset flag \n",
static_cast<index_t>(blockIdx.x),
group_id);
}
// Signal waiting blocks that they can start use their workspace. // Signal waiting blocks that they can start use their workspace.
work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset); work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset);
const auto p_e_grid = reinterpret_cast<FloatC*>(gemm_desc_ptr[group_id].p_e_grid); // const auto p_e_grid = reinterpret_cast<FloatC*>(gemm_desc_ptr[group_id].p_e_grid);
const auto stride_e = gemm_desc_ptr[group_id].StrideE; // const auto stride_e = gemm_desc_ptr[group_id].StrideE;
const auto stride_ds = gemm_desc_ptr[group_id].StrideDs; // const auto stride_ds = gemm_desc_ptr[group_id].StrideDs;
constexpr auto NumDTensor = DsDataType::Size(); // constexpr auto NumDTensor = DsDataType::Size();
using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer()); // using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer());
DsGridPointer p_ds_grid; // DsGridPointer p_ds_grid;
static_for<0, NumDTensor, 1>{}([&](auto i) { // static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>; // using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]); // p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
}); // });
gridwise_gemm.template RunWrite(p_ds_grid, // gridwise_gemm.template RunWrite(p_ds_grid,
p_e_grid, // p_e_grid,
static_cast<void*>(p_shared), // static_cast<void*>(p_shared),
M, // M,
N, // N,
stride_ds, // stride_ds,
stride_e, // stride_e,
cde_element_op, // cde_element_op,
b2c_tile_map); // b2c_tile_map);
} }
else if(work_scheduler.HasTile()) else if(work_scheduler.HasTile())
{ {
{
// const uint32_t flag_v2 = __builtin_amdgcn_readfirstlane(
const uint32_t flag_v2 = work_scheduler.GetFlagValue(k_batch, output_tile_idx, output_tile_idx_offset);
if (is_thread_local_1d_id_idx(0))
{
printf("bid: %d, group: %d, Waiting for Reduction flag_v[%u]: %u\n",
static_cast<index_t>(blockIdx.x),
group_id,
work_scheduler.GetWorkgroupFlagIdx(k_batch, output_tile_idx, output_tile_idx_offset),
// static_cast<index_t>(blockIdx.x));
flag_v2);
}
}
work_scheduler.WaitForReduction(k_batch, output_tile_idx, output_tile_idx_offset); work_scheduler.WaitForReduction(k_batch, output_tile_idx, output_tile_idx_offset);
} }
} while(work_scheduler.HasTile()); } while(work_scheduler.HasTile());
...@@ -751,7 +839,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -751,7 +839,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
void* p_flags = reinterpret_cast<char*>(dev_gemm_workspace) + void* p_flags = reinterpret_cast<char*>(dev_gemm_workspace) +
Block2ETileMapKSplit::GetAccWorkspaceSize( Block2ETileMapKSplit::GetAccWorkspaceSize(
sizeof(typename GridwiseGemm::AccType), grid_size); sizeof(typename GridwiseGemm::AccType), grid_size);
std::size_t flag_count = (grid_size * tiles_per_block + arg.K_BATCH - 1) / arg.K_BATCH; // std::size_t flag_count = (grid_size * tiles_per_block + arg.K_BATCH - 1) / arg.K_BATCH;
std::size_t flag_count = arg.tile_count_ / arg.K_BATCH;
if(stream_config.log_level_ > 0) if(stream_config.log_level_ > 0)
{ {
...@@ -987,7 +1076,14 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -987,7 +1076,14 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
arg.gpu_cu_count_ * std::min(arg.occupancy_num_blocks_, KernelConfig::CU_BLOCKS); arg.gpu_cu_count_ * std::min(arg.occupancy_num_blocks_, KernelConfig::CU_BLOCKS);
int grid_size = std::min(arg.tile_count_, occ_grid_size); int grid_size = std::min(arg.tile_count_, occ_grid_size);
int tiles_per_block = (arg.tile_count_ + grid_size - 1) / grid_size; int tiles_per_block = (arg.tile_count_ + grid_size - 1) / grid_size;
int flag_count = (grid_size * tiles_per_block + arg.K_BATCH - 1) / arg.K_BATCH;
if(arg.tile_count_ > occ_grid_size &&
grid_size * tiles_per_block > arg.tile_count_)
{
grid_size = (arg.tile_count_ + tiles_per_block - 1) / tiles_per_block;
}
// int flag_count = (grid_size * tiles_per_block + arg.K_BATCH - 1) / arg.K_BATCH;
int flag_count = arg.tile_count_ / arg.K_BATCH;
// This would be the maximum needed workspace size. Since actual grid size, which determines // This would be the maximum needed workspace size. Since actual grid size, which determines
// the amount of workspace bytes needed, may be less due to the number of available CUs in // the amount of workspace bytes needed, may be less due to the number of available CUs in
......
...@@ -106,6 +106,13 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -106,6 +106,13 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
using GridwiseGemmPipe = remove_cvref_t< using GridwiseGemmPipe = remove_cvref_t<
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
template <index_t... Ids>
__device__ static bool is_thread_local_1d_id_idx()
{
const auto tid = get_thread_local_1d_id();
return ((tid == Ids) || ...);
}
public: public:
using AccType = AccDataType; using AccType = AccDataType;
...@@ -906,6 +913,32 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -906,6 +913,32 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
Sequence<6>{}, Sequence<6>{},
Sequence<7>{})); Sequence<7>{}));
// if (is_thread_local_1d_id_idx<0>())
// {
// // printf("bid: %d; tid: %d; [Store Partials] c_block_desc:[%d, %d, %d, %d, %d, %d, %d, %d]\n",
// // static_cast<index_t>(blockIdx.x),
// // static_cast<index_t>(threadIdx.x),
// // M0.value,
// // N0.value,
// // M1.value,
// // N1.value,
// // M2.value,
// // M3.value,
// // M4.value,
// // N2.value);
// printf("bid: %d; tid: %d; [Store Partials] wrkspace_desc:[%d, %d, %d, %d, %d, %d, %d, %d]\n",
// static_cast<index_t>(blockIdx.x),
// static_cast<index_t>(threadIdx.x),
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0),
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1),
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2).value,
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3).value,
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4).value,
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5).value,
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6).value,
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7).value);
// }
auto p_workspace_grid = reinterpret_cast<AccDataType*>(p_workspace); auto p_workspace_grid = reinterpret_cast<AccDataType*>(p_workspace);
auto w_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto w_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_workspace_grid, workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); p_workspace_grid, workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
...@@ -963,11 +996,33 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -963,11 +996,33 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
n_thread_data_on_block_idx[I2]), n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}}; ck::tensor_operation::element_wise::PassThrough{}};
// if (is_thread_local_1d_id_idx<0, 64, 223>())
// {
// printf("[StorePartials] bid: %d, tid: %d: dst origin idx[%d, %d, %d, %d, %d, %d, %d, %d]\n",
// static_cast<index_t>(blockIdx.x),
// static_cast<index_t>(threadIdx.x),
// (static_cast<index_t>(blockIdx.x)) * MXdlPerWave,
// n_thread_data_on_block_idx[I0],
// m_thread_data_on_block_idx[I1],
// n_thread_data_on_block_idx[I1],
// m_thread_data_on_block_idx[I2],
// m_thread_data_on_block_idx[I3],
// m_thread_data_on_block_idx[I4],
// n_thread_data_on_block_idx[I2]);
// }
c_thread_copy_vgpr_to_gmem.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_thread_copy_vgpr_to_gmem.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf, c_thread_buf,
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
w_grid_buf); w_grid_buf);
if (is_thread_local_1d_id_idx<0>())
{
printf("[StorePartials] done. bid: %d, tid: %d\n",
static_cast<index_t>(blockIdx.x),
static_cast<index_t>(threadIdx.x));
}
} }
__device__ void AccumulatePartials(void* __restrict__ p_workspace, uint32_t reduce_count) __device__ void AccumulatePartials(void* __restrict__ p_workspace, uint32_t reduce_count)
......
...@@ -51,7 +51,7 @@ class StridedReductionTileLoop ...@@ -51,7 +51,7 @@ class StridedReductionTileLoop
{ {
tile_id_++; tile_id_++;
block_tile_idx_++; block_tile_idx_++;
return tile_id_ < tile_count_ && block_tile_idx_ < tiles_per_block_; return HasTile();
} }
__device__ index_t GetFlagCount(index_t k_tiles) const __device__ index_t GetFlagCount(index_t k_tiles) const
...@@ -75,11 +75,12 @@ class StridedReductionTileLoop ...@@ -75,11 +75,12 @@ class StridedReductionTileLoop
/// ///
/// @return The workgroup flag index. /// @return The workgroup flag index.
/// ///
__device__ uint32_t GetWorkgroupFlagIdx(index_t k_tiles, __device__ uint32_t GetWorkgroupFlagIdx([[maybe_unused]] index_t k_tiles,
index_t output_tile_idx, index_t output_tile_idx,
index_t output_tile_idx_offset) const index_t output_tile_idx_offset) const
{ {
return (output_tile_idx + output_tile_idx_offset) % GetFlagCount(k_tiles); // return (output_tile_idx + output_tile_idx_offset) % GetFlagCount(k_tiles);
return output_tile_idx + output_tile_idx_offset;
} }
/// ///
...@@ -92,7 +93,7 @@ class StridedReductionTileLoop ...@@ -92,7 +93,7 @@ class StridedReductionTileLoop
__device__ void __device__ void
FlagFinished(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset) FlagFinished(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
{ {
const auto fidx = GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset); /* [[maybe_unused]] */const auto fidx = GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset);
finished_block_flags_.inc(fidx); finished_block_flags_.inc(fidx);
} }
...@@ -111,8 +112,10 @@ class StridedReductionTileLoop ...@@ -111,8 +112,10 @@ class StridedReductionTileLoop
// We use < because for some cases we may have +1 more workgroups per dim. // We use < because for some cases we may have +1 more workgroups per dim.
// Ie when k_tiles = 5, tiles_per_block = 3. // Ie when k_tiles = 5, tiles_per_block = 3.
finished_block_flags_.wait_lt( finished_block_flags_.wait_lt(
GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset), GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset),
workgroups_per_dim); workgroups_per_dim);
// [[maybe_unused]] const auto fidx = GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset);
} }
/// ///
...@@ -128,6 +131,8 @@ class StridedReductionTileLoop ...@@ -128,6 +131,8 @@ class StridedReductionTileLoop
// Wait untill the counter has been reset. // Wait untill the counter has been reset.
finished_block_flags_.wait_eq( finished_block_flags_.wait_eq(
GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset), 0); GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset), 0);
// [[maybe_unused]] const auto fidx = GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset);
} }
/// ///
...@@ -141,6 +146,8 @@ class StridedReductionTileLoop ...@@ -141,6 +146,8 @@ class StridedReductionTileLoop
{ {
finished_block_flags_.reset( finished_block_flags_.reset(
GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset)); GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset));
// [[maybe_unused]] const auto fidx = GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset);
} }
/// ///
......
...@@ -16,96 +16,96 @@ namespace tensor_operation { ...@@ -16,96 +16,96 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( // void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row, // std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row, // Row,
Empty_Tuple, // Empty_Tuple,
Row, // Row,
F16, // F16,
F16, // F16,
Empty_Tuple, // Empty_Tuple,
F16, // F16,
PassThrough, // PassThrough,
PassThrough, // PassThrough,
PassThrough>>>& instances); // PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( // void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row, // std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col, // Col,
Empty_Tuple, // Empty_Tuple,
Row, // Row,
F16, // F16,
F16, // F16,
Empty_Tuple, // Empty_Tuple,
F16, // F16,
PassThrough, // PassThrough,
PassThrough, // PassThrough,
PassThrough>>>& instances); // PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances( // void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col, // std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Row, // Row,
Empty_Tuple, // Empty_Tuple,
Row, // Row,
F16, // F16,
F16, // F16,
Empty_Tuple, // Empty_Tuple,
F16, // F16,
PassThrough, // PassThrough,
PassThrough, // PassThrough,
PassThrough>>>& instances); // PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances( // void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col, // std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Col, // Col,
Empty_Tuple, // Empty_Tuple,
Row, // Row,
F16, // F16,
F16, // F16,
Empty_Tuple, // Empty_Tuple,
F16, // F16,
PassThrough, // PassThrough,
PassThrough, // PassThrough,
PassThrough>>>& instances); // PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances( // void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row, // std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col, // Col,
Empty_Tuple, // Empty_Tuple,
Row, // Row,
F16, // F16,
F16, // F16,
Empty_Tuple, // Empty_Tuple,
F16, // F16,
PassThrough, // PassThrough,
PassThrough, // PassThrough,
PassThrough>>>& instances); // PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances( // void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row, // std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row, // Row,
Empty_Tuple, // Empty_Tuple,
Row, // Row,
F16, // F16,
F16, // F16,
Empty_Tuple, // Empty_Tuple,
F16, // F16,
PassThrough, // PassThrough,
PassThrough, // PassThrough,
PassThrough>>>& instances); // PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances( // void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row, // std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col, // Col,
Empty_Tuple, // Empty_Tuple,
Row, // Row,
F16, // F16,
F16, // F16,
Empty_Tuple, // Empty_Tuple,
F16, // F16,
PassThrough, // PassThrough,
PassThrough, // PassThrough,
PassThrough>>>& instances); // PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances( void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row, std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
...@@ -120,31 +120,31 @@ void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances ...@@ -120,31 +120,31 @@ void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances( // void add_device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row, // std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row, // Row,
Empty_Tuple, // Empty_Tuple,
Row, // Row,
F16, // F16,
F8, // F8,
Empty_Tuple, // Empty_Tuple,
F16, // F16,
PassThrough, // PassThrough,
PassThrough, // PassThrough,
PassThrough>>>& instances); // PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances( // void add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row, // std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row, // Row,
Empty_Tuple, // Empty_Tuple,
Row, // Row,
F8, // F8,
F16, // F16,
Empty_Tuple, // Empty_Tuple,
F16, // F16,
PassThrough, // PassThrough,
PassThrough, // PassThrough,
PassThrough>>>& instances); // PassThrough>>>& instances);
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
...@@ -186,48 +186,48 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -186,48 +186,48 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> && if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>) is_same_v<ELayout, Row>)
{ {
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); // add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(op_ptrs); // add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances( add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances(
op_ptrs); op_ptrs);
} }
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> && else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>) is_same_v<ELayout, Row>)
{ {
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); // add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(op_ptrs); // add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances( // add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances(
op_ptrs); // op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
}
}
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances(op_ptrs);
}
}
else if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
} }
// else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
// is_same_v<ELayout, Row>)
// {
// add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
// }
// else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
// is_same_v<ELayout, Row>)
// {
// add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
// }
} }
// else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
// is_same_v<EDataType, half_t>)
// {
// if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
// is_same_v<ELayout, Row>)
// {
// add_device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances(op_ptrs);
// }
// }
// else if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, half_t> &&
// is_same_v<EDataType, half_t>)
// {
// if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
// is_same_v<ELayout, Row>)
// {
// add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
// }
// }
return op_ptrs; return op_ptrs;
} }
}; };
......
...@@ -17,18 +17,18 @@ namespace device { ...@@ -17,18 +17,18 @@ namespace device {
namespace instance { namespace instance {
// MultiD version // MultiD version
void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instances( // void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row, // std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col, // Col,
Empty_Tuple, // Empty_Tuple,
Row, // Row,
F16, // F16,
F16, // F16,
Empty_Tuple, // Empty_Tuple,
F16, // F16,
PassThrough, // PassThrough,
PassThrough, // PassThrough,
PassThrough>>>& instances); // PassThrough>>>& instances);
void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances( void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row, std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
...@@ -93,8 +93,8 @@ struct DeviceOperationInstanceFactory< ...@@ -93,8 +93,8 @@ struct DeviceOperationInstanceFactory<
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> && else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>) is_same_v<ELayout, Row>)
{ {
add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instances( // add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instances(
op_ptrs); // op_ptrs);
} }
} }
return op_ptrs; return op_ptrs;
......
add_instance_library(device_grouped_gemm_instance add_instance_library(device_grouped_gemm_instance
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp # device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp # device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp # device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp
device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp # device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp # device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp # device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp # device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp # device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instance.cpp # device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instance.cpp
) )
add_instance_library(device_grouped_gemm_multiple_d_instance add_instance_library(device_grouped_gemm_multiple_d_instance
device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instance.cpp # device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance.cpp device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
) )
...@@ -219,6 +219,8 @@ bool profile_ggemm_multid_splitk(int do_verification, ...@@ -219,6 +219,8 @@ bool profile_ggemm_multid_splitk(int do_verification,
// profile device GEMM instances // profile device GEMM instances
for(auto& gemm_ptr : op_ptrs) for(auto& gemm_ptr : op_ptrs)
{ {
std::cout << "Running instance: " << gemm_ptr->GetTypeString() << std::endl;
auto gptr = dynamic_cast<DeviceOp*>(gemm_ptr.get()); auto gptr = dynamic_cast<DeviceOp*>(gemm_ptr.get());
auto argument_ptr = gemm_ptr->MakeArgumentPointer( auto argument_ptr = gemm_ptr->MakeArgumentPointer(
...@@ -247,20 +249,24 @@ bool profile_ggemm_multid_splitk(int do_verification, ...@@ -247,20 +249,24 @@ bool profile_ggemm_multid_splitk(int do_verification,
for(std::size_t j = 0; j < kbatch_list.size(); j++) for(std::size_t j = 0; j < kbatch_list.size(); j++)
{ {
auto kbatch_curr = kbatch_list[j]; auto kbatch_curr = kbatch_list[j];
// std::cout << ">>> kbatch: " << kbatch_curr << std::endl;
gptr->SetKBatchSize(argument_ptr.get(), kbatch_curr); gptr->SetKBatchSize(argument_ptr.get(), kbatch_curr);
DeviceMem gemm_desc_workspace(gemm_ptr->GetWorkSpaceSize(argument_ptr.get())); DeviceMem gemm_desc_workspace(gemm_ptr->GetWorkSpaceSize(argument_ptr.get()));
gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_ptr->SetWorkSpacePointer(argument_ptr.get(),
gemm_desc_workspace.GetDeviceBuffer()); gemm_desc_workspace.GetDeviceBuffer());
// std::cout << "WorkspacePointer set!" << std::endl;
if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
for(std::size_t i = 0; i < gemm_descs.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
c_device_buf[i]->SetZero(); c_device_buf[i]->SetZero();
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); // invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false, 1});
// std::cout << ">>>>>GPU Run end!" << std::endl;
if(do_verification) if(do_verification)
{ {
...@@ -304,12 +310,16 @@ bool profile_ggemm_multid_splitk(int do_verification, ...@@ -304,12 +310,16 @@ bool profile_ggemm_multid_splitk(int do_verification,
<< (instance_pass ? "SUCCEED" : "FAILED") << std::endl; << (instance_pass ? "SUCCEED" : "FAILED") << std::endl;
pass = pass && instance_pass; pass = pass && instance_pass;
std::cout << ">>>>>CPU verification end!" << std::endl;
} }
if(time_kernel) if(time_kernel)
{ {
float avg_time = std::cout << ">>>>>GPU time profiling start!" << std::endl;
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); float avg_time = invoker_ptr->Run(
// argument_ptr.get(), StreamConfig{nullptr, time_kernel, 1, 5, 30});
argument_ptr.get(), StreamConfig{nullptr, time_kernel, 1, 0, 1});
std::size_t flop = 0, num_btype = 0; std::size_t flop = 0, num_btype = 0;
for(std::size_t i = 0; i < gemm_descs.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
{ {
...@@ -335,6 +345,7 @@ bool profile_ggemm_multid_splitk(int do_verification, ...@@ -335,6 +345,7 @@ bool profile_ggemm_multid_splitk(int do_verification,
best_gb_per_sec = gb_per_sec; best_gb_per_sec = gb_per_sec;
best_kbatch = kbatch_curr; best_kbatch = kbatch_curr;
} }
// std::cout << ">>>>>GPU time profiling end!" << std::endl;
} }
} }
else else
......
# ckProfiler # ckProfiler
set(PROFILER_SOURCES set(PROFILER_SOURCES
profiler.cpp profiler.cpp
profile_gemm.cpp # profile_gemm.cpp
profile_gemm_splitk.cpp # profile_gemm_splitk.cpp
profile_gemm_bias_add_reduce.cpp # profile_gemm_bias_add_reduce.cpp
profile_gemm_add_multiply.cpp # profile_gemm_add_multiply.cpp
profile_gemm_multiply_add.cpp # profile_gemm_multiply_add.cpp
profile_gemm_reduce.cpp # profile_gemm_reduce.cpp
profile_batched_gemm.cpp # profile_batched_gemm.cpp
profile_batched_gemm_reduce.cpp # profile_batched_gemm_reduce.cpp
profile_conv_fwd.cpp # profile_conv_fwd.cpp
profile_conv_fwd_bias_relu.cpp # profile_conv_fwd_bias_relu.cpp
profile_conv_fwd_bias_relu_add.cpp # profile_conv_fwd_bias_relu_add.cpp
profile_conv_bwd_data.cpp # profile_conv_bwd_data.cpp
profile_grouped_conv_fwd.cpp # profile_grouped_conv_fwd.cpp
profile_grouped_conv_bwd_weight.cpp # profile_grouped_conv_bwd_weight.cpp
profile_reduce.cpp # profile_reduce.cpp
profile_groupnorm_bwd_data.cpp # profile_groupnorm_bwd_data.cpp
profile_groupnorm_fwd.cpp # profile_groupnorm_fwd.cpp
profile_layernorm_bwd_data.cpp # profile_layernorm_bwd_data.cpp
profile_layernorm_fwd.cpp # profile_layernorm_fwd.cpp
profile_max_pool3d_fwd.cpp # profile_max_pool3d_fwd.cpp
profile_avg_pool3d_bwd.cpp # profile_avg_pool3d_bwd.cpp
profile_max_pool3d_bwd.cpp # profile_max_pool3d_bwd.cpp
profile_softmax.cpp # profile_softmax.cpp
profile_batchnorm_fwd.cpp # profile_batchnorm_fwd.cpp
profile_batchnorm_bwd.cpp # profile_batchnorm_bwd.cpp
profile_batchnorm_infer.cpp # profile_batchnorm_infer.cpp
profile_grouped_conv_bwd_data.cpp # profile_grouped_conv_bwd_data.cpp
profile_conv_tensor_rearrange.cpp # profile_conv_tensor_rearrange.cpp
) )
if(DL_KERNELS) if(DL_KERNELS)
...@@ -36,21 +36,22 @@ if(DL_KERNELS) ...@@ -36,21 +36,22 @@ if(DL_KERNELS)
endif() endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp) # list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp) # list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp) # list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp) # list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp) # list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp) # list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp) # list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp)
list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp) # list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) # list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiple_d_splitk.cpp)
endif() endif()
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp) # list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp)
list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp) # list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp)
endif() endif()
set(PROFILER_EXECUTABLE ckProfiler) set(PROFILER_EXECUTABLE ckProfiler)
...@@ -59,42 +60,42 @@ add_executable(${PROFILER_EXECUTABLE} ${PROFILER_SOURCES}) ...@@ -59,42 +60,42 @@ add_executable(${PROFILER_EXECUTABLE} ${PROFILER_SOURCES})
target_compile_options(${PROFILER_EXECUTABLE} PRIVATE -Wno-global-constructors) target_compile_options(${PROFILER_EXECUTABLE} PRIVATE -Wno-global-constructors)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE utility) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE utility)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_fwd_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_data_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance)
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance)
endif() endif()
...@@ -104,16 +105,17 @@ if(DL_KERNELS) ...@@ -104,16 +105,17 @@ if(DL_KERNELS)
endif() endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_multiple_d_instance)
endif() endif()
rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler) rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler)
...@@ -154,9 +154,9 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p ...@@ -154,9 +154,9 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
// Accumulate partial results. We can have different # of workgroups to reduce, thus we // Accumulate partial results. We can have different # of workgroups to reduce, thus we
// read actual flag value. // read actual flag value.
const index_t flag_v = __builtin_amdgcn_readfirstlane( const uint32_t flag_v = __builtin_amdgcn_readfirstlane(
work_scheduler.GetFlagValue(k_batch, output_tile_idx, output_tile_idx_offset)); work_scheduler.GetFlagValue(k_batch, output_tile_idx, output_tile_idx_offset));
for(index_t i = 1; i < flag_v; ++i) for(uint32_t i = 1; i < flag_v; ++i)
{ {
partial_result += p_workspace[(get_block_1d_id()) * MPerBlock * NPerBlock + partial_result += p_workspace[(get_block_1d_id()) * MPerBlock * NPerBlock +
i * MPerBlock * NPerBlock + get_thread_local_1d_id()]; i * MPerBlock * NPerBlock + get_thread_local_1d_id()];
...@@ -174,7 +174,7 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p ...@@ -174,7 +174,7 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
p_C[(C_m_tile_offset + C_thread_tile_m_idx) * stride_c + C_n_tile_offset + p_C[(C_m_tile_offset + C_thread_tile_m_idx) * stride_c + C_n_tile_offset +
C_thread_tile_n_idx] = partial_result; C_thread_tile_n_idx] = partial_result;
} }
else else if(work_scheduler.HasTile())
{ {
work_scheduler.WaitForReduction(k_batch, output_tile_idx, output_tile_idx_offset); work_scheduler.WaitForReduction(k_batch, output_tile_idx, output_tile_idx_offset);
} }
...@@ -284,10 +284,11 @@ struct GroupedGemmStridedTileLoopReduce ...@@ -284,10 +284,11 @@ struct GroupedGemmStridedTileLoopReduce
DeviceMem gemm_workspace, gemm_flags; DeviceMem gemm_workspace, gemm_flags;
const index_t tiles_per_block = (tile_count + grid_size - 1) / grid_size; // const index_t tiles_per_block = (tile_count + grid_size - 1) / grid_size;
// This is the number of MN-output tiles which we cover with workgroups. // This is the number of MN-output tiles which we cover with workgroups.
// We launch k_batch / tiles_per_block workgroups for each output tile. // We launch k_batch / tiles_per_block workgroups for each output tile.
const index_t flag_count = (grid_size * tiles_per_block + k_batch - 1) / k_batch; // const index_t flag_count = (grid_size * tiles_per_block + k_batch - 1) / k_batch;
const index_t flag_count = tile_count / k_batch;
gemm_workspace.Realloc(grid_size * MPerBlock * NPerBlock * sizeof(float)); gemm_workspace.Realloc(grid_size * MPerBlock * NPerBlock * sizeof(float));
gemm_flags.Realloc(flag_count * sizeof(uint32_t)); gemm_flags.Realloc(flag_count * sizeof(uint32_t));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment