Commit 9177a207 authored by Adam Osewski's avatar Adam Osewski
Browse files

Fix joining kbatch-tiles.

parent eaa68635
...@@ -43,8 +43,6 @@ namespace device { ...@@ -43,8 +43,6 @@ namespace device {
/// @tparam FloatC Input tensor C elements' data type. /// @tparam FloatC Input tensor C elements' data type.
/// @tparam Block2ETileMapKSplit The structure providing mapping between workgroup ids, /// @tparam Block2ETileMapKSplit The structure providing mapping between workgroup ids,
/// the data tiles to process and the output tiles. /// the data tiles to process and the output tiles.
/// @tparam HasMainKBlockLoop Flag indicating whether all GEMM problem configurations
/// need to loop over tiles in K dimension.
/// ///
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename GemmDesc, typename GemmDesc,
...@@ -55,8 +53,7 @@ template <typename GridwiseGemm, ...@@ -55,8 +53,7 @@ template <typename GridwiseGemm,
typename Block2ETileMapKSplit, typename Block2ETileMapKSplit,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation, typename CDEElementwiseOperation>
bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -147,20 +144,20 @@ __global__ void ...@@ -147,20 +144,20 @@ __global__ void
// k_tiles); // k_tiles);
// } // }
// just accumulate results in registers! // just accumulate results in registers!
GridwiseGemm::template RunGEMM<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template RunGEMM(p_a_grid,
p_b_grid, p_b_grid,
static_cast<void*>(p_shared), static_cast<void*>(p_shared),
a_element_op, a_element_op,
b_element_op, b_element_op,
M, M,
N, N,
K, K,
StrideA, StrideA,
StrideB, StrideB,
k_batch, k_batch,
b2c_tile_map, b2c_tile_map,
results_buffer, results_buffer,
k_tiles); k_tiles);
// Move to the last processed k-tile // Move to the last processed k-tile
b2c_tile_map.AdvanceTileKIdx(k_tiles - 1); b2c_tile_map.AdvanceTileKIdx(k_tiles - 1);
...@@ -175,6 +172,7 @@ __global__ void ...@@ -175,6 +172,7 @@ __global__ void
GridwiseGemm::StorePartials(p_workspace, static_cast<void*>(p_shared), results_buffer); GridwiseGemm::StorePartials(p_workspace, static_cast<void*>(p_shared), results_buffer);
#if 1 #if 1
__builtin_amdgcn_sched_barrier(0);
// make sure all writes to gmem has finished. // make sure all writes to gmem has finished.
__builtin_amdgcn_s_waitcnt(0x0f70); // s_waitcnt vmcnt(0) __builtin_amdgcn_s_waitcnt(0x0f70); // s_waitcnt vmcnt(0)
// __builtin_amdgcn_s_waitcnt(0x0070); // s_waitcnt vmcnt(0) lgkmcnt(0) // __builtin_amdgcn_s_waitcnt(0x0070); // s_waitcnt vmcnt(0) lgkmcnt(0)
...@@ -510,73 +508,18 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -510,73 +508,18 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
void UpdateOccupancy() void UpdateOccupancy()
{ {
bool all_have_main_k_block_loop; const auto kernel = kernel_grouped_gemm_xdl_splitk_v2<GridwiseGemm,
{ KernelArguments,
const auto a_grid_desc_ak0_m_ak1 = ADataType,
GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(gemm_kernel_args_[0].M, BDataType,
gemm_kernel_args_[0].K, EDataType,
gemm_kernel_args_[0].StrideA, DsDataType,
K_BATCH); Block2ETileMapKSplit,
AElementwiseOperation,
all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop( BElementwiseOperation,
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2) / CDEElementwiseOperation>;
K_BATCH); hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
} &occupancy_num_blocks_, kernel, BlockSize, 0));
for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i)
{
const auto& gemm_arg = gemm_kernel_args_[i];
auto kbatch = K_BATCH;
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
gemm_arg.M, gemm_arg.K, gemm_arg.StrideA, kbatch);
bool not_all_have_main_k_block_loop_same =
all_have_main_k_block_loop xor
GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(I0) *
a_grid_desc_ak0_m_ak1.GetLength(I2) /
K_BATCH);
if(not_all_have_main_k_block_loop_same)
{
std::ostringstream err;
err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
}
if(all_have_main_k_block_loop)
{
const auto kernel = kernel_grouped_gemm_xdl_splitk_v2<GridwiseGemm,
KernelArguments,
ADataType,
BDataType,
EDataType,
DsDataType,
Block2ETileMapKSplit,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
true>;
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy_num_blocks_, kernel, BlockSize, 0));
}
else
{
const auto kernel = kernel_grouped_gemm_xdl_splitk_v2<GridwiseGemm,
KernelArguments,
ADataType,
BDataType,
EDataType,
DsDataType,
Block2ETileMapKSplit,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
false>;
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy_num_blocks_, kernel, BlockSize, 0));
}
} }
// private: // private:
...@@ -631,8 +574,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -631,8 +574,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
void* dev_gemm_workspace, void* dev_gemm_workspace,
const StreamConfig& stream_config = StreamConfig{}) const StreamConfig& stream_config = StreamConfig{})
{ {
[[maybe_unused]] auto [all_have_kbatch_gt_one, all_have_main_k_block_loop] = CheckArgument(arg, stream_config);
CheckArgument(arg, stream_config);
if(dev_gemm_args == nullptr) if(dev_gemm_args == nullptr)
{ {
...@@ -650,18 +592,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -650,18 +592,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
throw std::runtime_error(err.str()); throw std::runtime_error(err.str());
} }
float ave_time = 0; float ave_time = DispatchKernel(arg, dev_gemm_args, dev_gemm_workspace, stream_config);
if(all_have_main_k_block_loop)
{
ave_time =
DispatchKernel<true>(arg, dev_gemm_args, dev_gemm_workspace, stream_config);
}
else
{
ave_time =
DispatchKernel<false>(arg, dev_gemm_args, dev_gemm_workspace, stream_config);
}
return ave_time; return ave_time;
} }
...@@ -708,22 +639,9 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -708,22 +639,9 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
} }
private: private:
auto CheckArgument(const Argument& arg, const StreamConfig& stream_config) const void CheckArgument(const Argument& arg, const StreamConfig& stream_config) const
{ {
bool all_have_kbatch_gt_one, all_have_main_k_block_loop; bool all_have_kbatch_gt_one = arg.K_BATCH > 1;
{
const auto a_grid_desc_ak0_m_ak1 =
GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(arg.gemm_kernel_args_[0].M,
arg.gemm_kernel_args_[0].K,
arg.gemm_kernel_args_[0].StrideA,
arg.K_BATCH);
all_have_kbatch_gt_one = arg.K_BATCH > 1;
all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_ak0_m_ak1.GetLength(I0) *
a_grid_desc_ak0_m_ak1.GetLength(I2 / kbatch);
}
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{ {
...@@ -751,24 +669,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -751,24 +669,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
throw std::runtime_error(err.str()); throw std::runtime_error(err.str());
} }
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
gemm_arg.M, gemm_arg.K, gemm_arg.StrideA, kbatch);
bool not_all_have_main_k_block_loop_same =
all_have_main_k_block_loop xor
GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(I0) *
a_grid_desc_ak0_m_ak1.GetLength(I2) /
kbatch);
bool not_all_have_kbatch_value_same = all_have_kbatch_gt_one xor (kbatch > 1); bool not_all_have_kbatch_value_same = all_have_kbatch_gt_one xor (kbatch > 1);
if(not_all_have_main_k_block_loop_same)
{
std::ostringstream err;
err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
if(not_all_have_kbatch_value_same) if(not_all_have_kbatch_value_same)
{ {
std::ostringstream err; std::ostringstream err;
...@@ -779,10 +681,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -779,10 +681,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
throw std::runtime_error(err.str()); throw std::runtime_error(err.str());
} }
} }
return std::make_tuple(all_have_kbatch_gt_one, all_have_main_k_block_loop);
} }
template <bool HasMainKBlockLoop>
float DispatchKernel(const Argument& arg, float DispatchKernel(const Argument& arg,
const void* dev_gemm_args, const void* dev_gemm_args,
void* dev_gemm_workspace, void* dev_gemm_workspace,
...@@ -797,8 +697,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -797,8 +697,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
Block2ETileMapKSplit, Block2ETileMapKSplit,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation>;
HasMainKBlockLoop>;
return LaunchKernel(kernel, arg, dev_gemm_args, dev_gemm_workspace, stream_config); return LaunchKernel(kernel, arg, dev_gemm_args, dev_gemm_workspace, stream_config);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment