Commit 7910f486 authored by Jianfeng yan's avatar Jianfeng yan
Browse files

DeviceGemmXdlSplit and DeviceGemmXdlSplitKCShuffle both work for arbitrary K

parent b5a9f642
...@@ -19,6 +19,11 @@ namespace ck { ...@@ -19,6 +19,11 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \see \link device_batched_gemm_xdl.hpp kernel_batched_gemm_xdlops_v2r3
*/
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
...@@ -159,15 +164,12 @@ struct DeviceGemmXdlSplitK ...@@ -159,15 +164,12 @@ struct DeviceGemmXdlSplitK
static constexpr auto K1Number = Number<K1>{}; static constexpr auto K1Number = Number<K1>{};
// static constexpr index_t Getk
static auto GetActualBatchAndKSplitted(index_t K, index_t KBatch) static auto GetActualBatchAndKSplitted(index_t K, index_t KBatch)
{ {
const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock; const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
const index_t KSplitted = K0 * K1; const index_t KSplitted = K0 * K1;
const index_t actual_batch = math::integer_divide_ceil(K, KSplitted); const index_t actual_batch = math::integer_divide_ceil(K, KSplitted);
// return std::make_pair<index_t, index_t>(actual_batch, KSplitted);
return std::make_pair(actual_batch, KSplitted); return std::make_pair(actual_batch, KSplitted);
} }
...@@ -251,8 +253,8 @@ struct DeviceGemmXdlSplitK ...@@ -251,8 +253,8 @@ struct DeviceGemmXdlSplitK
static auto MakeAGridDescriptor_K0_M_K1_Tail(index_t M, index_t K, index_t StrideA) static auto MakeAGridDescriptor_K0_M_K1_Tail(index_t M, index_t K, index_t StrideA)
{ {
const index_t KPad = math::integer_divide_ceil(K, K1 * K0PerBlock) * K1 * K0PerBlock; const index_t KPadded = math::integer_divide_ceil(K, K1 * K0PerBlock) * K1 * K0PerBlock;
const index_t K0 = KPad / K1; const index_t K0 = KPadded / K1;
const auto a_grid_desc_m_k = [&]() { const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
...@@ -267,7 +269,7 @@ struct DeviceGemmXdlSplitK ...@@ -267,7 +269,7 @@ struct DeviceGemmXdlSplitK
const auto a_grid_desc_m_kpad = transform_tensor_descriptor( const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k, a_grid_desc_m_k,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPadded - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -295,9 +297,9 @@ struct DeviceGemmXdlSplitK ...@@ -295,9 +297,9 @@ struct DeviceGemmXdlSplitK
static auto MakeBGridDescriptor_K0_N_K1_Tail(index_t K, index_t N, index_t StrideB) static auto MakeBGridDescriptor_K0_N_K1_Tail(index_t K, index_t N, index_t StrideB)
{ {
const index_t KPad = math::integer_divide_ceil(K, K1 * K0PerBlock) * K1 * K0PerBlock; const index_t KPadded = math::integer_divide_ceil(K, K1 * K0PerBlock) * K1 * K0PerBlock;
const index_t K0 = KPad / K1; const index_t K0 = KPadded / K1;
const auto b_grid_desc_k_n = [&]() { const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
...@@ -312,7 +314,7 @@ struct DeviceGemmXdlSplitK ...@@ -312,7 +314,7 @@ struct DeviceGemmXdlSplitK
const auto b_grid_desc_kpad_n = transform_tensor_descriptor( const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
b_grid_desc_k_n, b_grid_desc_k_n,
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), make_tuple(make_right_pad_transform(K, KPadded - K), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -672,26 +674,9 @@ struct DeviceGemmXdlSplitK ...@@ -672,26 +674,9 @@ struct DeviceGemmXdlSplitK
const bool tail_has_main_k0_block_loop = const bool tail_has_main_k0_block_loop =
GridwiseGemm::CalculateHasMainK0BlockLoop(K0_tail); GridwiseGemm::CalculateHasMainK0BlockLoop(K0_tail);
if(has_main_k0_block_loop && tail_has_main_k0_block_loop) const auto Run = [&](const auto& kernel)
{ {
const auto kernel = kernel_batched_gemm_xdlops_v2r3< return launch_and_time_kernel(kernel,
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1_Tail>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1_Tail>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
ComputePtrOffsetOfStridedBatch,
remove_reference_t<Block2CTileMap>,
true,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
...@@ -710,6 +695,30 @@ struct DeviceGemmXdlSplitK ...@@ -710,6 +695,30 @@ struct DeviceGemmXdlSplitK
arg.c_element_op_, arg.c_element_op_,
arg.compute_ptr_offset_of_batch_, arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
};
if(has_main_k0_block_loop && tail_has_main_k0_block_loop)
{
const auto kernel = kernel_batched_gemm_xdlops_v2r3<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1_Tail>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1_Tail>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
ComputePtrOffsetOfStridedBatch,
remove_reference_t<Block2CTileMap>,
true,
true>;
ave_time = Run(kernel);
} }
else if(has_main_k0_block_loop && !tail_has_main_k0_block_loop) else if(has_main_k0_block_loop && !tail_has_main_k0_block_loop)
{ {
...@@ -730,25 +739,7 @@ struct DeviceGemmXdlSplitK ...@@ -730,25 +739,7 @@ struct DeviceGemmXdlSplitK
true, true,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = Run(kernel);
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.a_grid_desc_k0_m_k1_tail_,
arg.b_grid_desc_k0_n_k1_tail_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
} }
else if(!has_main_k0_block_loop && tail_has_main_k0_block_loop) else if(!has_main_k0_block_loop && tail_has_main_k0_block_loop)
{ {
...@@ -769,25 +760,7 @@ struct DeviceGemmXdlSplitK ...@@ -769,25 +760,7 @@ struct DeviceGemmXdlSplitK
false, false,
true>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = Run(kernel);
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.a_grid_desc_k0_m_k1_tail_,
arg.b_grid_desc_k0_n_k1_tail_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
} }
else else
{ {
...@@ -808,25 +781,7 @@ struct DeviceGemmXdlSplitK ...@@ -808,25 +781,7 @@ struct DeviceGemmXdlSplitK
false, false,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = Run(kernel);
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.a_grid_desc_k0_m_k1_tail_,
arg.b_grid_desc_k0_n_k1_tail_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
} }
} }
else else
......
...@@ -19,6 +19,108 @@ namespace ck { ...@@ -19,6 +19,108 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \see \link device_batched_gemm_xdl.hpp kernel_batched_gemm_xdlops_v2r3
*/
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename AGridDesc_AK0_M_AK1_Tail,
typename BGridDesc_BK0_N_BK1_Tail,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename ComputePtrOffsetOfBatch,
typename Block2CTileMap,
bool HasMainKBlockLoop,
bool TailHasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const index_t batch_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const AGridDesc_AK0_M_AK1_Tail a_grid_desc_ak0_m_ak1_tail,
const BGridDesc_BK0_N_BK1_Tail b_grid_desc_bk0_n_bk1_tail,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
if(g_idx < batch_count - 1)
{
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_c_grid + c_batch_offset,
p_shared,
a_element_op,
b_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map);
}
else
{
GridwiseGemm::template Run<TailHasMainKBlockLoop>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_c_grid + c_batch_offset,
p_shared,
a_element_op,
b_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1_tail,
b_grid_desc_bk0_n_bk1_tail,
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map);
}
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = batch_count;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = a_grid_desc_ak0_m_ak1_tail;
ignore = b_grid_desc_bk0_n_bk1_tail;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = compute_ptr_offset_of_batch;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
...@@ -69,14 +171,127 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -69,14 +171,127 @@ struct DeviceGemmXdlSplitKCShuffle
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static auto GetKPad(index_t K1, index_t K, index_t KBatch) template <index_t K1>
static auto GetActualBatchAndKSplitted(index_t K, index_t KBatch)
{
const index_t K0PerBlock = KPerBlock / K1;
const index_t K0 = math::integer_divide_ceil(K, KPerBlock * KBatch) * K0PerBlock;
const index_t KSplitted = K0 * K1;
const index_t actual_batch = math::integer_divide_ceil(K, KSplitted);
return std::make_pair(actual_batch, KSplitted);
}
template <bool IsTail>
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA);
template <bool IsTail>
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB);
/*
* No padding in K
*/
template <>
static auto MakeAGridDescriptor_AK0_M_AK1<false>(index_t MRaw, index_t K, index_t StrideA)
{
assert(K % KPerBlock == 0);
assert(K % AK1 == 0);
const auto a_grid_desc_mraw_k = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, K), make_tuple(I1, StrideA));
}
}();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto MPad = M - MRaw;
const auto AK0 = K / AK1;
if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M, but not K
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else
{
// not pad M or K
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
}
template <>
static auto MakeBGridDescriptor_BK0_N_BK1<false>(index_t K, index_t NRaw, index_t StrideB)
{ {
const index_t K0 = math::integer_divide_ceil(K, K1 * KPerBlock * KBatch) * KPerBlock; assert(K % KPerBlock == 0);
const index_t KPad = KBatch * K0 * K1; assert(K % BK1 == 0);
return KPad;
const auto b_grid_desc_nraw_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, K), make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, K), make_tuple(StrideB, I1));
}
}();
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto NPad = N - NRaw;
const auto BK0 = K / BK1;
if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad N, but not K
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
} }
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) template <>
static auto MakeAGridDescriptor_AK0_M_AK1<true>(index_t MRaw, index_t KRaw, index_t StrideA)
{ {
const auto a_grid_desc_mraw_kraw = [&]() { const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>) if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
...@@ -96,14 +311,14 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -96,14 +311,14 @@ struct DeviceGemmXdlSplitKCShuffle
const auto MPad = M - MRaw; const auto MPad = M - MRaw;
const auto KPad = K - KRaw; const auto KPad = K - KRaw;
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
if constexpr(GemmSpec == GemmSpecialization::MKPadding || if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad both M and K // pad both M and K
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k = const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw, transform_tensor_descriptor(a_grid_desc_mraw_kraw,
...@@ -121,31 +336,9 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -121,31 +336,9 @@ struct DeviceGemmXdlSplitKCShuffle
return a_grid_desc_ak0_m_ak1; return a_grid_desc_ak0_m_ak1;
} }
else if constexpr(GemmSpec == GemmSpecialization::MPadding || else
GemmSpec == GemmSpecialization::MNPadding)
{
// pad M, but not K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{ {
// pad K, but not M // pad K, but not M
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k = transform_tensor_descriptor( const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw, a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
...@@ -161,25 +354,10 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -161,25 +354,10 @@ struct DeviceGemmXdlSplitKCShuffle
return a_grid_desc_ak0_m_ak1; return a_grid_desc_ak0_m_ak1;
} }
else
{
// not pad M or K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
} }
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) template <>
static auto MakeBGridDescriptor_BK0_N_BK1<true>(index_t KRaw, index_t NRaw, index_t StrideB)
{ {
const auto b_grid_desc_nraw_kraw = [&]() { const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
...@@ -200,14 +378,13 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -200,14 +378,13 @@ struct DeviceGemmXdlSplitKCShuffle
const auto NPad = N - NRaw; const auto NPad = N - NRaw;
const auto KPad = K - KRaw; const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::NKPadding || assert(K % BK1 == 0);
GemmSpec == GemmSpecialization::MNKPadding) const auto BK0 = K / BK1;
{
// pad both N and K
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{ // pad both N and K
const auto b_grid_desc_n_k = const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw, transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad), make_tuple(make_right_pad_transform(NRaw, NPad),
...@@ -224,31 +401,8 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -224,31 +401,8 @@ struct DeviceGemmXdlSplitKCShuffle
return b_grid_desc_bk0_n_bk1; return b_grid_desc_bk0_n_bk1;
} }
else if constexpr(GemmSpec == GemmSpecialization::NPadding || else // pad K, but not N
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{ {
// pad K, but not N
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k = transform_tensor_descriptor( const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw, b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
...@@ -264,22 +418,6 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -264,22 +418,6 @@ struct DeviceGemmXdlSplitKCShuffle
return b_grid_desc_bk0_n_bk1; return b_grid_desc_bk0_n_bk1;
} }
else
{
// not pad N or K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
} }
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
...@@ -340,10 +478,11 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -340,10 +478,11 @@ struct DeviceGemmXdlSplitKCShuffle
} }
} }
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1<false>(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1<false>(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using AGridDesc_AK0_M_AK1_Tail = decltype(MakeAGridDescriptor_AK0_M_AK1<true>(1, 1, 1));
using BGridDesc_BK0_N_BK1_Tail = decltype(MakeBGridDescriptor_BK0_N_BK1<true>(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
struct ComputePtrOffsetOfStridedBatch struct ComputePtrOffsetOfStridedBatch
{ {
...@@ -418,7 +557,8 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -418,7 +557,8 @@ struct DeviceGemmXdlSplitKCShuffle
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = decltype( using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = decltype(
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
using Block2CTileMap = decltype(BatchedGemmUtil::MakeBlock2CTileMap<MPerBlock, NPerBlock>(1, 1, 1)); using Block2CTileMap =
decltype(BatchedGemmUtil::MakeBlock2CTileMap<MPerBlock, NPerBlock>(1, 1, 1));
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -445,21 +585,40 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -445,21 +585,40 @@ struct DeviceGemmXdlSplitKCShuffle
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op}
{ {
const auto AKPad = GetKPad(AK1, KRaw, k_batch); const auto actual_batch_and_ksplitted_A =
assert(AKPad % k_batch == 0); GetActualBatchAndKSplitted<AK1>(KRaw, k_batch);
const auto BKPad = GetKPad(BK1, KRaw, k_batch); const auto actual_batch_and_ksplitted_B =
assert(BKPad % k_batch == 0); GetActualBatchAndKSplitted<BK1>(KRaw, k_batch);
const auto AKSplitted = AKPad / k_batch; assert(actual_batch_and_ksplitted_A.first == actual_batch_and_ksplitted_B.first);
const auto BKSplitted = BKPad / k_batch; BatchCount_ = actual_batch_and_ksplitted_A.first;
const auto AKSplitted = actual_batch_and_ksplitted_A.second;
const auto BKSplitted = actual_batch_and_ksplitted_B.second;
a_grid_desc_ak0_m_ak1_ = a_grid_desc_ak0_m_ak1_ =
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, AKSplitted, StrideA); DeviceOp::MakeAGridDescriptor_AK0_M_AK1<false>(MRaw, AKSplitted, StrideA);
b_grid_desc_bk0_n_bk1_ = b_grid_desc_bk0_n_bk1_ =
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(BKSplitted, NRaw, StrideB); DeviceOp::MakeBGridDescriptor_BK0_N_BK1<false>(BKSplitted, NRaw, StrideB);
c_grid_desc_m_n_ = DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC); c_grid_desc_m_n_ = DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC);
if(GridwiseGemm::CheckValidity( is_valid_ = GridwiseGemm::CheckValidity(
a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_)) a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_);
if(KRaw != AKSplitted * BatchCount_ || KRaw != BKSplitted * BatchCount_)
{
has_tail_ = true;
const auto AKTail = KRaw - AKSplitted * (BatchCount_ - 1);
const auto BKTail = KRaw - BKSplitted * (BatchCount_ - 1);
a_grid_desc_ak0_m_ak1_tail_ =
DeviceOp::MakeAGridDescriptor_AK0_M_AK1<true>(MRaw, AKTail, StrideA);
b_grid_desc_bk0_n_bk1_tail_ =
DeviceOp::MakeBGridDescriptor_BK0_N_BK1<true>(BKTail, NRaw, StrideB);
is_valid_ &= GridwiseGemm::CheckValidity(
a_grid_desc_ak0_m_ak1_tail_, b_grid_desc_bk0_n_bk1_tail_, c_grid_desc_m_n_);
}
if(is_valid_)
{ {
c_grid_desc_mblock_mperblock_nblock_nperblock_ = c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
...@@ -492,7 +651,8 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -492,7 +651,8 @@ struct DeviceGemmXdlSplitKCShuffle
compute_ptr_offset_of_batch_ = compute_ptr_offset_of_batch_ =
ComputePtrOffsetOfStridedBatch{a_batch_stride, b_batch_stride}; ComputePtrOffsetOfStridedBatch{a_batch_stride, b_batch_stride};
block_2_ctile_map_ = BatchedGemmUtil::MakeBlock2CTileMap<MPerBlock, NPerBlock>(BatchCount_, c_grid_desc_m_n_.GetLength(I0), c_grid_desc_m_n_.GetLength(I1)); block_2_ctile_map_ = BatchedGemmUtil::MakeBlock2CTileMap<MPerBlock, NPerBlock>(
BatchCount_, c_grid_desc_m_n_.GetLength(I0), c_grid_desc_m_n_.GetLength(I1));
} }
} }
...@@ -501,8 +661,12 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -501,8 +661,12 @@ struct DeviceGemmXdlSplitKCShuffle
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
index_t BatchCount_; index_t BatchCount_;
bool has_tail_;
bool is_valid_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
AGridDesc_AK0_M_AK1_Tail a_grid_desc_ak0_m_ak1_tail_;
BGridDesc_BK0_N_BK1_Tail b_grid_desc_bk0_n_bk1_tail_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_; c_grid_desc_mblock_mperblock_nblock_nperblock_;
...@@ -534,10 +698,23 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -534,10 +698,23 @@ struct DeviceGemmXdlSplitKCShuffle
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
if(arg.has_tail_)
{
std::cout << "arg.a_grid_desc_ak0_m_ak1_tail_{"
<< arg.a_grid_desc_ak0_m_ak1_tail_.GetLength(I0) << ", "
<< arg.a_grid_desc_ak0_m_ak1_tail_.GetLength(I1) << ", "
<< arg.a_grid_desc_ak0_m_ak1_tail_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_bk0_n_bk1_tail_{"
<< arg.b_grid_desc_bk0_n_bk1_tail_.GetLength(I0) << ", "
<< arg.b_grid_desc_bk0_n_bk1_tail_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_tail_.GetLength(I2) << "}" << std::endl;
}
} }
if(!GridwiseGemm::CheckValidity( if(!arg.is_valid_)
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseBatchedGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); "wrong! GridwiseBatchedGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
...@@ -546,127 +723,233 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -546,127 +723,233 @@ struct DeviceGemmXdlSplitKCShuffle
const index_t grid_size = const index_t grid_size =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_; GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_;
const auto K0 = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0); const auto K0 = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
float ave_time = 0; float ave_time = 0;
if(has_main_k0_block_loop) if(arg.has_tail_)
{ {
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v1< const auto K0_tail = arg.a_grid_desc_ak0_m_ak1_tail_.GetLength(I0);
GridwiseGemm, const bool tail_has_main_k0_block_loop =
ADataType, // TODO: distiguish A/B datatype GridwiseGemm::CalculateHasMainK0BlockLoop(K0_tail);
CDataType,
AElementwiseOperation, const auto Run = [&](const auto& kernel) {
BElementwiseOperation, if(nrepeat == 0)
CElementwiseOperation, {
DeviceOp::AGridDesc_AK0_M_AK1, launch_kernel(kernel,
DeviceOp::BGridDesc_BK0_N_BK1, dim3(grid_size),
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, dim3(BlockSize),
ComputePtrOffsetOfStridedBatch, 0,
Block2CTileMap, arg.p_a_grid_,
true>; arg.p_b_grid_,
arg.p_c_grid_,
if(nrepeat == 0) arg.BatchCount_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.a_grid_desc_ak0_m_ak1_tail_,
arg.b_grid_desc_bk0_n_bk1_tail_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
return 0.0f;
}
else
{
return launch_and_time_kernel(
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.a_grid_desc_ak0_m_ak1_tail_,
arg.b_grid_desc_bk0_n_bk1_tail_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
}
};
if(has_main_k0_block_loop && tail_has_main_k0_block_loop)
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v1<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::AGridDesc_AK0_M_AK1_Tail,
DeviceOp::BGridDesc_BK0_N_BK1_Tail,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffsetOfStridedBatch,
Block2CTileMap,
true,
true>;
ave_time = Run(kernel);
}
else if(has_main_k0_block_loop && !tail_has_main_k0_block_loop)
{ {
launch_kernel(kernel, const auto kernel = kernel_batched_gemm_xdl_cshuffle_v1<
dim3(grid_size), GridwiseGemm,
dim3(BlockSize), ADataType, // TODO: distiguish A/B datatype
0, CDataType,
arg.p_a_grid_, AElementwiseOperation,
arg.p_b_grid_, BElementwiseOperation,
arg.p_c_grid_, CElementwiseOperation,
arg.BatchCount_, DeviceOp::AGridDesc_AK0_M_AK1,
arg.a_element_op_, DeviceOp::BGridDesc_BK0_N_BK1,
arg.b_element_op_, DeviceOp::AGridDesc_AK0_M_AK1_Tail,
arg.c_element_op_, DeviceOp::BGridDesc_BK0_N_BK1_Tail,
arg.a_grid_desc_ak0_m_ak1_, CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
arg.b_grid_desc_bk0_n_bk1_, ComputePtrOffsetOfStridedBatch,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, Block2CTileMap,
arg.compute_ptr_offset_of_batch_, true,
arg.block_2_ctile_map_); false>;
ave_time = Run(kernel);
}
else if(!has_main_k0_block_loop && tail_has_main_k0_block_loop)
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v1<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::AGridDesc_AK0_M_AK1_Tail,
DeviceOp::BGridDesc_BK0_N_BK1_Tail,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffsetOfStridedBatch,
Block2CTileMap,
false,
true>;
ave_time = Run(kernel);
} }
else else
{ {
ave_time = const auto kernel = kernel_batched_gemm_xdl_cshuffle_v1<
launch_and_time_kernel(kernel, GridwiseGemm,
nrepeat, ADataType, // TODO: distiguish A/B datatype
dim3(grid_size), CDataType,
dim3(BlockSize), AElementwiseOperation,
0, BElementwiseOperation,
arg.p_a_grid_, CElementwiseOperation,
arg.p_b_grid_, DeviceOp::AGridDesc_AK0_M_AK1,
arg.p_c_grid_, DeviceOp::BGridDesc_BK0_N_BK1,
arg.BatchCount_, DeviceOp::AGridDesc_AK0_M_AK1_Tail,
arg.a_element_op_, DeviceOp::BGridDesc_BK0_N_BK1_Tail,
arg.b_element_op_, CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
arg.c_element_op_, ComputePtrOffsetOfStridedBatch,
arg.a_grid_desc_ak0_m_ak1_, Block2CTileMap,
arg.b_grid_desc_bk0_n_bk1_, false,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, false>;
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_); ave_time = Run(kernel);
} }
} }
else else
{ {
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v1< const auto Run = [&](const auto& kernel) {
GridwiseGemm, if(nrepeat == 0)
ADataType, // TODO: distiguish A/B datatype {
CDataType, launch_kernel(kernel,
AElementwiseOperation, dim3(grid_size),
BElementwiseOperation, dim3(BlockSize),
CElementwiseOperation, 0,
DeviceOp::AGridDesc_AK0_M_AK1, arg.p_a_grid_,
DeviceOp::BGridDesc_BK0_N_BK1, arg.p_b_grid_,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, arg.p_c_grid_,
ComputePtrOffsetOfStridedBatch, arg.BatchCount_,
Block2CTileMap, arg.a_element_op_,
false>; arg.b_element_op_,
arg.c_element_op_,
if(nrepeat == 0) arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
return 0.0f;
}
else
{
return launch_and_time_kernel(
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
}
};
if(has_main_k0_block_loop)
{ {
launch_kernel(kernel, const auto kernel = ck::kernel_batched_gemm_xdl_cshuffle_v1<
dim3(grid_size), GridwiseGemm,
dim3(BlockSize), ADataType, // TODO: distiguish A/B datatype
0, CDataType,
arg.p_a_grid_, AElementwiseOperation,
arg.p_b_grid_, BElementwiseOperation,
arg.p_c_grid_, CElementwiseOperation,
arg.BatchCount_, DeviceOp::AGridDesc_AK0_M_AK1,
arg.a_element_op_, DeviceOp::BGridDesc_BK0_N_BK1,
arg.b_element_op_, CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
arg.c_element_op_, ComputePtrOffsetOfStridedBatch,
arg.a_grid_desc_ak0_m_ak1_, Block2CTileMap,
arg.b_grid_desc_bk0_n_bk1_, true>;
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_, ave_time = Run(kernel);
arg.block_2_ctile_map_);
} }
else else
{ {
ave_time = const auto kernel = ck::kernel_batched_gemm_xdl_cshuffle_v1<
launch_and_time_kernel(kernel, GridwiseGemm,
nrepeat, ADataType, // TODO: distiguish A/B datatype
dim3(grid_size), CDataType,
dim3(BlockSize), AElementwiseOperation,
0, BElementwiseOperation,
arg.p_a_grid_, CElementwiseOperation,
arg.p_b_grid_, DeviceOp::AGridDesc_AK0_M_AK1,
arg.p_c_grid_, DeviceOp::BGridDesc_BK0_N_BK1,
arg.BatchCount_, CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
arg.a_element_op_, ComputePtrOffsetOfStridedBatch,
arg.b_element_op_, Block2CTileMap,
arg.c_element_op_, false>;
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, ave_time = Run(kernel);
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
} }
} }
return ave_time; return ave_time;
} }
...@@ -781,7 +1064,7 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -781,7 +1064,7 @@ struct DeviceGemmXdlSplitKCShuffle
return str.str(); return str.str();
} }
}; }; // namespace device
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -142,6 +142,7 @@ __global__ void ...@@ -142,6 +142,7 @@ __global__ void
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
template <typename FloatAB, template <typename FloatAB,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment