Commit ee1c88b1 authored by Jing Zhang's avatar Jing Zhang
Browse files

add setkbatch

parent 41a1466a
...@@ -57,7 +57,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F ...@@ -57,7 +57,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; //< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
//< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; //< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 32, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; //< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 32, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on // clang-format on
struct ProblemSize final struct ProblemSize final
...@@ -77,6 +78,7 @@ struct ExecutionConfig final ...@@ -77,6 +78,7 @@ struct ExecutionConfig final
{ {
bool do_verification = true; bool do_verification = true;
int init_method = 1; int init_method = 1;
int k_batch = 1;
bool time_kernel = false; bool time_kernel = false;
}; };
...@@ -238,6 +240,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -238,6 +240,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
} }
gemm.SetDeviceKernelArgs(argument, gemm_desc_workspace.GetDeviceBuffer()); gemm.SetDeviceKernelArgs(argument, gemm_desc_workspace.GetDeviceBuffer());
gemm.SetKBatch(argument, config.k_batch);
invoker.Run(argument, StreamConfig{nullptr, false}); invoker.Run(argument, StreamConfig{nullptr, false});
...@@ -293,8 +296,7 @@ int main(int argc, char* argv[]) ...@@ -293,8 +296,7 @@ int main(int argc, char* argv[])
problem_size.group_count = 16; problem_size.group_count = 16;
problem_size.Ms = { problem_size.Ms = {167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 118, 0, 1, 148};
167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
for(int i = 0; i < problem_size.group_count; i++) for(int i = 0; i < problem_size.group_count; i++)
{ {
...@@ -306,17 +308,19 @@ int main(int argc, char* argv[]) ...@@ -306,17 +308,19 @@ int main(int argc, char* argv[])
problem_size.stride_Cs.push_back(problem_size.Ns[i]); problem_size.stride_Cs.push_back(problem_size.Ns[i]);
} }
if(argc == 4) if(argc == 5)
{ {
config.do_verification = std::stoi(argv[1]); config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]); config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]); config.time_kernel = std::stoi(argv[3]);
config.k_batch = std::stoi(argv[4]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4: k_batch (> 0)\n");
exit(0); exit(0);
} }
......
...@@ -54,6 +54,7 @@ struct DeviceGroupedGemmFixedNK : DeviceGroupedGemm<ALayout, ...@@ -54,6 +54,7 @@ struct DeviceGroupedGemmFixedNK : DeviceGroupedGemm<ALayout,
CElementwiseOperation> CElementwiseOperation>
{ {
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const = 0; virtual void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const = 0;
virtual void SetKBatch(BaseArgument* p_arg, index_t k_batch) const = 0;
}; };
} // namespace device } // namespace device
......
...@@ -36,6 +36,7 @@ template <typename GridwiseGemm, ...@@ -36,6 +36,7 @@ template <typename GridwiseGemm,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation, typename CDEElementwiseOperation,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
...@@ -103,25 +104,29 @@ __global__ void ...@@ -103,25 +104,29 @@ __global__ void
const auto block_2_etile_map = const auto block_2_etile_map =
GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off); GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off);
GridwiseGemm:: GridwiseGemm::template Run<HasMainKBlockLoop,
template Run<HasMainKBlockLoop, GemmSpec, ALayout, BLayout, DsLayout, ELayout>( EGlobalMemoryDataOperation,
gemm_desc_ptr[group_id].p_a_grid, GemmSpec,
gemm_desc_ptr[group_id].p_b_grid, ALayout,
p_ds_grid_, BLayout,
gemm_desc_ptr[group_id].p_e_grid, DsLayout,
p_shared, ELayout>(gemm_desc_ptr[group_id].p_a_grid,
a_element_op, gemm_desc_ptr[group_id].p_b_grid,
b_element_op, p_ds_grid_,
c_element_op, gemm_desc_ptr[group_id].p_e_grid,
M, p_shared,
N, a_element_op,
K, b_element_op,
StrideA, c_element_op,
StrideB, M,
StrideDs, N,
StrideE, K,
KBatch, StrideA,
block_2_etile_map); StrideB,
StrideDs,
StrideE,
KBatch,
block_2_etile_map);
id_off += grid_size_grp; id_off += grid_size_grp;
} }
...@@ -195,8 +200,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -195,8 +200,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
static const index_t k_batch = 2;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
...@@ -211,7 +214,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -211,7 +214,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
InMemoryDataOperationEnum::AtomicAdd,
NumPrefetch, // NumGemmKPrefetchStage NumPrefetch, // NumGemmKPrefetchStage
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -406,6 +408,33 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -406,6 +408,33 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
void UpdateKBatch(index_t k_batch)
{
k_batch_ = k_batch;
if(k_batch_ < 1)
{
throw std::runtime_error("wrong! k_batch must be > 0");
}
const index_t AverM = sum_of_m / group_count_;
const index_t StrideE = gemm_desc_kernel_arg_[0].StrideE_;
const index_t N = gemm_desc_kernel_arg_[0].N_;
const auto e_grid_desc_m_n =
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(
AverM, N, StrideE);
const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_};
grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n);
grid_size_ = grid_size_grp_ * group_count_;
}
Argument(std::vector<const void*>& p_As, Argument(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs, std::vector<const void*>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>& p_Ds, std::vector<std::array<const void*, NumDTensor>>& p_Ds,
...@@ -418,6 +447,8 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -418,6 +447,8 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
{ {
grid_size_ = 0; grid_size_ = 0;
k_batch_ = 1;
grouped_gemm_kernel_args_dev = nullptr; grouped_gemm_kernel_args_dev = nullptr;
group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size()); group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
...@@ -497,19 +528,16 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -497,19 +528,16 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
AverM, N, StrideE); AverM, N, StrideE);
// block-to-e-tile map // block-to-e-tile map
const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch}; const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_};
grid_size_grp = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n);
// std::cout << "group_id: " << group_id << " grid_size_grp: " << grid_size_grp grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n);
//<< std::endl;
if(group_id * grid_size_grp != grid_size_) if(group_id * grid_size_grp_ != grid_size_)
{ {
throw std::runtime_error("wrong! grid_size_grp is not identical!"); throw std::runtime_error("wrong! grid_size_grp_ is not identical!");
} }
grid_size_ += grid_size_grp; grid_size_ += grid_size_grp_;
// check block-to-E-tile // check block-to-E-tile
if(!local_b2c_tile_map.CheckValidity(e_grid_desc_m_n)) if(!local_b2c_tile_map.CheckValidity(e_grid_desc_m_n))
...@@ -557,8 +585,10 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -557,8 +585,10 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
const void* grouped_gemm_kernel_args_dev; const void* grouped_gemm_kernel_args_dev;
index_t grid_size_; index_t grid_size_;
index_t grid_size_grp; index_t grid_size_grp_;
index_t sum_of_m; index_t sum_of_m;
index_t k_batch_;
}; };
// Invoker // Invoker
...@@ -570,37 +600,25 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -570,37 +600,25 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
{ {
bool has_main_k_block_loop = true; bool has_main_k_block_loop = true;
std::vector<GroupedGemmKernelArgument<NumDTensor>> grouped_gemm_kernel_args;
grouped_gemm_kernel_args.reserve(arg.gemm_desc_kernel_arg_.size());
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{ {
const auto KPad = const auto KPad =
GridwiseGemm::CalculateKPadded(arg.gemm_desc_kernel_arg_[i].K_, k_batch); GridwiseGemm::CalculateKPadded(arg.gemm_desc_kernel_arg_[i].K_, arg.k_batch_);
if(GridwiseGemm::CalculateHasMainKBlockLoop(KPad) != has_main_k_block_loop) if(GridwiseGemm::CalculateHasMainKBlockLoop(KPad) != has_main_k_block_loop)
{ {
throw std::runtime_error("wrong! not all gemm has_main_k_block_loop"); throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
} }
}
grouped_gemm_kernel_args.push_back( if(arg.grouped_gemm_kernel_args_dev == nullptr)
GroupedGemmKernelArgument<NumDTensor>{arg.gemm_desc_kernel_arg_[i].a_ptr_, {
arg.gemm_desc_kernel_arg_[i].b_ptr_, throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr");
arg.gemm_desc_kernel_arg_[i].ds_ptr_,
arg.gemm_desc_kernel_arg_[i].e_ptr_,
arg.gemm_desc_kernel_arg_[i].M_,
arg.gemm_desc_kernel_arg_[i].N_,
arg.gemm_desc_kernel_arg_[i].K_,
arg.gemm_desc_kernel_arg_[i].StrideA_,
arg.gemm_desc_kernel_arg_[i].StrideB_,
arg.gemm_desc_kernel_arg_[i].StrideDs_,
arg.gemm_desc_kernel_arg_[i].StrideE_});
} }
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_, auto e_global_memory_operation_) {
const auto kernel = const auto kernel =
kernel_grouped_gemm_xdl_fixed_nk<GridwiseGemm, kernel_grouped_gemm_xdl_fixed_nk<GridwiseGemm,
GroupedGemmKernelArgument<NumDTensor>, GroupedGemmKernelArgument<NumDTensor>,
...@@ -615,13 +633,9 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -615,13 +633,9 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
e_global_memory_operation_,
has_main_k_block_loop_>; has_main_k_block_loop_>;
if(arg.grouped_gemm_kernel_args_dev == nullptr)
{
throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr");
}
return launch_and_time_kernel( return launch_and_time_kernel(
stream_config, stream_config,
kernel, kernel,
...@@ -630,20 +644,43 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -630,20 +644,43 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
0, 0,
cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev),
arg.gemm_desc_kernel_arg_.size(), arg.gemm_desc_kernel_arg_.size(),
arg.grid_size_grp, arg.grid_size_grp_,
k_batch, arg.k_batch_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_); arg.c_element_op_);
}; };
if(has_main_k_block_loop) constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd;
constexpr auto Set = InMemoryDataOperationEnum::Set;
if(arg.k_batch_ > 1)
{ {
ave_time = launch_kernel(integral_constant<bool, true>{}); if(has_main_k_block_loop)
{
ave_time =
launch_kernel(integral_constant<bool, true>{},
integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
}
else
{
ave_time =
launch_kernel(integral_constant<bool, false>{},
integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
}
} }
else else
{ {
ave_time = launch_kernel(integral_constant<bool, false>{}); if(has_main_k_block_loop)
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<InMemoryDataOperationEnum, Set>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<InMemoryDataOperationEnum, Set>{});
}
} }
return ave_time; return ave_time;
...@@ -775,6 +812,14 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -775,6 +812,14 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
return dynamic_cast<const Argument*>(p_arg)->group_count_ * return dynamic_cast<const Argument*>(p_arg)->group_count_ *
sizeof(GroupedGemmKernelArgument<NumDTensor>); sizeof(GroupedGemmKernelArgument<NumDTensor>);
} }
static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); }
// polymorphic
void SetKBatch(BaseArgument* p_arg, index_t k_batch) const override
{
return SetKBatch(*dynamic_cast<Argument*>(p_arg), k_batch);
}
}; };
} // namespace device } // namespace device
......
...@@ -37,7 +37,6 @@ template <typename ABDataType, // FIXME: don't assume A/B have same datatype ...@@ -37,7 +37,6 @@ template <typename ABDataType, // FIXME: don't assume A/B have same datatype
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation, typename CDEElementwiseOperation,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
...@@ -632,6 +631,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -632,6 +631,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
__device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; } __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; }
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
typename AGridDesc_KBatch_AK0_M_AK1, typename AGridDesc_KBatch_AK0_M_AK1,
typename BGridDesc_KBatch_BK0_N_BK1, typename BGridDesc_KBatch_BK0_N_BK1,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -1074,6 +1074,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -1074,6 +1074,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
} }
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
...@@ -1139,19 +1140,20 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -1139,19 +1140,20 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
const auto e_grid_desc_mblock_mperblock_nblock_nperblock = const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
Run<HasMainKBlockLoop>(p_a_grid, Run<HasMainKBlockLoop, EGlobalMemoryDataOperation>(
p_b_grid, p_a_grid,
p_ds_grid, p_b_grid,
p_e_grid, p_ds_grid,
p_shared, p_e_grid,
a_element_op, p_shared,
b_element_op, a_element_op,
cde_element_op, b_element_op,
a_grid_desc_kbatch_ak0_m_ak1, cde_element_op,
b_grid_desc_kbatch_bk0_n_bk1, a_grid_desc_kbatch_ak0_m_ak1,
ds_grid_desc_mblock_mperblock_nblock_nperblock, b_grid_desc_kbatch_bk0_n_bk1,
e_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map); e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment