Commit 6e9ef894 authored by rtmadduri's avatar rtmadduri
Browse files

applied changes to tail num lambda, clean up ctrs

parent bf73d297
...@@ -76,7 +76,8 @@ __global__ void ...@@ -76,7 +76,8 @@ __global__ void
karg, karg,
karg.a_element_op, karg.a_element_op,
karg.b_element_op, karg.b_element_op,
karg.c_element_op gemm_desc_ptr[group_id].block_2_ctile_map_); karg.c_element_op,
gemm_desc_ptr[group_id].block_2_ctile_map_);
#else #else
ignore = gemm_descs_const; ignore = gemm_descs_const;
ignore = group_count; ignore = group_count;
...@@ -137,20 +138,18 @@ template <typename ALayout, ...@@ -137,20 +138,18 @@ template <typename ALayout,
// MultipleD not supported for now. // MultipleD not supported for now.
enable_if_t<is_same_v<DsLayout, ck::Tuple<>> && is_same_v<DsDataType, ck::Tuple<>>, enable_if_t<is_same_v<DsLayout, ck::Tuple<>> && is_same_v<DsDataType, ck::Tuple<>>,
bool> = false> bool> = false>
>
struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayout,
struct DeviceGroupedGemmXdlSplitKCShuffle BLayout,
: public DeviceGroupedGemmSplitK<ALayout, DsLayout,
BLayout, ELayout,
DsLayout, ADataType,
ELayout, BDataType,
ADataType, DsDataType,
BDataType, EDataType,
DsDataType, AElementwiseOperation,
EDataType, BElementwiseOperation,
AElementwiseOperation, CDEElementwiseOperation>
BElementwiseOperation,
CDEElementwiseOperation>
{ {
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
...@@ -221,7 +220,7 @@ template <typename ALayout, ...@@ -221,7 +220,7 @@ template <typename ALayout,
GroupedGemmBlock2ETileMap block_2_ctile_map_; GroupedGemmBlock2ETileMap block_2_ctile_map_;
index_t block_start_, block_end_; index_t block_start_, block_end_;
GemmTransKernelArg() = default; // GemmTransKernelArg() = default;
GemmTransKernelArg(KernelArgument&& karg, GemmTransKernelArg(KernelArgument&& karg,
GroupedGemmBlock2ETileMap&& b2c_map, GroupedGemmBlock2ETileMap&& b2c_map,
index_t block_start, index_t block_start,
...@@ -243,11 +242,8 @@ template <typename ALayout, ...@@ -243,11 +242,8 @@ template <typename ALayout,
Argument(std::vector<const void*>& p_a_grid, Argument(std::vector<const void*>& p_a_grid,
std::vector<const void*>& p_b_grid, std::vector<const void*>& p_b_grid,
std::vector<void*>& p_c_grid, std::vector<void*>& p_c_grid,
std::vector<GemmDesc>& gemm_descs, std::vector<GemmDesc>& gemm_descs)
AElementwiseOperation a_element_op, : Argument(p_a_grid, p_b_grid, p_c_grid, gemm_descs, DefaultKBatch)
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op))
: Argument(p_a_grid, p_b_grid, p_c_grid, gemm_descs, DefaultKBatch, a_element_op, b_element_op, cde_element_op )
{ {
// TODO: use occupancy api to calculate appropriate batch size. // TODO: use occupancy api to calculate appropriate batch size.
} }
...@@ -256,10 +252,7 @@ template <typename ALayout, ...@@ -256,10 +252,7 @@ template <typename ALayout,
std::vector<const void*>& p_b_grid, std::vector<const void*>& p_b_grid,
std::vector<void*>& p_c_grid, std::vector<void*>& p_c_grid,
std::vector<GemmDesc>& gemm_descs, std::vector<GemmDesc>& gemm_descs,
index_t kbatch, index_t kbatch)
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op))
: K_BATCH{kbatch} : K_BATCH{kbatch}
{ {
grid_size_ = 0; grid_size_ = 0;
...@@ -307,19 +300,14 @@ template <typename ALayout, ...@@ -307,19 +300,14 @@ template <typename ALayout,
KernelArgument karg{type_convert<const ADataType*>(p_a_grid[i]), KernelArgument karg{type_convert<const ADataType*>(p_a_grid[i]),
type_convert<const BDataType*>(p_b_grid[i]), type_convert<const BDataType*>(p_b_grid[i]),
{}, // p_ds_grid
type_convert<EDataType*>(p_c_grid[i]), type_convert<EDataType*>(p_c_grid[i]),
M, M,
N, N,
K, K,
stride_a, stride_a,
stride_b, stride_b,
{}, // StrideDs_
stride_c, stride_c,
K_BATCH, K_BATCH};
a_element_op,
b_element_op,
cde_element_op};
gemm_kernel_args_.emplace_back( gemm_kernel_args_.emplace_back(
std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end); std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end);
...@@ -341,8 +329,8 @@ template <typename ALayout, ...@@ -341,8 +329,8 @@ template <typename ALayout,
auto& karg = gemm_kernel_args_[i].karg_; auto& karg = gemm_kernel_args_[i].karg_;
const auto local_b2c_tile_map = Block2ETileMap{M, N, 4}; const auto local_b2c_tile_map = Block2ETileMap{karg.M, karg.N, 4};
index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(M, N); index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(karg.M, karg.N);
grid_size_grp *= K_BATCH; grid_size_grp *= K_BATCH;
const index_t block_start = grid_size_; const index_t block_start = grid_size_;
...@@ -380,16 +368,17 @@ template <typename ALayout, ...@@ -380,16 +368,17 @@ template <typename ALayout,
bool all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split0); bool all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split0);
const auto tail_num = GridwiseGemm::CalculateKBlockLoopTailNum(K_split0); const auto tail_num = GridwiseGemm::CalculateKBlockLoopTailNum(K_split0);
bool all_have_kbatch_gt_one = karg0.KBatch > 1;
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{ {
const auto& karg = arg.gemm_kernel_args_[i].karg_;
if(stream_config.log_level_ > 0) if(stream_config.log_level_ > 0)
{ {
karg.Print(); karg.Print();
} }
const auto& karg = arg.gemm_kernel_args_[i].karg_;
index_t k_grain = karg.KBatch * KPerBlock; index_t k_grain = karg.KBatch * KPerBlock;
index_t K_split = (karg.K + k_grain - 1) / k_grain * KPerBlock; index_t K_split = (karg.K + k_grain - 1) / k_grain * KPerBlock;
...@@ -460,11 +449,16 @@ template <typename ALayout, ...@@ -460,11 +449,16 @@ template <typename ALayout,
rotating_mem.Next(); rotating_mem.Next();
// clear c mem // clear c mem
// TODO: should be loop here through all groups // TODO: should be loop here through all groups
if(arg_.KBatch > 1) for(const auto& trans_arg : arg.gemm_kernel_args_)
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, {
0, const auto& karg = trans_arg.karg_;
arg_.M * arg_.N * sizeof(CDataType), if(karg.KBatch > 1)
stream_config.stream_id_)); hipGetErrorString(
hipMemsetAsync(karg.p_c_grid,
0,
karg.M * karg.N * sizeof(EDataType),
stream_config.stream_id_));
}
}; };
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>( ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
...@@ -480,11 +474,15 @@ template <typename ALayout, ...@@ -480,11 +474,15 @@ template <typename ALayout,
else else
{ {
// TODO: should be loop here through all groups // TODO: should be loop here through all groups
if(arg.KBatch > 1) for(const auto& trans_arg : arg.gemm_kernel_args_)
hipGetErrorString(hipMemsetAsync(arg.p_c_grid, {
0, const auto& karg = trans_arg.karg_;
arg.M * arg.N * sizeof(CDataType), if(karg.KBatch > 1)
stream_config.stream_id_)); hipGetErrorString(hipMemsetAsync(karg.p_c_grid,
0,
karg.M * karg.N * sizeof(EDataType),
stream_config.stream_id_));
}
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config, stream_config,
...@@ -537,7 +535,8 @@ template <typename ALayout, ...@@ -537,7 +535,8 @@ template <typename ALayout,
{ {
const auto kernel = const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm, kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArg true, GemmTransKernelArg,
true,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, minimum_occupancy,
TailNumber::One>; TailNumber::One>;
...@@ -546,313 +545,231 @@ template <typename ALayout, ...@@ -546,313 +545,231 @@ template <typename ALayout,
//// TODO: Fix below as above! //// TODO: Fix below as above!
else if(calculate_tail_number() == TailNumber::Full) else if(tail_num == TailNumber::Full)
{ {
if(all_have_same_tail_number())
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
else const auto kernel =
{ kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
throw_error(); GemmTransKernelArg,
} true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
} }
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{ {
if(calculate_tail_number() == TailNumber::Two) if(tail_num == TailNumber::Two)
{ {
if(all_have_same_tail_number())
{ const auto kernel = kernel_grouped_gemm_xdl_splitk<
const auto kernel = kernel_gemm_xdl_cshuffle_v3< GridwiseGemm,
GridwiseGemm, GemmTransKernelArg,
true, true,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, minimum_occupancy,
TailNumber::Two>; TailNumber::Two>;
Run(kernel); Run(kernel);
}
else
{
throw_error();
}
} }
} }
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{ {
if(calculate_tail_number() == TailNumber::Three) if(tail_num == TailNumber::Three)
{ {
if(all_have_same_tail_number())
{ const auto kernel = kernel_grouped_gemm_xdl_splitk<
const auto kernel = kernel_gemm_xdl_cshuffle_v3< GridwiseGemm,
GridwiseGemm, GemmTransKernelArg,
true, true,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, minimum_occupancy,
TailNumber::Three>; TailNumber::Three>;
Run(kernel); Run(kernel);
}
else
{
throw_error();
}
} }
} }
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{ {
if(calculate_tail_number() == TailNumber::Four) if(tail_num == TailNumber::Four)
{ {
if(all_have_same_tail_number()) const auto kernel = kernel_grouped_gemm_xdl_splitk<
{ GridwiseGemm,
const auto kernel = kernel_gemm_xdl_cshuffle_v3< GemmTransKernelArg,
GridwiseGemm, true,
true, InMemoryDataOperationEnum::AtomicAdd,
InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy,
minimum_occupancy, TailNumber::Four>;
TailNumber::Four>; Run(kernel);
Run(kernel);
}
else
{
throw_error();
}
} }
} }
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{ {
if(calculate_tail_number() == TailNumber::Five) if(tail_num == TailNumber::Five)
{ {
if(all_have_same_tail_number()) const auto kernel = kernel_grouped_gemm_xdl_splitk<
{ GridwiseGemm,
const auto kernel = kernel_gemm_xdl_cshuffle_v3< GemmTransKernelArg,
GridwiseGemm, true,
true, InMemoryDataOperationEnum::AtomicAdd,
InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy,
minimum_occupancy, TailNumber::Five>;
TailNumber::Five>; Run(kernel);
Run(kernel);
}
else
{
throw_error();
}
} }
} }
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{ {
if(calculate_tail_number() == TailNumber::Six) if(tail_num == TailNumber::Six)
{ {
if(all_have_same_tail_number()) const auto kernel = kernel_grouped_gemm_xdl_splitk<
{ GridwiseGemm,
const auto kernel = kernel_gemm_xdl_cshuffle_v3< GemmTransKernelArg,
GridwiseGemm, true,
true, InMemoryDataOperationEnum::AtomicAdd,
InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy,
minimum_occupancy, TailNumber::Six>;
TailNumber::Six>; Run(kernel);
Run(kernel);
}
else
{
throw_error();
}
} }
} }
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{ {
if(calculate_tail_number() == TailNumber::Seven) if(tail_num == TailNumber::Seven)
{ {
if(all_have_same_tail_number()) const auto kernel = kernel_grouped_gemm_xdl_splitk<
{ GridwiseGemm,
const auto kernel = kernel_gemm_xdl_cshuffle_v3< GemmTransKernelArg,
GridwiseGemm, true,
true, InMemoryDataOperationEnum::AtomicAdd,
InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy,
minimum_occupancy, TailNumber::Seven>;
TailNumber::Seven>; Run(kernel);
Run(kernel);
}
else
{
throw_error();
}
} }
} }
} }
else else
{ {
if(calculate_tail_number() == TailNumber::One) if(tail_num == TailNumber::One)
{ {
if(all_have_same_tail_number())
{ const auto kernel =
const auto kernel = kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, GemmTransKernelArg,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::One>; TailNumber::One>;
Run(kernel); Run(kernel);
}
else
{
throw_error();
}
} }
else if(calculate_tail_number() == TailNumber::Full) else if(tail_num == TailNumber::Full)
{ {
if(all_have_same_tail_number())
{ const auto kernel =
const auto kernel = kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, GemmTransKernelArg,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::Full>; TailNumber::Full>;
Run(kernel); Run(kernel);
}
else
{
throw_error();
}
} }
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{ {
if(calculate_tail_number() == TailNumber::Two) if(tail_num == TailNumber::Two)
{ {
if(all_have_same_tail_number())
{ const auto kernel =
const auto kernel = kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, GemmTransKernelArg,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::Two>; TailNumber::Two>;
Run(kernel); Run(kernel);
}
else
{
throw_error();
}
} }
} }
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{ {
if(calculate_tail_number() == TailNumber::Three) if(tail_num == TailNumber::Three)
{ {
if(all_have_same_tail_number())
{ const auto kernel =
const auto kernel = kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, GemmTransKernelArg,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::Three>; TailNumber::Three>;
Run(kernel); Run(kernel);
}
else
{
throw_error();
}
} }
} }
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{ {
if(calculate_tail_number() == TailNumber::Four) if(tail_num == TailNumber::Four)
{ {
if(all_have_same_tail_number())
{ const auto kernel =
const auto kernel = kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, GemmTransKernelArg,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::Four>; TailNumber::Four>;
Run(kernel); Run(kernel);
}
else
{
throw_error();
}
} }
} }
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{ {
if(calculate_tail_number() == TailNumber::Five) if(tail_num == TailNumber::Five)
{ {
if(all_have_same_tail_number())
{ const auto kernel =
const auto kernel = kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, GemmTransKernelArg,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::Five>; TailNumber::Five>;
Run(kernel); Run(kernel);
}
else
{
throw_error();
}
} }
} }
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{ {
if(calculate_tail_number() == TailNumber::Six) if(tail_num == TailNumber::Six)
{ {
if(all_have_same_tail_number())
{ const auto kernel =
const auto kernel = kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, GemmTransKernelArg,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::Six>; TailNumber::Six>;
Run(kernel); Run(kernel);
}
else
{
throw_error();
}
} }
} }
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{ {
if(calculate_tail_number() == TailNumber::Seven) if(tail_num == TailNumber::Seven)
{ {
if(all_have_same_tail_number()) const auto kernel =
{ kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
const auto kernel = GemmTransKernelArg,
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, true,
true, InMemoryDataOperationEnum::Set,
InMemoryDataOperationEnum::Set, minimum_occupancy,
minimum_occupancy, TailNumber::Seven>;
TailNumber::Seven>; Run(kernel);
Run(kernel);
}
else
{
throw_error();
}
} }
} }
} }
...@@ -862,77 +779,57 @@ template <typename ALayout, ...@@ -862,77 +779,57 @@ template <typename ALayout,
{ {
if(all_have_kbatch_gt_one) if(all_have_kbatch_gt_one)
{ {
if(calculate_tail_number() == TailNumber::Odd) if(tail_num == TailNumber::Odd)
{ {
if(all_have_same_tail_number())
{ const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< GridwiseGemm,
GridwiseGemm, GemmTransKernelArg,
true, true,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, minimum_occupancy,
TailNumber::Odd>; TailNumber::Odd>;
Run(kernel); Run(kernel);
}
else
{
throw_error();
}
} }
else else
{ {
if(all_have_same_tail_number())
{ const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< GridwiseGemm,
GridwiseGemm, GemmTransKernelArg,
true, true,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, minimum_occupancy,
TailNumber::Even>; TailNumber::Even>;
Run(kernel); Run(kernel);
}
else
{
throw_error();
}
} }
} }
else else
{ {
if(calculate_tail_number() == TailNumber::Odd) if(tail_num == TailNumber::Odd)
{ {
if(all_have_same_tail_number())
{ const auto kernel =
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm, GemmTransKernelArg,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::Odd>; TailNumber::Odd>;
Run(kernel); Run(kernel);
}
else
{
throw_error();
}
} }
else else
{ {
if(all_have_same_tail_number())
{ const auto kernel =
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm, GemmTransKernelArg,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::Even>; TailNumber::Even>;
Run(kernel); Run(kernel);
}
else
{
throw_error();
}
} }
} }
} }
...@@ -941,78 +838,57 @@ template <typename ALayout, ...@@ -941,78 +838,57 @@ template <typename ALayout,
{ {
if(all_have_kbatch_gt_one) if(all_have_kbatch_gt_one)
{ {
if(calculate_tail_number() == TailNumber::Odd) if(tail_num == TailNumber::Odd)
{ {
if(all_have_same_tail_number())
{ const auto kernel =
const auto kernel = kernel_gemm_xdl_cshuffle_v3< kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GridwiseGemm, GemmTransKernelArg,
true, true,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, minimum_occupancy,
TailNumber::Odd>; TailNumber::Odd>;
Run(kernel); Run(kernel);
}
else
{
throw_error();
}
} }
else else
{ {
if(all_have_same_tail_number())
{ const auto kernel =
const auto kernel = kernel_gemm_xdl_cshuffle_v3< kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GridwiseGemm, GemmTransKernelArg,
true, true,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, minimum_occupancy,
TailNumber::Even>; TailNumber::Even>;
Run(kernel); Run(kernel);
}
else
{
throw_error();
}
} }
} }
else else
{ {
if(calculate_tail_number() == TailNumber::Odd) if(tail_num == TailNumber::Odd)
{ {
if(all_have_same_tail_number()) const auto kernel =
{ kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
const auto kernel = GemmTransKernelArg,
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, true,
true, InMemoryDataOperationEnum::Set,
InMemoryDataOperationEnum::Set, minimum_occupancy,
minimum_occupancy, TailNumber::Odd>;
TailNumber::Odd>; Run(kernel);
Run(kernel);
}
else
{
throw_error();
}
} }
else else
{ {
if(all_have_same_tail_number())
{ const auto kernel =
const auto kernel = kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, GemmTransKernelArg,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::Even>; TailNumber::Even>;
Run(kernel); Run(kernel);
}
else
{
throw_error();
}
} }
} }
} }
...@@ -1025,19 +901,21 @@ template <typename ALayout, ...@@ -1025,19 +901,21 @@ template <typename ALayout,
if(all_have_kbatch_gt_one) if(all_have_kbatch_gt_one)
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
false, GemmTransKernelArg,
InMemoryDataOperationEnum::AtomicAdd, false,
minimum_occupancy>; InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel); Run(kernel);
} }
else else
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
false, GemmTransKernelArg,
InMemoryDataOperationEnum::Set, false,
minimum_occupancy>; InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel); Run(kernel);
} }
} }
...@@ -1115,12 +993,11 @@ template <typename ALayout, ...@@ -1115,12 +993,11 @@ template <typename ALayout,
std::vector<std::array<const void*, NumDTensor>>&, std::vector<std::array<const void*, NumDTensor>>&,
std::vector<void*>& p_c_grid, std::vector<void*>& p_c_grid,
std::vector<GemmDesc> gemm_descs, std::vector<GemmDesc> gemm_descs,
AElementwiseOperation a_element_op, AElementwiseOperation,
BElementwiseOperation b_element_op, BElementwiseOperation,
CDEElementwiseOperation cde_element_op) CDEElementwiseOperation)
{ {
return Argument{ return Argument{p_a_grid, p_b_grid, p_c_grid, gemm_descs};
p_a_grid, p_b_grid, p_c_grid, gemm_descs, a_element_op, b_element_op, cde_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -1132,12 +1009,11 @@ template <typename ALayout, ...@@ -1132,12 +1009,11 @@ template <typename ALayout,
std::vector<std::array<const void*, NumDTensor>>&, std::vector<std::array<const void*, NumDTensor>>&,
std::vector<void*>& p_c_grid, std::vector<void*>& p_c_grid,
std::vector<GemmDesc>& gemm_descs, std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_element_op, AElementwiseOperation,
BElementwiseOperation b_element_op, BElementwiseOperation,
CDEElementwiseOperation cde_element_op) override CDEElementwiseOperation) override
{ {
return std::make_unique<Argument>( return std::make_unique<Argument>(p_a_grid, p_b_grid, p_c_grid, gemm_descs);
p_a_grid, p_b_grid, p_c_grid, gemm_descs, a_element_op, b_element_op, cde_element_op);
} }
// polymorphic // polymorphic
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment