"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "06c5a07255f7ad37032b17f17949065f14f84f8a"
Commit bf73d297 authored by Adam Osewski's avatar Adam Osewski
Browse files

Fixes

parent a5e9069f
...@@ -28,18 +28,14 @@ template <typename GridwiseGemm, ...@@ -28,18 +28,14 @@ template <typename GridwiseGemm,
typename GemmDesc, typename GemmDesc,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AElementwiseOperation = ck::tensor_operation::element_wise::PassThrough, index_t MinimumOccupancy = 1,
typename BElementwiseOperation = ck::tensor_operation::element_wise::PassThrough, TailNumber TailNum = TailNumber::Full>
typename CDEElementwiseOperation = ck::tensor_operation::element_wise::PassThrough>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif #endif
kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count, const index_t group_count)
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation c_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__)) defined(__gfx94__))
...@@ -68,13 +64,19 @@ __global__ void ...@@ -68,13 +64,19 @@ __global__ void
group_id = index_t((left + right) / 2); group_id = index_t((left + right) / 2);
} }
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>( const auto karg = gemm_desc_ptr[group_id].karg_;
gemm_desc_ptr[group_id].karg_, auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
static_cast<void*>(p_shared),
gemm_desc_ptr[group_id].block_2_ctile_map_, GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
a_element_op, karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
b_element_op, karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
c_element_op); karg.p_ds_grid,
karg.p_c_grid,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op gemm_desc_ptr[group_id].block_2_ctile_map_);
#else #else
ignore = gemm_descs_const; ignore = gemm_descs_const;
ignore = group_count; ignore = group_count;
...@@ -131,19 +133,24 @@ template <typename ALayout, ...@@ -131,19 +133,24 @@ template <typename ALayout,
typename ComputeTypeA = EDataType, typename ComputeTypeA = EDataType,
typename ComputeTypeB = ComputeTypeA, typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false, bool PermuteA = false,
bool PermuteB = false> bool PermuteB = false,
// MultipleD not supported for now.
struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayout, enable_if_t<is_same_v<DsLayout, ck::Tuple<>> && is_same_v<DsDataType, ck::Tuple<>>,
BLayout, bool> = false>
DsLayout, >
ELayout,
ADataType, struct DeviceGroupedGemmXdlSplitKCShuffle
BDataType, : public DeviceGroupedGemmSplitK<ALayout,
DsDataType, BLayout,
EDataType, DsLayout,
AElementwiseOperation, ELayout,
BElementwiseOperation, ADataType,
CDEElementwiseOperation> BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{ {
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
...@@ -198,7 +205,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -198,7 +205,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
BlkGemmPipeSched, BlkGemmPipeSched,
BlkGemmPipelineVer, BlkGemmPipelineVer,
ComputeTypeA, ComputeTypeA,
ComputeTypeB>; ComputeTypeB,
PermuteA,
PermuteB>;
using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
...@@ -209,16 +218,16 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -209,16 +218,16 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
struct GemmTransKernelArg struct GemmTransKernelArg
{ {
KernelArgument karg_; KernelArgument karg_;
// GroupedGemmBlock2ETileMap block_2_ctile_map_; GroupedGemmBlock2ETileMap block_2_ctile_map_;
index_t block_start_, block_end_; index_t block_start_, block_end_;
GemmTransKernelArg() = default; GemmTransKernelArg() = default;
GemmTransKernelArg(KernelArgument&& karg, GemmTransKernelArg(KernelArgument&& karg,
// GroupedGemmBlock2ETileMap&& b2c_map, GroupedGemmBlock2ETileMap&& b2c_map,
index_t block_start, index_t block_start,
index_t block_end) index_t block_end)
: karg_{karg}, : karg_{karg},
// block_2_ctile_map_{b2c_map}, block_2_ctile_map_{b2c_map},
block_start_{block_start}, block_start_{block_start},
block_end_{block_end} block_end_{block_end}
{ {
...@@ -234,8 +243,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -234,8 +243,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
Argument(std::vector<const void*>& p_a_grid, Argument(std::vector<const void*>& p_a_grid,
std::vector<const void*>& p_b_grid, std::vector<const void*>& p_b_grid,
std::vector<void*>& p_c_grid, std::vector<void*>& p_c_grid,
std::vector<GemmDesc>& gemm_descs) std::vector<GemmDesc>& gemm_descs,
: Argument(p_a_grid, p_b_grid, p_c_grid, gemm_descs, DefaultKBatch) AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op))
: Argument(p_a_grid, p_b_grid, p_c_grid, gemm_descs, DefaultKBatch, a_element_op, b_element_op, cde_element_op )
{ {
// TODO: use occupancy api to calculate appropriate batch size. // TODO: use occupancy api to calculate appropriate batch size.
} }
...@@ -244,7 +256,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -244,7 +256,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
std::vector<const void*>& p_b_grid, std::vector<const void*>& p_b_grid,
std::vector<void*>& p_c_grid, std::vector<void*>& p_c_grid,
std::vector<GemmDesc>& gemm_descs, std::vector<GemmDesc>& gemm_descs,
index_t kbatch) index_t kbatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op))
: K_BATCH{kbatch} : K_BATCH{kbatch}
{ {
grid_size_ = 0; grid_size_ = 0;
...@@ -267,7 +282,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -267,7 +282,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const index_t N = gemm_descs[i].N_; const index_t N = gemm_descs[i].N_;
const index_t K = gemm_descs[i].K_; const index_t K = gemm_descs[i].K_;
if(M == 0) if(M * N * K == 0)
{ {
skipped_group_count_++; skipped_group_count_++;
continue; continue;
...@@ -277,12 +292,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -277,12 +292,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const index_t stride_b = gemm_descs[i].stride_B_; const index_t stride_b = gemm_descs[i].stride_B_;
const index_t stride_c = gemm_descs[i].stride_C_; const index_t stride_c = gemm_descs[i].stride_C_;
index_t gdx, gdy, gdz; const auto local_b2c_tile_map = Block2ETileMap{M, N, 4};
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(M, N, K_BATCH); index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(M, N);
grid_size_grp *= K_BATCH;
const auto local_b2c_tile_map = Block2ETileMap{gdx, gdy, gdz};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(M, N);
// const index_t grid_size_grp = gdx * gdy * gdz;
const index_t block_start = grid_size_; const index_t block_start = grid_size_;
const index_t block_end = grid_size_ + grid_size_grp; const index_t block_end = grid_size_ + grid_size_grp;
...@@ -290,24 +302,27 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -290,24 +302,27 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
grid_size_ += grid_size_grp; grid_size_ += grid_size_grp;
// block-to-e-tile map // block-to-e-tile map
// auto grouped_block_2_ctile_map = auto grouped_block_2_ctile_map =
// GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
KernelArgument karg{type_convert<const ADataType*>(p_a_grid[i]), KernelArgument karg{type_convert<const ADataType*>(p_a_grid[i]),
type_convert<const BDataType*>(p_b_grid[i]), type_convert<const BDataType*>(p_b_grid[i]),
{}, // p_ds_grid
type_convert<EDataType*>(p_c_grid[i]), type_convert<EDataType*>(p_c_grid[i]),
M, M,
N, N,
K, K,
stride_a, stride_a,
stride_b, stride_b,
{}, // StrideDs_
stride_c, stride_c,
K_BATCH}; K_BATCH,
a_element_op,
b_element_op,
cde_element_op};
// gemm_kernel_args_.emplace_back( gemm_kernel_args_.emplace_back(
// std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end); std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end);
gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
} }
} }
...@@ -326,28 +341,22 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -326,28 +341,22 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
auto& karg = gemm_kernel_args_[i].karg_; auto& karg = gemm_kernel_args_[i].karg_;
// const index_t m_padded = GridwiseGemm::CalculateMPadded(karg.M); const auto local_b2c_tile_map = Block2ETileMap{M, N, 4};
// const index_t n_padded = GridwiseGemm::CalculateNPadded(karg.N); index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(M, N);
grid_size_grp *= K_BATCH;
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) =
GridwiseGemm::CalculateGridSize(karg.M, karg.N, karg.KBatch);
const auto local_b2c_tile_map = Block2ETileMap{gdx, gdy, gdz};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(karg.M, karg.N);
// const index_t grid_size_grp = gdx * gdy * gdz;
const index_t block_start = grid_size_; const index_t block_start = grid_size_;
const index_t block_end = grid_size_ + grid_size_grp; const index_t block_end = grid_size_ + grid_size_grp;
grid_size_ += grid_size_grp; grid_size_ += grid_size_grp;
// auto grouped_block_2_ctile_map = // block-to-e-tile map
// GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); auto grouped_block_2_ctile_map =
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
karg.KBatch = K_BATCH; karg.KBatch = K_BATCH;
// gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map; gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map;
gemm_kernel_args_[i].block_start_ = block_start; gemm_kernel_args_[i].block_start_ = block_start;
gemm_kernel_args_[i].block_end_ = block_end; gemm_kernel_args_[i].block_end_ = block_end;
} }
} }
...@@ -365,45 +374,53 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -365,45 +374,53 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
{ {
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
bool all_have_main_k_block_loop{true}; const auto& karg0 = arg.gemm_kernel_args_[0].karg_;
bool all_have_kbatch_gt_one; index_t k_grain0 = karg0.KBatch * KPerBlock;
index_t K_split0 = (karg0.K + k_grain0 - 1) / k_grain0 * KPerBlock;
bool all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split0);
const auto tail_num = GridwiseGemm::CalculateKBlockLoopTailNum(K_split0);
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{ {
const auto& karg = arg.gemm_kernel_args_[i].karg_;
all_have_kbatch_gt_one = karg.KBatch > 1;
index_t k_grain = arg.gemm_kernel_args_[i].karg_.KBatch * KPerBlock;
index_t K_split =
(arg.gemm_kernel_args_[i].karg_.K + k_grain - 1) / k_grain * KPerBlock;
all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
if(stream_config.log_level_ > 0) if(stream_config.log_level_ > 0)
{ {
karg.Print(); karg.Print();
} }
auto kbatch = karg.KBatch; const auto& karg = arg.gemm_kernel_args_[i].karg_;
if(!GridwiseGemm::CheckValidity(karg)) index_t k_grain = karg.KBatch * KPerBlock;
index_t K_split = (karg.K + k_grain - 1) / k_grain * KPerBlock;
bool not_all_have_main_k0_block_loop_same =
all_have_main_k_block_loop xor
GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
bool not_all_have_tail_num_same =
(tail_num == GridwiseGemm::CalculateKBlockLoopTailNum(K_split));
if(not_all_have_main_k0_block_loop_same)
{ {
std::ostringstream err; std::ostringstream err;
err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__ err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__; << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str()); throw std::runtime_error(err.str());
} }
bool not_all_have_kbatch_value_same = all_have_kbatch_gt_one xor (kbatch > 1); if(not_all_have_tail_num_same)
{
std::ostringstream err;
err << "Not all gemms have same TailNumber value! in " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
if(not_all_have_kbatch_value_same) if(!GridwiseGemm::CheckValidity(karg))
{ {
std::ostringstream err; std::ostringstream err;
err << "Not all gemms have same kbatch value (=1 or >1)! " err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__
<< "group [" << i << "], kbatch: " << kbatch << ":" << __LINE__ << ", in function: " << __func__;
<< ", group [0], kbatch: " << arg.gemm_kernel_args_[0].karg_.KBatch
<< " in " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str()); throw std::runtime_error(err.str());
} }
} }
...@@ -418,64 +435,71 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -418,64 +435,71 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
float ave_time = 0; float ave_time = 0;
const auto Run = [&](const auto& kernel) { const auto Run = [&](const auto& kernel) {
if(all_have_kbatch_gt_one) if(stream_config.flush_cache)
{ {
for(const auto& trans_arg : arg.gemm_kernel_args_) const auto& arg_ = arg.gemm_kernel_args_[0].karg_;
{
const auto& karg = trans_arg.karg_; const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
hip_check_error(hipMemsetAsync(karg.p_c_grid, arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
0, const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
karg.M * karg.N * sizeof(EDataType), arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
stream_config.stream_id_));
} auto size_a_buffer =
a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
auto size_b_buffer =
b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
ck::utility::RotatingMemWrapper<Argument> rotating_mem(
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck::utility::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
// TODO: should be loop here through all groups
if(arg_.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
0,
arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
stream_config,
run_flush_cache,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.gemm_kernel_args_.size());
} }
else
for(const auto& trans_arg : arg.gemm_kernel_args_)
{ {
const auto& karg = trans_arg.karg_; // TODO: should be loop here through all groups
ave_time += launch_and_time_kernel( if(arg.KBatch > 1)
stream_config, kernel, dim3(arg.grid_size_), dim3(BlockSize), 0, karg); hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
ave_time = launch_and_time_kernel(
stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.gemm_kernel_args_.size());
} }
}; };
constexpr index_t minimum_occupancy = constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
// Calculate TailNumber for one
auto calculate_tail_number = [&]() {
index_t k_grain = arg.gemm_kernel_args_[0].karg_.KBatch * KPerBlock;
index_t K_split =
(arg.gemm_kernel_args_[0].karg_.K + k_grain - 1) / k_grain * KPerBlock;
return GridwiseGemm::CalculateKBlockLoopTailNum(K_split);
};
auto all_have_same_tail_number = [&]() {
// Calculate TailNumber for one
auto tail_number = calculate_tail_number();
// Calculate TailNumber for every other arg and compare
for(size_t i = 1; i < arg.gemm_kernel_args_.size(); ++i)
{
index_t k_grain = arg.gemm_kernel_args_[i].karg_.KBatch * KPerBlock;
index_t K_split =
(arg.gemm_kernel_args_[i].karg_.K + k_grain - 1) / k_grain * KPerBlock;
if(tail_number != GridwiseGemm::CalculateKBlockLoopTailNum(K_split))
{
return false;
}
}
return true;
};
auto throw_error = [&]() {
std::ostringstream err;
err << "Not all gemms have same TailNumber value! ";
throw std::runtime_error(err.str());
};
if(all_have_main_k_block_loop) if(all_have_main_k_block_loop)
{ {
// Tail number always full // Tail number always full
...@@ -485,19 +509,21 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -485,19 +509,21 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
if(all_have_kbatch_gt_one) if(all_have_kbatch_gt_one)
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
true, GemmTransKernelArg,
InMemoryDataOperationEnum::AtomicAdd, true,
minimum_occupancy>; InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel); Run(kernel);
} }
else else
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
true, GemmTransKernelArg,
InMemoryDataOperationEnum::Set, true,
minimum_occupancy>; InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel); Run(kernel);
} }
} }
...@@ -507,24 +533,19 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -507,24 +533,19 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
{ {
if(all_have_kbatch_gt_one) if(all_have_kbatch_gt_one)
{ {
if(calculate_tail_number() == TailNumber::One) if(tail_num == TailNumber::One)
{ {
if(all_have_same_tail_number()) const auto kernel =
{ kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
const auto kernel = kernel_gemm_xdl_cshuffle_v3< GemmTransKernelArg true,
GridwiseGemm, InMemoryDataOperationEnum::AtomicAdd,
true, minimum_occupancy,
InMemoryDataOperationEnum::AtomicAdd, TailNumber::One>;
minimum_occupancy, Run(kernel);
TailNumber::One>;
Run(kernel);
}
else
{
throw_error();
}
} }
//// TODO: Fix below as above!
else if(calculate_tail_number() == TailNumber::Full) else if(calculate_tail_number() == TailNumber::Full)
{ {
if(all_have_same_tail_number()) if(all_have_same_tail_number())
...@@ -1094,11 +1115,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -1094,11 +1115,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
std::vector<std::array<const void*, NumDTensor>>&, std::vector<std::array<const void*, NumDTensor>>&,
std::vector<void*>& p_c_grid, std::vector<void*>& p_c_grid,
std::vector<GemmDesc> gemm_descs, std::vector<GemmDesc> gemm_descs,
AElementwiseOperation, AElementwiseOperation a_element_op,
BElementwiseOperation, BElementwiseOperation b_element_op,
CDEElementwiseOperation) CDEElementwiseOperation cde_element_op)
{ {
return Argument{p_a_grid, p_b_grid, p_c_grid, gemm_descs}; return Argument{
p_a_grid, p_b_grid, p_c_grid, gemm_descs, a_element_op, b_element_op, cde_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -1110,11 +1132,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -1110,11 +1132,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
std::vector<std::array<const void*, NumDTensor>>&, std::vector<std::array<const void*, NumDTensor>>&,
std::vector<void*>& p_c_grid, std::vector<void*>& p_c_grid,
std::vector<GemmDesc>& gemm_descs, std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation, AElementwiseOperation a_element_op,
BElementwiseOperation, BElementwiseOperation b_element_op,
CDEElementwiseOperation) override CDEElementwiseOperation cde_element_op) override
{ {
return std::make_unique<Argument>(p_a_grid, p_b_grid, p_c_grid, gemm_descs); return std::make_unique<Argument>(
p_a_grid, p_b_grid, p_c_grid, gemm_descs, a_element_op, b_element_op, cde_element_op);
} }
// polymorphic // polymorphic
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment