"...resnet50_tensorflow.git" did not exist on "8fd10a5728f542c3d9b267e7ca0e7ad39f5b4293"
Commit 9177a207 authored by Adam Osewski's avatar Adam Osewski
Browse files

Fix joining kbatch-tiles.

parent eaa68635
...@@ -43,8 +43,6 @@ namespace device { ...@@ -43,8 +43,6 @@ namespace device {
/// @tparam FloatC Input tensor C elements' data type. /// @tparam FloatC Input tensor C elements' data type.
/// @tparam Block2ETileMapKSplit The structure providing mapping between workgroup ids, /// @tparam Block2ETileMapKSplit The structure providing mapping between workgroup ids,
/// the data tiles to process and the output tiles. /// the data tiles to process and the output tiles.
/// @tparam HasMainKBlockLoop Flag indicating whether all GEMM problem configurations
/// need to loop over tiles in K dimension.
/// ///
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename GemmDesc, typename GemmDesc,
...@@ -55,8 +53,7 @@ template <typename GridwiseGemm, ...@@ -55,8 +53,7 @@ template <typename GridwiseGemm,
typename Block2ETileMapKSplit, typename Block2ETileMapKSplit,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation, typename CDEElementwiseOperation>
bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -147,7 +144,7 @@ __global__ void ...@@ -147,7 +144,7 @@ __global__ void
// k_tiles); // k_tiles);
// } // }
// just accumulate results in registers! // just accumulate results in registers!
GridwiseGemm::template RunGEMM<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template RunGEMM(p_a_grid,
p_b_grid, p_b_grid,
static_cast<void*>(p_shared), static_cast<void*>(p_shared),
a_element_op, a_element_op,
...@@ -175,6 +172,7 @@ __global__ void ...@@ -175,6 +172,7 @@ __global__ void
GridwiseGemm::StorePartials(p_workspace, static_cast<void*>(p_shared), results_buffer); GridwiseGemm::StorePartials(p_workspace, static_cast<void*>(p_shared), results_buffer);
#if 1 #if 1
__builtin_amdgcn_sched_barrier(0);
// make sure all writes to gmem has finished. // make sure all writes to gmem has finished.
__builtin_amdgcn_s_waitcnt(0x0f70); // s_waitcnt vmcnt(0) __builtin_amdgcn_s_waitcnt(0x0f70); // s_waitcnt vmcnt(0)
// __builtin_amdgcn_s_waitcnt(0x0070); // s_waitcnt vmcnt(0) lgkmcnt(0) // __builtin_amdgcn_s_waitcnt(0x0070); // s_waitcnt vmcnt(0) lgkmcnt(0)
...@@ -509,59 +507,6 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -509,59 +507,6 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
} }
void UpdateOccupancy() void UpdateOccupancy()
{
bool all_have_main_k_block_loop;
{
const auto a_grid_desc_ak0_m_ak1 =
GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(gemm_kernel_args_[0].M,
gemm_kernel_args_[0].K,
gemm_kernel_args_[0].StrideA,
K_BATCH);
all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2) /
K_BATCH);
}
for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i)
{
const auto& gemm_arg = gemm_kernel_args_[i];
auto kbatch = K_BATCH;
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
gemm_arg.M, gemm_arg.K, gemm_arg.StrideA, kbatch);
bool not_all_have_main_k_block_loop_same =
all_have_main_k_block_loop xor
GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(I0) *
a_grid_desc_ak0_m_ak1.GetLength(I2) /
K_BATCH);
if(not_all_have_main_k_block_loop_same)
{
std::ostringstream err;
err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
}
if(all_have_main_k_block_loop)
{
const auto kernel = kernel_grouped_gemm_xdl_splitk_v2<GridwiseGemm,
KernelArguments,
ADataType,
BDataType,
EDataType,
DsDataType,
Block2ETileMapKSplit,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
true>;
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy_num_blocks_, kernel, BlockSize, 0));
}
else
{ {
const auto kernel = kernel_grouped_gemm_xdl_splitk_v2<GridwiseGemm, const auto kernel = kernel_grouped_gemm_xdl_splitk_v2<GridwiseGemm,
KernelArguments, KernelArguments,
...@@ -572,12 +517,10 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -572,12 +517,10 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
Block2ETileMapKSplit, Block2ETileMapKSplit,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation>;
false>;
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy_num_blocks_, kernel, BlockSize, 0)); &occupancy_num_blocks_, kernel, BlockSize, 0));
} }
}
// private: // private:
index_t K_BATCH; index_t K_BATCH;
...@@ -631,7 +574,6 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -631,7 +574,6 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
void* dev_gemm_workspace, void* dev_gemm_workspace,
const StreamConfig& stream_config = StreamConfig{}) const StreamConfig& stream_config = StreamConfig{})
{ {
[[maybe_unused]] auto [all_have_kbatch_gt_one, all_have_main_k_block_loop] =
CheckArgument(arg, stream_config); CheckArgument(arg, stream_config);
if(dev_gemm_args == nullptr) if(dev_gemm_args == nullptr)
...@@ -650,18 +592,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -650,18 +592,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
throw std::runtime_error(err.str()); throw std::runtime_error(err.str());
} }
float ave_time = 0; float ave_time = DispatchKernel(arg, dev_gemm_args, dev_gemm_workspace, stream_config);
if(all_have_main_k_block_loop)
{
ave_time =
DispatchKernel<true>(arg, dev_gemm_args, dev_gemm_workspace, stream_config);
}
else
{
ave_time =
DispatchKernel<false>(arg, dev_gemm_args, dev_gemm_workspace, stream_config);
}
return ave_time; return ave_time;
} }
...@@ -708,22 +639,9 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -708,22 +639,9 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
} }
private: private:
auto CheckArgument(const Argument& arg, const StreamConfig& stream_config) const void CheckArgument(const Argument& arg, const StreamConfig& stream_config) const
{ {
bool all_have_kbatch_gt_one, all_have_main_k_block_loop; bool all_have_kbatch_gt_one = arg.K_BATCH > 1;
{
const auto a_grid_desc_ak0_m_ak1 =
GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(arg.gemm_kernel_args_[0].M,
arg.gemm_kernel_args_[0].K,
arg.gemm_kernel_args_[0].StrideA,
arg.K_BATCH);
all_have_kbatch_gt_one = arg.K_BATCH > 1;
all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_ak0_m_ak1.GetLength(I0) *
a_grid_desc_ak0_m_ak1.GetLength(I2 / kbatch);
}
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{ {
...@@ -751,24 +669,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -751,24 +669,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
throw std::runtime_error(err.str()); throw std::runtime_error(err.str());
} }
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
gemm_arg.M, gemm_arg.K, gemm_arg.StrideA, kbatch);
bool not_all_have_main_k_block_loop_same =
all_have_main_k_block_loop xor
GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(I0) *
a_grid_desc_ak0_m_ak1.GetLength(I2) /
kbatch);
bool not_all_have_kbatch_value_same = all_have_kbatch_gt_one xor (kbatch > 1); bool not_all_have_kbatch_value_same = all_have_kbatch_gt_one xor (kbatch > 1);
if(not_all_have_main_k_block_loop_same)
{
std::ostringstream err;
err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
if(not_all_have_kbatch_value_same) if(not_all_have_kbatch_value_same)
{ {
std::ostringstream err; std::ostringstream err;
...@@ -779,10 +681,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -779,10 +681,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
throw std::runtime_error(err.str()); throw std::runtime_error(err.str());
} }
} }
return std::make_tuple(all_have_kbatch_gt_one, all_have_main_k_block_loop);
} }
template <bool HasMainKBlockLoop>
float DispatchKernel(const Argument& arg, float DispatchKernel(const Argument& arg,
const void* dev_gemm_args, const void* dev_gemm_args,
void* dev_gemm_workspace, void* dev_gemm_workspace,
...@@ -797,8 +697,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -797,8 +697,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
Block2ETileMapKSplit, Block2ETileMapKSplit,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation>;
HasMainKBlockLoop>;
return LaunchKernel(kernel, arg, dev_gemm_args, dev_gemm_workspace, stream_config); return LaunchKernel(kernel, arg, dev_gemm_args, dev_gemm_workspace, stream_config);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment