Commit 9177a207 authored by Adam Osewski's avatar Adam Osewski
Browse files

Fix joining kbatch-tiles.

parent eaa68635
......@@ -43,8 +43,6 @@ namespace device {
/// @tparam FloatC Input tensor C elements' data type.
/// @tparam Block2ETileMapKSplit The structure providing mapping between workgroup ids,
/// the data tiles to process and the output tiles.
/// @tparam HasMainKBlockLoop Flag indicating whether all GEMM problem configurations
/// need to loop over tiles in K dimension.
///
template <typename GridwiseGemm,
typename GemmDesc,
......@@ -55,8 +53,7 @@ template <typename GridwiseGemm,
typename Block2ETileMapKSplit,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
bool HasMainKBlockLoop>
typename CDEElementwiseOperation>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
......@@ -147,20 +144,20 @@ __global__ void
// k_tiles);
// }
// just accumulate results in registers!
GridwiseGemm::template RunGEMM<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
static_cast<void*>(p_shared),
a_element_op,
b_element_op,
M,
N,
K,
StrideA,
StrideB,
k_batch,
b2c_tile_map,
results_buffer,
k_tiles);
GridwiseGemm::template RunGEMM(p_a_grid,
p_b_grid,
static_cast<void*>(p_shared),
a_element_op,
b_element_op,
M,
N,
K,
StrideA,
StrideB,
k_batch,
b2c_tile_map,
results_buffer,
k_tiles);
// Move to the last processed k-tile
b2c_tile_map.AdvanceTileKIdx(k_tiles - 1);
......@@ -175,6 +172,7 @@ __global__ void
GridwiseGemm::StorePartials(p_workspace, static_cast<void*>(p_shared), results_buffer);
#if 1
__builtin_amdgcn_sched_barrier(0);
// make sure all writes to gmem has finished.
__builtin_amdgcn_s_waitcnt(0x0f70); // s_waitcnt vmcnt(0)
// __builtin_amdgcn_s_waitcnt(0x0070); // s_waitcnt vmcnt(0) lgkmcnt(0)
......@@ -510,73 +508,18 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
void UpdateOccupancy()
{
bool all_have_main_k_block_loop;
{
const auto a_grid_desc_ak0_m_ak1 =
GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(gemm_kernel_args_[0].M,
gemm_kernel_args_[0].K,
gemm_kernel_args_[0].StrideA,
K_BATCH);
all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2) /
K_BATCH);
}
for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i)
{
const auto& gemm_arg = gemm_kernel_args_[i];
auto kbatch = K_BATCH;
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
gemm_arg.M, gemm_arg.K, gemm_arg.StrideA, kbatch);
bool not_all_have_main_k_block_loop_same =
all_have_main_k_block_loop xor
GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(I0) *
a_grid_desc_ak0_m_ak1.GetLength(I2) /
K_BATCH);
if(not_all_have_main_k_block_loop_same)
{
std::ostringstream err;
err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
}
if(all_have_main_k_block_loop)
{
const auto kernel = kernel_grouped_gemm_xdl_splitk_v2<GridwiseGemm,
KernelArguments,
ADataType,
BDataType,
EDataType,
DsDataType,
Block2ETileMapKSplit,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
true>;
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy_num_blocks_, kernel, BlockSize, 0));
}
else
{
const auto kernel = kernel_grouped_gemm_xdl_splitk_v2<GridwiseGemm,
KernelArguments,
ADataType,
BDataType,
EDataType,
DsDataType,
Block2ETileMapKSplit,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
false>;
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy_num_blocks_, kernel, BlockSize, 0));
}
const auto kernel = kernel_grouped_gemm_xdl_splitk_v2<GridwiseGemm,
KernelArguments,
ADataType,
BDataType,
EDataType,
DsDataType,
Block2ETileMapKSplit,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>;
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy_num_blocks_, kernel, BlockSize, 0));
}
// private:
......@@ -631,8 +574,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
void* dev_gemm_workspace,
const StreamConfig& stream_config = StreamConfig{})
{
[[maybe_unused]] auto [all_have_kbatch_gt_one, all_have_main_k_block_loop] =
CheckArgument(arg, stream_config);
CheckArgument(arg, stream_config);
if(dev_gemm_args == nullptr)
{
......@@ -650,18 +592,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
throw std::runtime_error(err.str());
}
float ave_time = 0;
if(all_have_main_k_block_loop)
{
ave_time =
DispatchKernel<true>(arg, dev_gemm_args, dev_gemm_workspace, stream_config);
}
else
{
ave_time =
DispatchKernel<false>(arg, dev_gemm_args, dev_gemm_workspace, stream_config);
}
float ave_time = DispatchKernel(arg, dev_gemm_args, dev_gemm_workspace, stream_config);
return ave_time;
}
......@@ -708,22 +639,9 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
}
private:
auto CheckArgument(const Argument& arg, const StreamConfig& stream_config) const
void CheckArgument(const Argument& arg, const StreamConfig& stream_config) const
{
bool all_have_kbatch_gt_one, all_have_main_k_block_loop;
{
const auto a_grid_desc_ak0_m_ak1 =
GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(arg.gemm_kernel_args_[0].M,
arg.gemm_kernel_args_[0].K,
arg.gemm_kernel_args_[0].StrideA,
arg.K_BATCH);
all_have_kbatch_gt_one = arg.K_BATCH > 1;
all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_ak0_m_ak1.GetLength(I0) *
a_grid_desc_ak0_m_ak1.GetLength(I2 / kbatch);
}
bool all_have_kbatch_gt_one = arg.K_BATCH > 1;
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{
......@@ -751,24 +669,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
throw std::runtime_error(err.str());
}
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
gemm_arg.M, gemm_arg.K, gemm_arg.StrideA, kbatch);
bool not_all_have_main_k_block_loop_same =
all_have_main_k_block_loop xor
GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(I0) *
a_grid_desc_ak0_m_ak1.GetLength(I2) /
kbatch);
bool not_all_have_kbatch_value_same = all_have_kbatch_gt_one xor (kbatch > 1);
if(not_all_have_main_k_block_loop_same)
{
std::ostringstream err;
err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
if(not_all_have_kbatch_value_same)
{
std::ostringstream err;
......@@ -779,10 +681,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
throw std::runtime_error(err.str());
}
}
return std::make_tuple(all_have_kbatch_gt_one, all_have_main_k_block_loop);
}
template <bool HasMainKBlockLoop>
float DispatchKernel(const Argument& arg,
const void* dev_gemm_args,
void* dev_gemm_workspace,
......@@ -797,8 +697,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
Block2ETileMapKSplit,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
HasMainKBlockLoop>;
CDEElementwiseOperation>;
return LaunchKernel(kernel, arg, dev_gemm_args, dev_gemm_workspace, stream_config);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment