Commit ee1c88b1 authored by Jing Zhang's avatar Jing Zhang
Browse files

add setkbatch

parent 41a1466a
......@@ -57,7 +57,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
//< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 32, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
//< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 32, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
struct ProblemSize final
......@@ -77,6 +78,7 @@ struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
int k_batch = 1;
bool time_kernel = false;
};
......@@ -238,6 +240,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
}
gemm.SetDeviceKernelArgs(argument, gemm_desc_workspace.GetDeviceBuffer());
gemm.SetKBatch(argument, config.k_batch);
invoker.Run(argument, StreamConfig{nullptr, false});
......@@ -293,8 +296,7 @@ int main(int argc, char* argv[])
problem_size.group_count = 16;
problem_size.Ms = {
167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
problem_size.Ms = {167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 118, 0, 1, 148};
for(int i = 0; i < problem_size.group_count; i++)
{
......@@ -306,17 +308,19 @@ int main(int argc, char* argv[])
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
}
if(argc == 4)
if(argc == 5)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
config.k_batch = std::stoi(argv[4]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4: k_batch (> 0)\n");
exit(0);
}
......
......@@ -54,6 +54,7 @@ struct DeviceGroupedGemmFixedNK : DeviceGroupedGemm<ALayout,
CElementwiseOperation>
{
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const = 0;
virtual void SetKBatch(BaseArgument* p_arg, index_t k_batch) const = 0;
};
} // namespace device
......
......@@ -36,6 +36,7 @@ template <typename GridwiseGemm,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
......@@ -103,25 +104,29 @@ __global__ void
const auto block_2_etile_map =
GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off);
GridwiseGemm::
template Run<HasMainKBlockLoop, GemmSpec, ALayout, BLayout, DsLayout, ELayout>(
gemm_desc_ptr[group_id].p_a_grid,
gemm_desc_ptr[group_id].p_b_grid,
p_ds_grid_,
gemm_desc_ptr[group_id].p_e_grid,
p_shared,
a_element_op,
b_element_op,
c_element_op,
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
KBatch,
block_2_etile_map);
GridwiseGemm::template Run<HasMainKBlockLoop,
EGlobalMemoryDataOperation,
GemmSpec,
ALayout,
BLayout,
DsLayout,
ELayout>(gemm_desc_ptr[group_id].p_a_grid,
gemm_desc_ptr[group_id].p_b_grid,
p_ds_grid_,
gemm_desc_ptr[group_id].p_e_grid,
p_shared,
a_element_op,
b_element_op,
c_element_op,
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
KBatch,
block_2_etile_map);
id_off += grid_size_grp;
}
......@@ -195,8 +200,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
static constexpr index_t NumDTensor = DsDataType::Size();
static const index_t k_batch = 2;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
......@@ -211,7 +214,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::AtomicAdd,
NumPrefetch, // NumGemmKPrefetchStage
BlockSize,
MPerBlock,
......@@ -406,6 +408,33 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
// Argument
struct Argument : public BaseArgument
{
void UpdateKBatch(index_t k_batch)
{
k_batch_ = k_batch;
if(k_batch_ < 1)
{
throw std::runtime_error("wrong! k_batch must be > 0");
}
const index_t AverM = sum_of_m / group_count_;
const index_t StrideE = gemm_desc_kernel_arg_[0].StrideE_;
const index_t N = gemm_desc_kernel_arg_[0].N_;
const auto e_grid_desc_m_n =
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(
AverM, N, StrideE);
const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_};
grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n);
grid_size_ = grid_size_grp_ * group_count_;
}
Argument(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
......@@ -418,6 +447,8 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
{
grid_size_ = 0;
k_batch_ = 1;
grouped_gemm_kernel_args_dev = nullptr;
group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
......@@ -497,19 +528,16 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
AverM, N, StrideE);
// block-to-e-tile map
const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch};
grid_size_grp = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n);
const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_};
// std::cout << "group_id: " << group_id << " grid_size_grp: " << grid_size_grp
//<< std::endl;
grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n);
if(group_id * grid_size_grp != grid_size_)
if(group_id * grid_size_grp_ != grid_size_)
{
throw std::runtime_error("wrong! grid_size_grp is not identical!");
throw std::runtime_error("wrong! grid_size_grp_ is not identical!");
}
grid_size_ += grid_size_grp;
grid_size_ += grid_size_grp_;
// check block-to-E-tile
if(!local_b2c_tile_map.CheckValidity(e_grid_desc_m_n))
......@@ -557,8 +585,10 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
const void* grouped_gemm_kernel_args_dev;
index_t grid_size_;
index_t grid_size_grp;
index_t grid_size_grp_;
index_t sum_of_m;
index_t k_batch_;
};
// Invoker
......@@ -570,37 +600,25 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
{
bool has_main_k_block_loop = true;
std::vector<GroupedGemmKernelArgument<NumDTensor>> grouped_gemm_kernel_args;
grouped_gemm_kernel_args.reserve(arg.gemm_desc_kernel_arg_.size());
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{
const auto KPad =
GridwiseGemm::CalculateKPadded(arg.gemm_desc_kernel_arg_[i].K_, k_batch);
GridwiseGemm::CalculateKPadded(arg.gemm_desc_kernel_arg_[i].K_, arg.k_batch_);
if(GridwiseGemm::CalculateHasMainKBlockLoop(KPad) != has_main_k_block_loop)
{
throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
}
}
grouped_gemm_kernel_args.push_back(
GroupedGemmKernelArgument<NumDTensor>{arg.gemm_desc_kernel_arg_[i].a_ptr_,
arg.gemm_desc_kernel_arg_[i].b_ptr_,
arg.gemm_desc_kernel_arg_[i].ds_ptr_,
arg.gemm_desc_kernel_arg_[i].e_ptr_,
arg.gemm_desc_kernel_arg_[i].M_,
arg.gemm_desc_kernel_arg_[i].N_,
arg.gemm_desc_kernel_arg_[i].K_,
arg.gemm_desc_kernel_arg_[i].StrideA_,
arg.gemm_desc_kernel_arg_[i].StrideB_,
arg.gemm_desc_kernel_arg_[i].StrideDs_,
arg.gemm_desc_kernel_arg_[i].StrideE_});
if(arg.grouped_gemm_kernel_args_dev == nullptr)
{
throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr");
}
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
auto launch_kernel = [&](auto has_main_k_block_loop_, auto e_global_memory_operation_) {
const auto kernel =
kernel_grouped_gemm_xdl_fixed_nk<GridwiseGemm,
GroupedGemmKernelArgument<NumDTensor>,
......@@ -615,13 +633,9 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
e_global_memory_operation_,
has_main_k_block_loop_>;
if(arg.grouped_gemm_kernel_args_dev == nullptr)
{
throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr");
}
return launch_and_time_kernel(
stream_config,
kernel,
......@@ -630,20 +644,43 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
0,
cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev),
arg.gemm_desc_kernel_arg_.size(),
arg.grid_size_grp,
k_batch,
arg.grid_size_grp_,
arg.k_batch_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
};
if(has_main_k_block_loop)
constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd;
constexpr auto Set = InMemoryDataOperationEnum::Set;
if(arg.k_batch_ > 1)
{
ave_time = launch_kernel(integral_constant<bool, true>{});
if(has_main_k_block_loop)
{
ave_time =
launch_kernel(integral_constant<bool, true>{},
integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
}
else
{
ave_time =
launch_kernel(integral_constant<bool, false>{},
integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
}
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{});
if(has_main_k_block_loop)
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<InMemoryDataOperationEnum, Set>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<InMemoryDataOperationEnum, Set>{});
}
}
return ave_time;
......@@ -775,6 +812,14 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
return dynamic_cast<const Argument*>(p_arg)->group_count_ *
sizeof(GroupedGemmKernelArgument<NumDTensor>);
}
static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); }
// polymorphic
void SetKBatch(BaseArgument* p_arg, index_t k_batch) const override
{
return SetKBatch(*dynamic_cast<Argument*>(p_arg), k_batch);
}
};
} // namespace device
......
......@@ -37,7 +37,6 @@ template <typename ABDataType, // FIXME: don't assume A/B have same datatype
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
......@@ -632,6 +631,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
__device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; }
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
typename AGridDesc_KBatch_AK0_M_AK1,
typename BGridDesc_KBatch_BK0_N_BK1,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
......@@ -1074,6 +1074,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
}
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
GemmSpecialization GemmSpec,
typename ALayout,
typename BLayout,
......@@ -1139,19 +1140,20 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_kbatch_ak0_m_ak1,
b_grid_desc_kbatch_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation>(
p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_kbatch_ak0_m_ak1,
b_grid_desc_kbatch_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment