Commit ad0e4083 authored by Adam Osewski's avatar Adam Osewski
Browse files

Multiple changes to global kernel function.

* StorePartials work on offseted pointer.
* Read flags as uint32_t value
* Accumulate partials only if there is more than one cooperating workgroup
* Add condition for waiting on reduction end, only when there is still work to do.
* Fix creation od a/b grid desc in CheckArgument.
* LaunchKernel will use preprocess lambda to set flags value to zero.
* Add condition in IsSupportedArgument to check if xdl is supported.
parent f8cbbd1b
...@@ -91,9 +91,6 @@ __global__ void ...@@ -91,9 +91,6 @@ __global__ void
if(work_scheduler.tile_id_ >= tile_count) if(work_scheduler.tile_id_ >= tile_count)
return; return;
if(get_thread_global_1d_id() < work_scheduler.GetFlagCount(k_batch))
p_flags[get_thread_global_1d_id()] = 0;
index_t group_id = 0; index_t group_id = 0;
index_t offset = 0; index_t offset = 0;
...@@ -153,16 +150,20 @@ __global__ void ...@@ -153,16 +150,20 @@ __global__ void
} while(work_scheduler.GetNextTile() && b2c_tile_map.GetNextKTileIdx()); } while(work_scheduler.GetNextTile() && b2c_tile_map.GetNextKTileIdx());
const index_t output_tile_idx =
__builtin_amdgcn_readfirstlane(b2c_tile_map.GetOutputTileIdx());
const index_t output_tile_idx_offset = __builtin_amdgcn_readfirstlane(offset / k_batch);
// if (changed group_id || next [M,N] tile) // if (changed group_id || next [M,N] tile)
if(!b2c_tile_map.IsFirstKSplitBlock()) if(!b2c_tile_map.IsFirstKSplitBlock())
{ {
gridwise_gemm.StorePartials(p_workspace); void* __restrict__ p_block_workspace = reinterpret_cast<void* __restrict__>(
reinterpret_cast<char*>(p_workspace) + blockIdx.x * GridwiseGemm::GetMPerBlock() *
GridwiseGemm::GetNPerBlock() *
sizeof(typename GridwiseGemm::AccType));
gridwise_gemm.StorePartials(p_block_workspace);
} }
const index_t output_tile_idx =
__builtin_amdgcn_readfirstlane(b2c_tile_map.GetOutputTileIdx());
const index_t output_tile_idx_offset = __builtin_amdgcn_readfirstlane(offset / k_batch);
work_scheduler.FlagFinished(k_batch, output_tile_idx, output_tile_idx_offset); work_scheduler.FlagFinished(k_batch, output_tile_idx, output_tile_idx_offset);
// The workgroup which processed first K tile accumulates results and stores to GMEM // The workgroup which processed first K tile accumulates results and stores to GMEM
...@@ -173,10 +174,13 @@ __global__ void ...@@ -173,10 +174,13 @@ __global__ void
// Accumulate partial results. We can have different # of workgroups to reduce, thus we // Accumulate partial results. We can have different # of workgroups to reduce, thus we
// read actual flag value. // read actual flag value.
const index_t flag_v = __builtin_amdgcn_readfirstlane( const uint32_t flag_v = __builtin_amdgcn_readfirstlane(
work_scheduler.GetFlagValue(k_batch, output_tile_idx, output_tile_idx_offset)); work_scheduler.GetFlagValue(k_batch, output_tile_idx, output_tile_idx_offset));
gridwise_gemm.AccumulatePartials(p_workspace, flag_v); // Accumulate only when there is at least two workgroups processing splitk data-tiles
// across same MN-output tile.
if(flag_v > 1)
gridwise_gemm.AccumulatePartials(p_workspace, flag_v);
// Signal waiting blocks that they can start use their workspace. // Signal waiting blocks that they can start use their workspace.
work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset); work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset);
...@@ -192,7 +196,6 @@ __global__ void ...@@ -192,7 +196,6 @@ __global__ void
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>; using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]); p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
}); });
...@@ -206,7 +209,7 @@ __global__ void ...@@ -206,7 +209,7 @@ __global__ void
cde_element_op, cde_element_op,
b2c_tile_map); b2c_tile_map);
} }
else else if(work_scheduler.HasTile())
{ {
// TODO: double buffering in order to not wait for this. // TODO: double buffering in order to not wait for this.
work_scheduler.WaitForReduction(k_batch, output_tile_idx, output_tile_idx_offset); work_scheduler.WaitForReduction(k_batch, output_tile_idx, output_tile_idx_offset);
...@@ -647,10 +650,18 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -647,10 +650,18 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
const auto a_grid_desc_kbatch_ak0_m_ak1 = const auto a_grid_desc_kbatch_ak0_m_ak1 =
GridwiseGemm::MakeAGridDescriptor_KBatch_AK0_M_AK1( GridwiseGemm::MakeAGridDescriptor_KBatch_AK0_M_AK1(
arg.gemm_kernel_args_[0].M, gemm_arg.M, gemm_arg.K, gemm_arg.StrideA, arg.K_BATCH);
arg.gemm_kernel_args_[0].K,
arg.gemm_kernel_args_[0].StrideA, const auto b_grid_desc_kbatch_bk0_n_bk1 =
arg.K_BATCH); GridwiseGemm::MakeBGridDescriptor_KBatch_BK0_N_BK1(
gemm_arg.K, gemm_arg.N, gemm_arg.StrideB, arg.K_BATCH);
std::cout << "group id: " << i
<< ", kbatch: " << a_grid_desc_kbatch_ak0_m_ak1.GetLength(I0)
<< ", AK0: " << a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1)
<< ", AK1: " << a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)
<< ", BK0: " << b_grid_desc_kbatch_bk0_n_bk1.GetLength(I1)
<< ", BK1: " << b_grid_desc_kbatch_bk0_n_bk1.GetLength(I3) << std::endl;
bool not_all_have_main_k_block_loop_same = bool not_all_have_main_k_block_loop_same =
all_have_main_k_block_loop xor GridwiseGemm::CalculateHasMainKBlockLoop( all_have_main_k_block_loop xor GridwiseGemm::CalculateHasMainKBlockLoop(
...@@ -736,16 +747,36 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -736,16 +747,36 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
// configuration the first is smaller than the latter. Launching too many workgroups // configuration the first is smaller than the latter. Launching too many workgroups
// mean some of them will have to iterate through all gemm problem descriptors just to // mean some of them will have to iterate through all gemm problem descriptors just to
// find out they have nothing to do which is of course waste of GPU cycles. // find out they have nothing to do which is of course waste of GPU cycles.
const index_t grid_size = ck::math::min(arg.tile_count_, max_occupancy_grid_size);
const index_t tiles_per_block = (arg.tile_count_ + grid_size - 1) / grid_size;
std::size_t acc_workspace_size_bytes = Block2ETileMapKSplit::GetAccWorkspaceSize(
sizeof(typename GridwiseGemm::AccType), grid_size);
void* p_flags = reinterpret_cast<char*>(dev_gemm_workspace) +
Block2ETileMapKSplit::GetAccWorkspaceSize(
sizeof(typename GridwiseGemm::AccType), grid_size);
std::size_t flag_count = (grid_size * tiles_per_block + arg.K_BATCH - 1) / arg.K_BATCH;
if(stream_config.log_level_ > 0) if(stream_config.log_level_ > 0)
{ {
const index_t grid_size = ck::math::min(arg.tile_count_, max_occupancy_grid_size);
const index_t tiles_per_block = (arg.tile_count_ + grid_size - 1) / grid_size;
std::cout << "tile_count: " << arg.tile_count_ std::cout << "tile_count: " << arg.tile_count_
<< ", tiles_per_block: " << tiles_per_block << std::endl; << ", tiles_per_block: " << tiles_per_block
<< ", grid_size: " << grid_size << ", flag_count: " << flag_count
<< ", p_flags: " << p_flags << ", workspace_ptr: " << dev_gemm_workspace
<< ", acc_workspace_size_bytes: " << acc_workspace_size_bytes
<< std::endl;
} }
return launch_and_time_kernel( auto preprocess = [&]() {
hip_check_error(hipMemsetAsync(
p_flags, 0, flag_count * sizeof(uint32_t), stream_config.stream_id_));
};
return launch_and_time_kernel_with_preprocess(
// return launch_and_time_kernel(
stream_config, stream_config,
preprocess,
kernel, kernel,
dim3(ck::math::min(arg.tile_count_, max_occupancy_grid_size)), dim3(ck::math::min(arg.tile_count_, max_occupancy_grid_size)),
dim3(BlockSize), dim3(BlockSize),
...@@ -768,6 +799,11 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -768,6 +799,11 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported())
{
return false;
}
if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) + if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) +
arg.skipped_group_count_) != arg.group_count_) arg.skipped_group_count_) != arg.group_count_)
{ {
...@@ -783,6 +819,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -783,6 +819,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{ {
const auto& gemm_arg = arg.gemm_kernel_args_[i]; const auto& gemm_arg = arg.gemm_kernel_args_[i];
bool group_arg_valid = GridwiseGemm::CheckValidity(gemm_arg.M, bool group_arg_valid = GridwiseGemm::CheckValidity(gemm_arg.M,
gemm_arg.N, gemm_arg.N,
gemm_arg.K, gemm_arg.K,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment