Commit a6ef5c39 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents 9b3c4ac4 1274861a
...@@ -553,7 +553,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout, ...@@ -553,7 +553,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{" std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{"
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0) << arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0)
......
...@@ -337,6 +337,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage ...@@ -337,6 +337,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
elementwise_d_grid_descs_m_n_.reserve(group_count_); elementwise_d_grid_descs_m_n_.reserve(group_count_);
ds_grid_pointer_.reserve(group_count_); ds_grid_pointer_.reserve(group_count_);
group_grid_size_.reserve(group_count_); group_grid_size_.reserve(group_count_);
e_ptrs_.reserve(group_count_);
for(std::size_t i = 0; i < gemm_descs.size(); ++i) for(std::size_t i = 0; i < gemm_descs.size(); ++i)
{ {
...@@ -380,7 +381,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage ...@@ -380,7 +381,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
const index_t block_end = grid_size_ + grid_size_grp; const index_t block_end = grid_size_ + grid_size_grp;
grid_size_ += grid_size_grp; grid_size_ += grid_size_grp;
group_grid_size_[i] = grid_size_grp; group_grid_size_.push_back(grid_size_grp);
// block-to-e-tile map // block-to-e-tile map
auto grouped_block_2_ctile_map = auto grouped_block_2_ctile_map =
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
...@@ -421,9 +422,9 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage ...@@ -421,9 +422,9 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
elementwise_c_grid_descs_m_n_.push_back(c_grid_desc_m_n); elementwise_c_grid_descs_m_n_.push_back(c_grid_desc_m_n);
elementwise_d_grid_descs_m_n_.push_back(ds_grid_desc_m_n); elementwise_d_grid_descs_m_n_.push_back(ds_grid_desc_m_n);
ds_grid_pointer_.push_back(p_ds_grid); ds_grid_pointer_.push_back(p_ds_grid);
// Store a copy of E pointers for elementwise kernel destination
e_ptrs_.push_back(p_Es[i]);
} }
// Store a copy of E pointers for elementwise kernel destination
e_ptrs_ = p_Es;
} }
/** /**
...@@ -467,7 +468,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage ...@@ -467,7 +468,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
gemm_kernel_args_[i].block_start_ = block_start; gemm_kernel_args_[i].block_start_ = block_start;
gemm_kernel_args_[i].block_end_ = block_end; gemm_kernel_args_[i].block_end_ = block_end;
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
index_t tiles = (block_end - block_start) / K_BATCH; index_t tiles = (block_end - block_start) / K_BATCH;
std::cout << "block_start: " << block_start << "\n" std::cout << "block_start: " << block_start << "\n"
...@@ -494,7 +495,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage ...@@ -494,7 +495,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
arg.karg_.p_c_grid = p_workspace + offset; arg.karg_.p_c_grid = p_workspace + offset;
index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch; index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch;
offset += tiles * MPerBlock * NPerBlock; offset += tiles * MPerBlock * NPerBlock;
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "block_start: " << arg.block_start_ << "\n" std::cout << "block_start: " << arg.block_start_ << "\n"
<< "block_end: " << arg.block_end_ << "\n" << "block_end: " << arg.block_end_ << "\n"
...@@ -774,13 +775,13 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage ...@@ -774,13 +775,13 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
dim3(BlockSize), dim3(BlockSize),
0, 0,
cast_pointer_to_constant_address_space(dev_gemm_args), cast_pointer_to_constant_address_space(dev_gemm_args),
arg.group_count_, arg.gemm_kernel_args_.size(),
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
PassThrough{}); PassThrough{});
// Elementwise kernels // Elementwise kernels
for(int i = 0; i < arg.group_count_; ++i) for(size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{ {
time += launch_and_time_kernel( time += launch_and_time_kernel(
stream_config, stream_config,
...@@ -818,7 +819,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage ...@@ -818,7 +819,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) + if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) +
arg.skipped_group_count_) != arg.group_count_) arg.skipped_group_count_) != arg.group_count_)
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "The group count is not equal to sum of skipped groups " std::cout << "The group count is not equal to sum of skipped groups "
"and kernel args size!" "and kernel args size!"
...@@ -835,7 +836,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage ...@@ -835,7 +836,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
bool group_arg_valid = GridwiseGemm::CheckValidity(gemm_arg); bool group_arg_valid = GridwiseGemm::CheckValidity(gemm_arg);
if(not group_arg_valid) if(not group_arg_valid)
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "[" << __func__ << "] group id: " << i std::cout << "[" << __func__ << "] group id: " << i
<< " has invalid GridwiseGemm settings!" << std::endl; << " has invalid GridwiseGemm settings!" << std::endl;
......
...@@ -620,7 +620,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop ...@@ -620,7 +620,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
GridwiseGemm::template CheckTensorTransfersValidity<ALayout, BLayout, ELayout>( GridwiseGemm::template CheckTensorTransfersValidity<ALayout, BLayout, ELayout>(
M, N, K))) M, N, K)))
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "The provided GEMM problem size (M,N,K) [" << M << "," << N << "," std::cout << "The provided GEMM problem size (M,N,K) [" << M << "," << N << ","
<< K << "] are not supported by current template parameters!" << K << "] are not supported by current template parameters!"
......
...@@ -514,7 +514,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -514,7 +514,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "group: " << i << " arg.a_grid_desc_ak0_m_ak1_{" std::cout << "group: " << i << " arg.a_grid_desc_ak0_m_ak1_{"
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) << arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0)
......
...@@ -529,7 +529,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -529,7 +529,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) + if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) +
arg.skipped_group_count_) != arg.group_count_) arg.skipped_group_count_) != arg.group_count_)
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "The group count is not equal to sum of skipped groups " std::cout << "The group count is not equal to sum of skipped groups "
"and kernel args size!" "and kernel args size!"
...@@ -545,7 +545,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -545,7 +545,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
bool group_arg_valid = GridwiseGemm::CheckValidity(a); bool group_arg_valid = GridwiseGemm::CheckValidity(a);
if(not group_arg_valid) if(not group_arg_valid)
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "[" << __func__ << "] group id: " << i std::cout << "[" << __func__ << "] group id: " << i
<< " has invalid GridwiseGemm settings!" << std::endl; << " has invalid GridwiseGemm settings!" << std::endl;
......
...@@ -935,7 +935,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -935,7 +935,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
if(!(karg.M % MPerBlock == 0)) if(!(karg.M % MPerBlock == 0))
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
...@@ -952,7 +952,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -952,7 +952,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
if(!(karg.N % NPerBlock == 0)) if(!(karg.N % NPerBlock == 0))
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
...@@ -971,7 +971,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -971,7 +971,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto K_t = karg.KBatch * KPerBlock; auto K_t = karg.KBatch * KPerBlock;
if(!(karg.K % K_t == 0)) if(!(karg.K % K_t == 0))
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
<< karg.K << " " << __FILE__ << ":" << __LINE__ << karg.K << " " << __FILE__ << ":" << __LINE__
...@@ -995,7 +995,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -995,7 +995,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
if(karg.K % ABlockTransferSrcScalarPerVector != 0) if(karg.K % ABlockTransferSrcScalarPerVector != 0)
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "Arg K (" << karg.K std::cout << "Arg K (" << karg.K
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector (" << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
...@@ -1009,7 +1009,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1009,7 +1009,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
if(karg.M % ABlockTransferSrcScalarPerVector != 0) if(karg.M % ABlockTransferSrcScalarPerVector != 0)
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "Arg M (" << karg.M std::cout << "Arg M (" << karg.M
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector (" << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
...@@ -1024,7 +1024,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1024,7 +1024,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
if(karg.N % BBlockTransferSrcScalarPerVector != 0) if(karg.N % BBlockTransferSrcScalarPerVector != 0)
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "Arg N (" << karg.N std::cout << "Arg N (" << karg.N
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector (" << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
...@@ -1038,7 +1038,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1038,7 +1038,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
if(karg.K % BBlockTransferSrcScalarPerVector != 0) if(karg.K % BBlockTransferSrcScalarPerVector != 0)
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "Arg K (" << karg.K std::cout << "Arg K (" << karg.K
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector (" << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
...@@ -1053,7 +1053,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1053,7 +1053,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "Arg N (" << karg.N std::cout << "Arg N (" << karg.N
<< ") value is not a multiple of " << ") value is not a multiple of "
...@@ -1069,7 +1069,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1069,7 +1069,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "Arg M (" << karg.M std::cout << "Arg M (" << karg.M
<< ") value is not a multiple of " << ") value is not a multiple of "
...@@ -1084,7 +1084,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1084,7 +1084,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
if constexpr(is_same<remove_cvref_t<CDataType>, bhalf_t>::value) if constexpr(is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__ std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl; << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
......
...@@ -1113,7 +1113,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1113,7 +1113,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
if(!(karg.M % MPerBlock == 0)) if(!(karg.M % MPerBlock == 0))
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
...@@ -1130,7 +1130,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1130,7 +1130,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
if(!(karg.N % NPerBlock == 0)) if(!(karg.N % NPerBlock == 0))
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
...@@ -1149,7 +1149,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1149,7 +1149,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto K_t = karg.KBatch * KPerBlock; auto K_t = karg.KBatch * KPerBlock;
if(!(karg.K % K_t == 0)) if(!(karg.K % K_t == 0))
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
<< karg.K << " " << __FILE__ << ":" << __LINE__ << karg.K << " " << __FILE__ << ":" << __LINE__
...@@ -1173,7 +1173,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1173,7 +1173,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
if(karg.K % ABlockTransferSrcScalarPerVector != 0) if(karg.K % ABlockTransferSrcScalarPerVector != 0)
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "Arg K (" << karg.K std::cout << "Arg K (" << karg.K
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector (" << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
...@@ -1187,7 +1187,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1187,7 +1187,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
if(karg.M % ABlockTransferSrcScalarPerVector != 0) if(karg.M % ABlockTransferSrcScalarPerVector != 0)
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "Arg M (" << karg.M std::cout << "Arg M (" << karg.M
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector (" << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
...@@ -1202,7 +1202,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1202,7 +1202,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
if(karg.N % BBlockTransferSrcScalarPerVector != 0) if(karg.N % BBlockTransferSrcScalarPerVector != 0)
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "Arg N (" << karg.N std::cout << "Arg N (" << karg.N
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector (" << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
...@@ -1216,7 +1216,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1216,7 +1216,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
if(karg.K % BBlockTransferSrcScalarPerVector != 0) if(karg.K % BBlockTransferSrcScalarPerVector != 0)
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "Arg K (" << karg.K std::cout << "Arg K (" << karg.K
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector (" << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
...@@ -1231,7 +1231,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1231,7 +1231,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "Arg N (" << karg.N std::cout << "Arg N (" << karg.N
<< ") value is not a multiple of " << ") value is not a multiple of "
...@@ -1247,7 +1247,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1247,7 +1247,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "Arg M (" << karg.M std::cout << "Arg M (" << karg.M
<< ") value is not a multiple of " << ") value is not a multiple of "
......
...@@ -446,7 +446,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -446,7 +446,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{ {
if(!(karg.M % MPerBlock == 0)) if(!(karg.M % MPerBlock == 0))
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
...@@ -463,7 +463,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -463,7 +463,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{ {
if(!(karg.N % NPerBlock == 0)) if(!(karg.N % NPerBlock == 0))
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
...@@ -482,7 +482,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -482,7 +482,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
auto K_t = karg.k_batch * K0PerBlock * K1; auto K_t = karg.k_batch * K0PerBlock * K1;
if(!(karg.K % K_t == 0)) if(!(karg.K % K_t == 0))
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
<< karg.K << " " << __FILE__ << ":" << __LINE__ << karg.K << " " << __FILE__ << ":" << __LINE__
...@@ -496,7 +496,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -496,7 +496,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{ {
if(karg.K % ABlockTransferSrcScalarPerVector != 0) if(karg.K % ABlockTransferSrcScalarPerVector != 0)
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "Arg K (" << karg.K std::cout << "Arg K (" << karg.K
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector (" << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
...@@ -510,7 +510,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -510,7 +510,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{ {
if(karg.M % ABlockTransferSrcScalarPerVector != 0) if(karg.M % ABlockTransferSrcScalarPerVector != 0)
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "Arg M (" << karg.M std::cout << "Arg M (" << karg.M
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector (" << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
...@@ -525,7 +525,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -525,7 +525,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{ {
if(karg.N % BBlockTransferSrcScalarPerVector != 0) if(karg.N % BBlockTransferSrcScalarPerVector != 0)
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "Arg N (" << karg.N std::cout << "Arg N (" << karg.N
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector (" << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
...@@ -539,7 +539,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -539,7 +539,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{ {
if(karg.K % BBlockTransferSrcScalarPerVector != 0) if(karg.K % BBlockTransferSrcScalarPerVector != 0)
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "Arg K (" << karg.K std::cout << "Arg K (" << karg.K
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector (" << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
...@@ -554,7 +554,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -554,7 +554,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{ {
if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "Arg N (" << karg.N std::cout << "Arg N (" << karg.N
<< ") value is not a multiple of " << ") value is not a multiple of "
...@@ -569,7 +569,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -569,7 +569,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{ {
if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "Arg M (" << karg.M std::cout << "Arg M (" << karg.M
<< ") value is not a multiple of " << ") value is not a multiple of "
...@@ -584,7 +584,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -584,7 +584,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
const auto num_k_loop = karg.K0Padded / K0PerBlock; const auto num_k_loop = karg.K0Padded / K0PerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop)) if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "The number of k loops (" << num_k_loop std::cout << "The number of k loops (" << num_k_loop
<< ") value is not supported by GridwiseGemm Pipeline." << ") value is not supported by GridwiseGemm Pipeline."
......
...@@ -124,7 +124,7 @@ struct EnvVar ...@@ -124,7 +124,7 @@ struct EnvVar
#define CK_DECLARE_ENV_VAR_STR(name) CK_DECLARE_ENV_VAR(name, std::string, "") #define CK_DECLARE_ENV_VAR_STR(name) CK_DECLARE_ENV_VAR(name, std::string, "")
#define ENV(name) \ #define CK_ENV(name) \
ck::env::name {} ck::env::name {}
template <class EnvVar> template <class EnvVar>
......
...@@ -129,8 +129,8 @@ constexpr double fp16_to_double_hip(const fp16_hip_t& x) ...@@ -129,8 +129,8 @@ constexpr double fp16_to_double_hip(const fp16_hip_t& x)
CK_TILE_HOST_DEVICE CK_TILE_HOST_DEVICE
constexpr fp16_hip_t float_to_fp16_hip(const float& x) constexpr fp16_hip_t float_to_fp16_hip(const float& x)
{ {
return __float2half(x); // return __float2half(x);
// return static_cast<fp16_hip_t>(x); return static_cast<fp16_hip_t>(x);
} }
CK_TILE_HOST_DEVICE CK_TILE_HOST_DEVICE
......
...@@ -56,7 +56,6 @@ CK_TILE_LEFT_UNARY_OP(+) ...@@ -56,7 +56,6 @@ CK_TILE_LEFT_UNARY_OP(+)
CK_TILE_LEFT_UNARY_OP(-) CK_TILE_LEFT_UNARY_OP(-)
CK_TILE_LEFT_UNARY_OP(~) CK_TILE_LEFT_UNARY_OP(~)
CK_TILE_LEFT_UNARY_OP(!) CK_TILE_LEFT_UNARY_OP(!)
CK_TILE_LEFT_UNARY_OP(*)
CK_TILE_BINARY_OP(+) CK_TILE_BINARY_OP(+)
CK_TILE_BINARY_OP(-) CK_TILE_BINARY_OP(-)
......
...@@ -88,7 +88,7 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification, ...@@ -88,7 +88,7 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification,
c_m_n_host_results.push_back( c_m_n_host_results.push_back(
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n["
<< i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i << i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
......
...@@ -87,7 +87,7 @@ bool profile_grouped_gemm_impl(int do_verification, ...@@ -87,7 +87,7 @@ bool profile_grouped_gemm_impl(int do_verification,
c_m_n_host_results.push_back( c_m_n_host_results.push_back(
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n["
<< i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i << i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
......
...@@ -82,7 +82,7 @@ bool profile_grouped_gemm_tile_loop_impl(int do_verification, ...@@ -82,7 +82,7 @@ bool profile_grouped_gemm_tile_loop_impl(int do_verification,
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
c_m_n_host_results.push_back( c_m_n_host_results.push_back(
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n["
<< i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i << i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
......
...@@ -88,7 +88,7 @@ bool profile_grouped_gemm_two_stage_impl(int do_verification, ...@@ -88,7 +88,7 @@ bool profile_grouped_gemm_two_stage_impl(int do_verification,
c_m_n_host_results.push_back( c_m_n_host_results.push_back(
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n["
<< i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i << i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
......
...@@ -6,6 +6,12 @@ if(result EQUAL 0) ...@@ -6,6 +6,12 @@ if(result EQUAL 0)
add_dependencies(test_grouped_gemm test_grouped_gemm_splitk) add_dependencies(test_grouped_gemm test_grouped_gemm_splitk)
endif() endif()
add_gtest_executable(test_grouped_gemm_two_stage_splitk test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp)
if(result EQUAL 0)
target_link_libraries(test_grouped_gemm_two_stage_splitk PRIVATE utility device_grouped_gemm_instance)
add_dependencies(test_grouped_gemm test_grouped_gemm_two_stage_splitk)
endif()
add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp) add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp)
if(result EQUAL 0) if(result EQUAL 0)
target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance) target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include <vector>
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/utility/data_type.hpp"
#include "gtest/gtest.h"
#include "test_grouped_gemm_util.hpp"
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using RRR_F16_F16_F16 = ck::test::TestGroupedGemmTwoStage<std::tuple<Row, Row, Row, F16, F16, F16>>;
using RCR_F16_F16_F16 = ck::test::TestGroupedGemmTwoStage<std::tuple<Row, Col, Row, F16, F16, F16>>;
using RRR_F16_F16_F16_LargeK =
ck::test::TestGroupedGemmTwoStage<std::tuple<Row, Row, Row, F16, F16, F16>>;
using RCR_F16_F16_F16_LargeK =
ck::test::TestGroupedGemmTwoStage<std::tuple<Row, Col, Row, F16, F16, F16>>;
using RRR_BF16_BF16_BF16 =
ck::test::TestGroupedGemmTwoStage<std::tuple<Row, Row, Row, BF16, BF16, BF16>>;
using RCR_BF16_BF16_BF16 =
ck::test::TestGroupedGemmTwoStage<std::tuple<Row, Col, Row, BF16, BF16, BF16>>;
using RRR_BF16_I8_BF16 =
ck::test::TestGroupedGemmTwoStage<std::tuple<Row, Row, Row, BF16, I8, BF16>>;
using RCR_BF16_I8_BF16 =
ck::test::TestGroupedGemmTwoStage<std::tuple<Row, Col, Row, BF16, I8, BF16>>;
const std::vector<int> KBATCH{1, 2, 3, 5, 8};
INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_KN,
RRR_F16_F16_F16,
testing::ValuesIn(KBATCH));
INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_NK,
RCR_F16_F16_F16,
testing::ValuesIn(KBATCH));
INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_KN_BF16,
RRR_BF16_BF16_BF16,
testing::ValuesIn(KBATCH));
INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_NK_BF16,
RCR_BF16_BF16_BF16,
testing::ValuesIn(KBATCH));
INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_KN_BF16_INT8,
RRR_BF16_I8_BF16,
testing::ValuesIn(KBATCH));
INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_NK_BF16_INT8,
RCR_BF16_I8_BF16,
testing::ValuesIn(KBATCH));
INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_LargeK_MK_KN,
RRR_F16_F16_F16_LargeK,
testing::Values(32, 64));
INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_LargeK_MK_NK,
RCR_F16_F16_F16_LargeK,
testing::Values(32, 64));
#include "test_grouped_gemm_ut_cases.inc"
#include "test_grouped_gemm_two_stage_ut_cases.inc"
#pragma once
TEST_P(RRR_BF16_BF16_BF16, MNKPadded)
{
const std::vector<int> Ms{127, 150, 188, 210};
constexpr int N = 136;
constexpr int K = 280;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_BF16_BF16_BF16, MNKPadded)
{
const std::vector<int> Ms{127, 150, 188, 210};
constexpr int N = 136;
constexpr int K = 280;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RRR_BF16_I8_BF16, MNKPadded)
{
const std::vector<int> Ms{127, 150, 188, 210};
constexpr int N = 136;
constexpr int K = 280;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_BF16_I8_BF16, MNKPadded)
{
const std::vector<int> Ms{127, 150, 188, 210};
constexpr int N = 136;
constexpr int K = 280;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "ck/utility/tuple.hpp" #include "ck/utility/tuple.hpp"
#include "ck/utility/number.hpp" #include "ck/utility/number.hpp"
#include "profiler/profile_grouped_gemm_impl.hpp" #include "profiler/profile_grouped_gemm_impl.hpp"
#include "profiler/profile_grouped_gemm_two_stage_impl.hpp"
namespace ck { namespace ck {
namespace test { namespace test {
...@@ -90,6 +91,58 @@ class TestGroupedGemm : public testing::TestWithParam<int> ...@@ -90,6 +91,58 @@ class TestGroupedGemm : public testing::TestWithParam<int>
} }
}; };
template <typename Tuple>
class TestGroupedGemmTwoStage : public testing::TestWithParam<int>
{
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using ELayout = std::tuple_element_t<2, Tuple>;
using ADataType = std::tuple_element_t<3, Tuple>;
using BDataType = std::tuple_element_t<4, Tuple>;
using EDataType = std::tuple_element_t<5, Tuple>;
public:
static constexpr bool verify_ = true;
static constexpr int init_method_ = 1; // decimal value initialization
static constexpr bool log_ = false;
static constexpr bool bench_ = false; // measure kernel performance
void SetUp() override {}
void Run(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int kbatch = 1,
int n_warmup = 1,
int n_iter = 10)
{
bool pass = ck::profiler::profile_grouped_gemm_two_stage_impl<ADataType,
BDataType,
EDataType,
float,
ALayout,
BLayout,
ELayout>(verify_,
init_method_,
log_,
bench_,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
EXPECT_TRUE(pass);
}
};
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename ELayout, typename ELayout,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment