Commit 88a4fbfb authored by Adam Osewski's avatar Adam Osewski
Browse files

Few fixes.

* prevent clearing c_thread_buffer between consecutive k-dim data tiles GEMM.
* Limit the number of launched thread blocks.
parent 7ed791b8
...@@ -89,7 +89,9 @@ __global__ void ...@@ -89,7 +89,9 @@ __global__ void
// early exit if no work. // early exit if no work.
if(work_scheduler.tile_id_ >= tile_count) if(work_scheduler.tile_id_ >= tile_count)
{
return; return;
}
index_t group_id = 0; index_t group_id = 0;
index_t offset = 0; index_t offset = 0;
...@@ -101,6 +103,7 @@ __global__ void ...@@ -101,6 +103,7 @@ __global__ void
index_t gemm_tile_id_start = 0; index_t gemm_tile_id_start = 0;
index_t gemm_tile_id_end = grid_size_grp; index_t gemm_tile_id_end = grid_size_grp;
auto gridwise_gemm = GridwiseGemm();
do do
{ {
...@@ -127,11 +130,13 @@ __global__ void ...@@ -127,11 +130,13 @@ __global__ void
const auto StrideA = gemm_desc_ptr[group_id].StrideA; const auto StrideA = gemm_desc_ptr[group_id].StrideA;
const auto StrideB = gemm_desc_ptr[group_id].StrideB; const auto StrideB = gemm_desc_ptr[group_id].StrideB;
auto gridwise_gemm = GridwiseGemm(); auto& results_buffer = gridwise_gemm.GetCThreadBuffer();
b2c_tile_map.CalculateBottomIndex(work_scheduler.tile_id_ - offset); b2c_tile_map.CalculateBottomIndex(work_scheduler.tile_id_ - offset);
results_buffer.Clear();
// Iterate over K dimension for this [M,N] tile // Iterate over K dimension for this [M,N] tile
// still in the same GEMM && the same [M,N] tile // still in the same GEMM && the same [M,N] tile
// TODO: change desc so that few K-tiles will be done in single GEMM.
do do
{ {
// just accumulate results in registers! // just accumulate results in registers!
...@@ -211,7 +216,6 @@ __global__ void ...@@ -211,7 +216,6 @@ __global__ void
} }
else if(work_scheduler.HasTile()) else if(work_scheduler.HasTile())
{ {
// TODO: double buffering in order to not wait for this.
work_scheduler.WaitForReduction(k_batch, output_tile_idx, output_tile_idx_offset); work_scheduler.WaitForReduction(k_batch, output_tile_idx, output_tile_idx_offset);
} }
} while(work_scheduler.HasTile()); } while(work_scheduler.HasTile());
...@@ -747,9 +751,17 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -747,9 +751,17 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
// configuration the first is smaller than the latter. Launching too many workgroups // configuration the first is smaller than the latter. Launching too many workgroups
// mean some of them will have to iterate through all gemm problem descriptors just to // mean some of them will have to iterate through all gemm problem descriptors just to
// find out they have nothing to do which is of course waste of GPU cycles. // find out they have nothing to do which is of course waste of GPU cycles.
const index_t grid_size = ck::math::min(arg.tile_count_, max_occupancy_grid_size); index_t grid_size = ck::math::min(arg.tile_count_, max_occupancy_grid_size);
const index_t tiles_per_block = (arg.tile_count_ + grid_size - 1) / grid_size; const index_t tiles_per_block = (arg.tile_count_ + grid_size - 1) / grid_size;
// Make a correction to grid size in order to get rid of workgroups which does not have
// anything to work.
if(arg.tile_count_ > max_occupancy_grid_size &&
grid_size * tiles_per_block > arg.tile_count_)
{
grid_size = (arg.tile_count_ + tiles_per_block - 1) / tiles_per_block;
}
std::size_t acc_workspace_size_bytes = Block2ETileMapKSplit::GetAccWorkspaceSize( std::size_t acc_workspace_size_bytes = Block2ETileMapKSplit::GetAccWorkspaceSize(
sizeof(typename GridwiseGemm::AccType), grid_size); sizeof(typename GridwiseGemm::AccType), grid_size);
...@@ -774,11 +786,10 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -774,11 +786,10 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
}; };
return launch_and_time_kernel_with_preprocess( return launch_and_time_kernel_with_preprocess(
// return launch_and_time_kernel(
stream_config, stream_config,
preprocess, preprocess,
kernel, kernel,
dim3(ck::math::min(arg.tile_count_, max_occupancy_grid_size)), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
cast_pointer_to_constant_address_space(dev_gemm_args), cast_pointer_to_constant_address_space(dev_gemm_args),
......
...@@ -754,6 +754,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -754,6 +754,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)) / a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)) /
KPerBlock); KPerBlock);
bool clear_c_thread_buf = false;
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_kbatch_ak0_m_ak1, gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_kbatch_ak0_m_ak1,
a_block_desc_kbatch_ak0_m_ak1, a_block_desc_kbatch_ak0_m_ak1,
a_blockwise_copy, a_blockwise_copy,
...@@ -768,7 +770,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -768,7 +770,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
b_block_slice_copy_step, b_block_slice_copy_step,
blockwise_gemm_, blockwise_gemm_,
c_thread_buf, c_thread_buf,
num_k_block_main_loop); num_k_block_main_loop,
clear_c_thread_buf);
} }
template <bool HasMainKBlockLoop, typename Block2ETileMap> template <bool HasMainKBlockLoop, typename Block2ETileMap>
......
...@@ -55,7 +55,8 @@ struct GridwiseGemmPipeline_v1<1> ...@@ -55,7 +55,8 @@ struct GridwiseGemmPipeline_v1<1>
const BBlockTransferStep& b_block_copy_step, const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm, const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
index_t num_loop) index_t num_loop,
bool clear_c_thread_buf = true)
{ {
// preload data into LDS // preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
...@@ -65,7 +66,8 @@ struct GridwiseGemmPipeline_v1<1> ...@@ -65,7 +66,8 @@ struct GridwiseGemmPipeline_v1<1>
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C // Initialize C
c_thread_buf.Clear(); if(clear_c_thread_buf)
c_thread_buf.Clear();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
......
...@@ -93,7 +93,6 @@ class StridedReductionTileLoop ...@@ -93,7 +93,6 @@ class StridedReductionTileLoop
FlagFinished(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset) FlagFinished(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
{ {
const auto fidx = GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset); const auto fidx = GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset);
finished_block_flags_.inc(fidx); finished_block_flags_.inc(fidx);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment