Unverified Commit 303d4594 authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

improved zeroing (#1221)

parent 5f2c89e8
...@@ -23,8 +23,8 @@ add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16) ...@@ -23,8 +23,8 @@ add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16)
add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp) add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8)
add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fixed_nk_fp8.cpp) add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp16_fp8 grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16_fp8)
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp) add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp)
......
...@@ -36,7 +36,7 @@ using BDataType = F16; ...@@ -36,7 +36,7 @@ using BDataType = F16;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using DsDataType = ck::Tuple<>; using DsDataType = ck::Tuple<>;
using EDataType = F32; using EDataType = F16;
using ALayout = Row; using ALayout = Row;
using BLayout = Col; using BLayout = Col;
...@@ -298,9 +298,9 @@ int main(int argc, char* argv[]) ...@@ -298,9 +298,9 @@ int main(int argc, char* argv[])
for(int i = 0; i < problem_size.group_count; i++) for(int i = 0; i < problem_size.group_count; i++)
{ {
problem_size.Ms.push_back(256 + 256 * i); problem_size.Ms.push_back(128 + rand() % 128);
problem_size.Ns.push_back(256); problem_size.Ns.push_back(1024);
problem_size.Ks.push_back(128); problem_size.Ks.push_back(1024);
problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]);
......
...@@ -35,7 +35,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -35,7 +35,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16; using ADataType = F16;
using BDataType = F8; using BDataType = F8;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F16;
using DsDataType = ck::Tuple<>; using DsDataType = ck::Tuple<>;
using EDataType = F16; using EDataType = F16;
......
...@@ -23,6 +23,7 @@ namespace device { ...@@ -23,6 +23,7 @@ namespace device {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename GemmDesc, typename GemmDesc,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
bool Zeroing,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename DsLayout, typename DsLayout,
...@@ -106,8 +107,37 @@ __global__ void ...@@ -106,8 +107,37 @@ __global__ void
const auto block_2_etile_map = const auto block_2_etile_map =
GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off); GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off);
if constexpr(Zeroing)
{
auto barrier_count_finished = auto barrier_count_finished =
barrier_count + group_id * barrier_size_grp + id_local % mn_blocks; barrier_count + group_id * barrier_size_grp + id_local % mn_blocks;
GridwiseGemm::template RunWithZeroing<HasMainKBlockLoop,
EGlobalMemoryDataOperation,
GemmSpec,
ALayout,
BLayout,
DsLayout,
ELayout>(gemm_desc_ptr[group_id].p_a_grid,
gemm_desc_ptr[group_id].p_b_grid,
p_ds_grid_,
gemm_desc_ptr[group_id].p_e_grid,
p_shared,
barrier_count_finished,
a_element_op,
b_element_op,
c_element_op,
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
KBatch,
block_2_etile_map);
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop, GridwiseGemm::template Run<HasMainKBlockLoop,
EGlobalMemoryDataOperation, EGlobalMemoryDataOperation,
...@@ -120,7 +150,7 @@ __global__ void ...@@ -120,7 +150,7 @@ __global__ void
p_ds_grid_, p_ds_grid_,
gemm_desc_ptr[group_id].p_e_grid, gemm_desc_ptr[group_id].p_e_grid,
p_shared, p_shared,
barrier_count_finished, nullptr,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
...@@ -133,6 +163,7 @@ __global__ void ...@@ -133,6 +163,7 @@ __global__ void
StrideE, StrideE,
KBatch, KBatch,
block_2_etile_map); block_2_etile_map);
}
id_off += grid_size_grp; id_off += grid_size_grp;
id_local += grid_size_grp; id_local += grid_size_grp;
...@@ -193,8 +224,11 @@ template <typename ALayout, ...@@ -193,8 +224,11 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
PipelineVersion PipelineVer = PipelineVersion::v1,
LoopScheduler LoopSched = make_default_loop_scheduler(),
typename ComputeType = ADataType, typename ComputeType = ADataType,
LoopScheduler LoopSched = make_default_loop_scheduler()> typename ALDSType = ComputeType,
typename BLDSType = ComputeType>
struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
BLayout, BLayout,
DsLayout, DsLayout,
...@@ -215,11 +249,15 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -215,11 +249,15 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
using AComputeType = ComputeType;
using BComputeType = ComputeType;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_splitk_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_xdl_splitk_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
BDataType, BDataType,
ComputeType, AComputeType,
BComputeType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
DsDataType, DsDataType,
...@@ -258,7 +296,10 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -258,7 +296,10 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched,
PipelineVer,
ALDSType,
BLDSType>;
template <typename UnderlyingBlockToCTileMap> template <typename UnderlyingBlockToCTileMap>
struct OffsettedBlockToCTileMapMLoops struct OffsettedBlockToCTileMapMLoops
...@@ -613,10 +654,49 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -613,10 +654,49 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_, auto e_global_memory_operation_) { auto launch_kernel = [&](auto has_main_k_block_loop_, auto e_global_memory_operation_) {
if(arg.k_batch_ == 1)
{
const auto kernel =
kernel_grouped_gemm_xdl_fixed_nk<GridwiseGemm,
GroupedGemmKernelArgument<NumDTensor>,
GemmSpec,
false,
ALayout,
BLayout,
DsLayout,
ELayout,
DsDataType,
Block2ETileMap,
GroupedGemmBlock2ETileMap,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
e_global_memory_operation_,
has_main_k_block_loop_>;
return launch_and_time_kernel(
stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev),
nullptr,
arg.barrier_size_grp_,
arg.gemm_desc_kernel_arg_.size(),
arg.grid_size_grp_,
arg.k_batch_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
}
else
{
const auto kernel = const auto kernel =
kernel_grouped_gemm_xdl_fixed_nk<GridwiseGemm, kernel_grouped_gemm_xdl_fixed_nk<GridwiseGemm,
GroupedGemmKernelArgument<NumDTensor>, GroupedGemmKernelArgument<NumDTensor>,
GemmSpec, GemmSpec,
true,
ALayout, ALayout,
BLayout, BLayout,
DsLayout, DsLayout,
...@@ -645,13 +725,14 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -645,13 +725,14 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_); arg.c_element_op_);
}
}; };
constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd; constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd;
constexpr auto Set = InMemoryDataOperationEnum::Set; constexpr auto Set = InMemoryDataOperationEnum::Set;
// For bf16 datatype only kbatch = 1 scenario is supported. This condition is enforced // For bf16 datatype only kbatch = 1 scenario is supported. This condition is
// in IsSupportedArgument function // enforced in IsSupportedArgument function
if constexpr(std::is_same<ADataType, ck::bhalf_t>::value) if constexpr(std::is_same<ADataType, ck::bhalf_t>::value)
{ {
if(has_main_k_block_loop) if(has_main_k_block_loop)
...@@ -719,12 +800,12 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -719,12 +800,12 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
bool supported = true; bool supported = true;
// If we use padding we do not support vector loads for dimensions not divisible by vector // If we use padding we do not support vector loads for dimensions not divisible by
// load size. // vector load size.
if constexpr(GemmSpec != GemmSpecialization::Default) if constexpr(GemmSpec != GemmSpecialization::Default)
{ {
// [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout, // [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1}
// thus we have to adapt it to the {M,K} or {N,K} layout. // layout, thus we have to adapt it to the {M,K} or {N,K} layout.
const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0; const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0;
const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0; const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0;
......
...@@ -31,7 +31,8 @@ namespace ck { ...@@ -31,7 +31,8 @@ namespace ck {
// D0, D1, ... and E have the same layout // D0, D1, ... and E have the same layout
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename ComputeType, typename AComputeType,
typename BComputeType,
typename AccDataType, typename AccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename DsDataType, typename DsDataType,
...@@ -71,7 +72,9 @@ template <typename ADataType, ...@@ -71,7 +72,9 @@ template <typename ADataType,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock, index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched, LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer,
typename ALDSType,
typename BLDSType>
struct GridwiseGemmMultipleD_xdl_splitk_cshuffle struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
{ {
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
...@@ -186,8 +189,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -186,8 +189,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
constexpr auto c_block_size = constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * return math::max(a_block_space_size_aligned * sizeof(ALDSType) +
sizeof(ComputeType), b_block_space_size_aligned * sizeof(BLDSType),
c_block_size * sizeof(CShuffleDataType)); c_block_size * sizeof(CShuffleDataType));
} }
...@@ -455,6 +458,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -455,6 +458,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
InMemoryDataOperationEnum EGlobalMemoryDataOperation, InMemoryDataOperationEnum EGlobalMemoryDataOperation,
index_t NumDTensor_, index_t NumDTensor_,
typename DsDataType_, typename DsDataType_,
bool Zeroing,
typename AGridDesc_KBatch_AK0_M_AK1, typename AGridDesc_KBatch_AK0_M_AK1,
typename BGridDesc_KBatch_BK0_N_BK1, typename BGridDesc_KBatch_BK0_N_BK1,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -530,7 +534,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -530,7 +534,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ADataType, ADataType,
ComputeType, ALDSType,
decltype(a_grid_desc_kbatch_ak0_m_ak1), decltype(a_grid_desc_kbatch_ak0_m_ak1),
decltype(a_block_desc_kbatch_ak0_m_ak1), decltype(a_block_desc_kbatch_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -561,7 +565,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -561,7 +565,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
BDataType, BDataType,
ComputeType, BLDSType,
decltype(b_grid_desc_kbatch_bk0_n_bk1), decltype(b_grid_desc_kbatch_bk0_n_bk1),
decltype(b_block_desc_kbatch_bk0_n_bk1), decltype(b_block_desc_kbatch_bk0_n_bk1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -597,12 +601,12 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -597,12 +601,12 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
// sanity check // sanity check
constexpr index_t KPack = constexpr index_t KPack =
math::max(math::lcm(AK1, BK1), math::max(math::lcm(AK1, BK1),
MfmaSelector<ComputeType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<AComputeType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
ComputeType, ALDSType,
ComputeType, BLDSType,
AccDataType, AccDataType,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
...@@ -611,9 +615,12 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -611,9 +615,12 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
MXdlPerWave, MXdlPerWave,
NXdlPerWave, NXdlPerWave,
KPack, KPack,
LoopSched>(); LoopSched,
AComputeType,
BComputeType>();
#if 1 if constexpr(Zeroing)
{
if(block_work_idx[I0] == 0) if(block_work_idx[I0] == 0)
{ {
const index_t nThreadSize = CDEShuffleBlockTransferScalarPerVector_NPerBlock; const index_t nThreadSize = CDEShuffleBlockTransferScalarPerVector_NPerBlock;
...@@ -659,14 +666,14 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -659,14 +666,14 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_buf); e_grid_buf);
__syncthreads(); __builtin_amdgcn_s_barrier();
if(threadIdx.x == 0) if(threadIdx.x == 0)
{ {
atomicAdd(barrier_count_finished, 1); atomicAdd(barrier_count_finished, 1);
} }
} }
#endif }
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
...@@ -675,10 +682,10 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -675,10 +682,10 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<ALDSType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeType*>(p_shared) + a_block_space_size_aligned, static_cast<BLDSType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0);
...@@ -710,13 +717,15 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -710,13 +717,15 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
num_k_block_main_loop); num_k_block_main_loop);
// shuffle C and write out // shuffle C and write out
{
if constexpr(Zeroing)
{ {
if(threadIdx.x == 0) if(threadIdx.x == 0)
{ {
while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {} while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {}
} }
__builtin_amdgcn_s_barrier();
__syncthreads(); }
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
...@@ -951,6 +960,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -951,6 +960,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
} }
}); });
if constexpr(Zeroing)
{
if(threadIdx.x == 0) if(threadIdx.x == 0)
{ {
index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1); index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1);
...@@ -962,6 +973,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -962,6 +973,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
} }
} }
} }
}
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation, InMemoryDataOperationEnum EGlobalMemoryDataOperation,
...@@ -971,7 +983,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -971,7 +983,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
typename DsLayout, typename DsLayout,
typename ELayout, typename ELayout,
typename Block2ETileMap> typename Block2ETileMap>
__device__ static void Run(const void* __restrict__ p_a_grid_, __device__ static void RunWithZeroing(const void* __restrict__ p_a_grid_,
const void* __restrict__ p_b_grid_, const void* __restrict__ p_b_grid_,
DsGridPointer p_ds_grid, DsGridPointer p_ds_grid,
void* __restrict__ p_e_grid_, void* __restrict__ p_e_grid_,
...@@ -1035,7 +1047,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -1035,7 +1047,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
if(kbatch_id == KBatch - 1) if(kbatch_id == KBatch - 1)
{ {
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, NumDTensor, DsDataType>( Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, NumDTensor, DsDataType, true>(
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_ds_grid, p_ds_grid,
...@@ -1054,7 +1066,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -1054,7 +1066,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
} }
else else
{ {
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, 0, Tuple<>>( Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, 0, Tuple<>, true>(
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_ds_grid, p_ds_grid,
...@@ -1072,6 +1084,89 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -1072,6 +1084,89 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
block_2_etile_map); block_2_etile_map);
} }
} }
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
GemmSpecialization GemmSpec,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename Block2ETileMap>
__device__ static void Run(const void* __restrict__ p_a_grid_,
const void* __restrict__ p_b_grid_,
DsGridPointer p_ds_grid,
void* __restrict__ p_e_grid_,
void* __restrict__ p_shared,
uint32_t*,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const index_t M,
const index_t N,
const index_t K,
const index_t StrideA,
const index_t StrideB,
const std::array<index_t, NumDTensor> StrideDs,
const index_t StrideE,
const index_t KBatch,
const Block2ETileMap& block_2_etile_map)
{
const auto p_a_grid = reinterpret_cast<const ADataType*>(p_a_grid_);
const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_);
const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_);
using DsGridDesc_M_N =
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>({}, {}, {}))>;
DsGridDesc_M_N ds_grid_desc_m_n;
static_for<0, NumDTensor, 1>{}([&](auto j) {
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N<DLayout, GemmSpec>(M, N, StrideDs[j]);
});
const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
// tensor descriptors for block/thread-wise copy
const auto a_grid_desc_kbatch_ak0_m_ak1 =
MakeAGridDescriptor_KBatch_AK0_M_AK1<ALayout, GemmSpec>(M, K, StrideA, KBatch);
const auto b_grid_desc_kbatch_bk0_n_bk1 =
MakeBGridDescriptor_KBatch_BK0_N_BK1<BLayout, GemmSpec>(K, N, StrideB, KBatch);
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock;
static_for<0, NumDTensor, 1>{}([&](auto j) {
ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]);
});
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, NumDTensor, DsDataType, false>(
p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared,
nullptr,
KBatch,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_kbatch_ak0_m_ak1,
b_grid_desc_kbatch_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
}
}; };
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment