Commit 88a4fbfb authored by Adam Osewski's avatar Adam Osewski
Browse files

Few fixes.

* prevent clearing c_thread_buffer between consecutive k-dim data tiles GEMM.
* Limit the number of launched thread blocks.
parent 7ed791b8
......@@ -89,7 +89,9 @@ __global__ void
// early exit if no work.
if(work_scheduler.tile_id_ >= tile_count)
{
return;
}
index_t group_id = 0;
index_t offset = 0;
......@@ -101,6 +103,7 @@ __global__ void
index_t gemm_tile_id_start = 0;
index_t gemm_tile_id_end = grid_size_grp;
auto gridwise_gemm = GridwiseGemm();
do
{
......@@ -127,11 +130,13 @@ __global__ void
const auto StrideA = gemm_desc_ptr[group_id].StrideA;
const auto StrideB = gemm_desc_ptr[group_id].StrideB;
auto gridwise_gemm = GridwiseGemm();
auto& results_buffer = gridwise_gemm.GetCThreadBuffer();
b2c_tile_map.CalculateBottomIndex(work_scheduler.tile_id_ - offset);
results_buffer.Clear();
// Iterate over K dimension for this [M,N] tile
// still in the same GEMM && the same [M,N] tile
// TODO: change desc so that few K-tiles will be done in single GEMM.
do
{
// just accumulate results in registers!
......@@ -211,7 +216,6 @@ __global__ void
}
else if(work_scheduler.HasTile())
{
// TODO: double buffering in order to not wait for this.
work_scheduler.WaitForReduction(k_batch, output_tile_idx, output_tile_idx_offset);
}
} while(work_scheduler.HasTile());
......@@ -747,9 +751,17 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
// configuration the first is smaller than the latter. Launching too many workgroups
// mean some of them will have to iterate through all gemm problem descriptors just to
// find out they have nothing to do which is of course waste of GPU cycles.
const index_t grid_size = ck::math::min(arg.tile_count_, max_occupancy_grid_size);
index_t grid_size = ck::math::min(arg.tile_count_, max_occupancy_grid_size);
const index_t tiles_per_block = (arg.tile_count_ + grid_size - 1) / grid_size;
// Make a correction to grid size in order to get rid of workgroups which does not have
// anything to work.
if(arg.tile_count_ > max_occupancy_grid_size &&
grid_size * tiles_per_block > arg.tile_count_)
{
grid_size = (arg.tile_count_ + tiles_per_block - 1) / tiles_per_block;
}
std::size_t acc_workspace_size_bytes = Block2ETileMapKSplit::GetAccWorkspaceSize(
sizeof(typename GridwiseGemm::AccType), grid_size);
......@@ -774,11 +786,10 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
};
return launch_and_time_kernel_with_preprocess(
// return launch_and_time_kernel(
stream_config,
preprocess,
kernel,
dim3(ck::math::min(arg.tile_count_, max_occupancy_grid_size)),
dim3(grid_size),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(dev_gemm_args),
......
......@@ -754,6 +754,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)) /
KPerBlock);
bool clear_c_thread_buf = false;
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_kbatch_ak0_m_ak1,
a_block_desc_kbatch_ak0_m_ak1,
a_blockwise_copy,
......@@ -768,7 +770,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
b_block_slice_copy_step,
blockwise_gemm_,
c_thread_buf,
num_k_block_main_loop);
num_k_block_main_loop,
clear_c_thread_buf);
}
template <bool HasMainKBlockLoop, typename Block2ETileMap>
......
......@@ -55,7 +55,8 @@ struct GridwiseGemmPipeline_v1<1>
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
index_t num_loop,
bool clear_c_thread_buf = true)
{
// preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
......@@ -65,7 +66,8 @@ struct GridwiseGemmPipeline_v1<1>
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
if(clear_c_thread_buf)
c_thread_buf.Clear();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
......
......@@ -93,7 +93,6 @@ class StridedReductionTileLoop
FlagFinished(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
{
const auto fidx = GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset);
finished_block_flags_.inc(fidx);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment