Commit 7910f486 authored by Jianfeng yan's avatar Jianfeng yan
Browse files

DeviceGemmXdlSplit and DeviceGemmXdlSplitKCShuffle both work for arbitrary K

parent b5a9f642
......@@ -19,6 +19,11 @@ namespace ck {
namespace tensor_operation {
namespace device {
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \see \link device_batched_gemm_xdl.hpp kernel_batched_gemm_xdlops_v2r3
*/
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
......@@ -159,15 +164,12 @@ struct DeviceGemmXdlSplitK
static constexpr auto K1Number = Number<K1>{};
// static constexpr index_t Getk
static auto GetActualBatchAndKSplitted(index_t K, index_t KBatch)
{
const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
const index_t KSplitted = K0 * K1;
const index_t actual_batch = math::integer_divide_ceil(K, KSplitted);
// return std::make_pair<index_t, index_t>(actual_batch, KSplitted);
return std::make_pair(actual_batch, KSplitted);
}
......@@ -251,8 +253,8 @@ struct DeviceGemmXdlSplitK
static auto MakeAGridDescriptor_K0_M_K1_Tail(index_t M, index_t K, index_t StrideA)
{
const index_t KPad = math::integer_divide_ceil(K, K1 * K0PerBlock) * K1 * K0PerBlock;
const index_t K0 = KPad / K1;
const index_t KPadded = math::integer_divide_ceil(K, K1 * K0PerBlock) * K1 * K0PerBlock;
const index_t K0 = KPadded / K1;
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
......@@ -267,7 +269,7 @@ struct DeviceGemmXdlSplitK
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPadded - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -295,9 +297,9 @@ struct DeviceGemmXdlSplitK
static auto MakeBGridDescriptor_K0_N_K1_Tail(index_t K, index_t N, index_t StrideB)
{
const index_t KPad = math::integer_divide_ceil(K, K1 * K0PerBlock) * K1 * K0PerBlock;
const index_t KPadded = math::integer_divide_ceil(K, K1 * K0PerBlock) * K1 * K0PerBlock;
const index_t K0 = KPad / K1;
const index_t K0 = KPadded / K1;
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
......@@ -312,7 +314,7 @@ struct DeviceGemmXdlSplitK
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)),
make_tuple(make_right_pad_transform(K, KPadded - K), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -672,26 +674,9 @@ struct DeviceGemmXdlSplitK
const bool tail_has_main_k0_block_loop =
GridwiseGemm::CalculateHasMainK0BlockLoop(K0_tail);
if(has_main_k0_block_loop && tail_has_main_k0_block_loop)
const auto Run = [&](const auto& kernel)
{
const auto kernel = kernel_batched_gemm_xdlops_v2r3<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1_Tail>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1_Tail>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
ComputePtrOffsetOfStridedBatch,
remove_reference_t<Block2CTileMap>,
true,
true>;
ave_time = launch_and_time_kernel(kernel,
return launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
......@@ -710,6 +695,30 @@ struct DeviceGemmXdlSplitK
arg.c_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
};
if(has_main_k0_block_loop && tail_has_main_k0_block_loop)
{
const auto kernel = kernel_batched_gemm_xdlops_v2r3<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1_Tail>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1_Tail>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
ComputePtrOffsetOfStridedBatch,
remove_reference_t<Block2CTileMap>,
true,
true>;
ave_time = Run(kernel);
}
else if(has_main_k0_block_loop && !tail_has_main_k0_block_loop)
{
......@@ -730,25 +739,7 @@ struct DeviceGemmXdlSplitK
true,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.a_grid_desc_k0_m_k1_tail_,
arg.b_grid_desc_k0_n_k1_tail_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
ave_time = Run(kernel);
}
else if(!has_main_k0_block_loop && tail_has_main_k0_block_loop)
{
......@@ -769,25 +760,7 @@ struct DeviceGemmXdlSplitK
false,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.a_grid_desc_k0_m_k1_tail_,
arg.b_grid_desc_k0_n_k1_tail_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
ave_time = Run(kernel);
}
else
{
......@@ -808,25 +781,7 @@ struct DeviceGemmXdlSplitK
false,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.a_grid_desc_k0_m_k1_tail_,
arg.b_grid_desc_k0_n_k1_tail_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
ave_time = Run(kernel);
}
}
else
......
......@@ -19,6 +19,108 @@ namespace ck {
namespace tensor_operation {
namespace device {
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \see \link device_batched_gemm_xdl.hpp kernel_batched_gemm_xdlops_v2r3
*/
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename AGridDesc_AK0_M_AK1_Tail,
typename BGridDesc_BK0_N_BK1_Tail,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename ComputePtrOffsetOfBatch,
typename Block2CTileMap,
bool HasMainKBlockLoop,
bool TailHasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const index_t batch_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const AGridDesc_AK0_M_AK1_Tail a_grid_desc_ak0_m_ak1_tail,
const BGridDesc_BK0_N_BK1_Tail b_grid_desc_bk0_n_bk1_tail,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
if(g_idx < batch_count - 1)
{
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_c_grid + c_batch_offset,
p_shared,
a_element_op,
b_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map);
}
else
{
GridwiseGemm::template Run<TailHasMainKBlockLoop>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_c_grid + c_batch_offset,
p_shared,
a_element_op,
b_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1_tail,
b_grid_desc_bk0_n_bk1_tail,
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map);
}
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = batch_count;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = a_grid_desc_ak0_m_ak1_tail;
ignore = b_grid_desc_bk0_n_bk1_tail;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = compute_ptr_offset_of_batch;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <typename ALayout,
typename BLayout,
typename CLayout,
......@@ -69,14 +171,127 @@ struct DeviceGemmXdlSplitKCShuffle
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static auto GetKPad(index_t K1, index_t K, index_t KBatch)
template <index_t K1>
static auto GetActualBatchAndKSplitted(index_t K, index_t KBatch)
{
const index_t K0PerBlock = KPerBlock / K1;
const index_t K0 = math::integer_divide_ceil(K, KPerBlock * KBatch) * K0PerBlock;
const index_t KSplitted = K0 * K1;
const index_t actual_batch = math::integer_divide_ceil(K, KSplitted);
return std::make_pair(actual_batch, KSplitted);
}
template <bool IsTail>
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA);
template <bool IsTail>
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB);
/*
* No padding in K
*/
template <>
static auto MakeAGridDescriptor_AK0_M_AK1<false>(index_t MRaw, index_t K, index_t StrideA)
{
assert(K % KPerBlock == 0);
assert(K % AK1 == 0);
const auto a_grid_desc_mraw_k = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, K), make_tuple(I1, StrideA));
}
}();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto MPad = M - MRaw;
const auto AK0 = K / AK1;
if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M, but not K
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else
{
// not pad M or K
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
}
template <>
static auto MakeBGridDescriptor_BK0_N_BK1<false>(index_t K, index_t NRaw, index_t StrideB)
{
const index_t K0 = math::integer_divide_ceil(K, K1 * KPerBlock * KBatch) * KPerBlock;
const index_t KPad = KBatch * K0 * K1;
return KPad;
assert(K % KPerBlock == 0);
assert(K % BK1 == 0);
const auto b_grid_desc_nraw_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, K), make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, K), make_tuple(StrideB, I1));
}
}();
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto NPad = N - NRaw;
const auto BK0 = K / BK1;
if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad N, but not K
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
}
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
template <>
static auto MakeAGridDescriptor_AK0_M_AK1<true>(index_t MRaw, index_t KRaw, index_t StrideA)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
......@@ -96,14 +311,14 @@ struct DeviceGemmXdlSplitKCShuffle
const auto MPad = M - MRaw;
const auto KPad = K - KRaw;
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both M and K
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
......@@ -121,31 +336,9 @@ struct DeviceGemmXdlSplitKCShuffle
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad M, but not K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
else
{
// pad K, but not M
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
......@@ -161,25 +354,10 @@ struct DeviceGemmXdlSplitKCShuffle
return a_grid_desc_ak0_m_ak1;
}
else
{
// not pad M or K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
}
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
template <>
static auto MakeBGridDescriptor_BK0_N_BK1<true>(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
......@@ -200,14 +378,13 @@ struct DeviceGemmXdlSplitKCShuffle
const auto NPad = N - NRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both N and K
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{ // pad both N and K
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad),
......@@ -224,31 +401,8 @@ struct DeviceGemmXdlSplitKCShuffle
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
else // pad K, but not N
{
// pad K, but not N
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
......@@ -264,22 +418,6 @@ struct DeviceGemmXdlSplitKCShuffle
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
}
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
......@@ -340,10 +478,11 @@ struct DeviceGemmXdlSplitKCShuffle
}
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1<false>(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1<false>(1, 1, 1));
using AGridDesc_AK0_M_AK1_Tail = decltype(MakeAGridDescriptor_AK0_M_AK1<true>(1, 1, 1));
using BGridDesc_BK0_N_BK1_Tail = decltype(MakeBGridDescriptor_BK0_N_BK1<true>(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
struct ComputePtrOffsetOfStridedBatch
{
......@@ -418,7 +557,8 @@ struct DeviceGemmXdlSplitKCShuffle
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = decltype(
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
using Block2CTileMap = decltype(BatchedGemmUtil::MakeBlock2CTileMap<MPerBlock, NPerBlock>(1, 1, 1));
using Block2CTileMap =
decltype(BatchedGemmUtil::MakeBlock2CTileMap<MPerBlock, NPerBlock>(1, 1, 1));
struct Argument : public BaseArgument
{
......@@ -445,21 +585,40 @@ struct DeviceGemmXdlSplitKCShuffle
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
const auto AKPad = GetKPad(AK1, KRaw, k_batch);
assert(AKPad % k_batch == 0);
const auto BKPad = GetKPad(BK1, KRaw, k_batch);
assert(BKPad % k_batch == 0);
const auto AKSplitted = AKPad / k_batch;
const auto BKSplitted = BKPad / k_batch;
const auto actual_batch_and_ksplitted_A =
GetActualBatchAndKSplitted<AK1>(KRaw, k_batch);
const auto actual_batch_and_ksplitted_B =
GetActualBatchAndKSplitted<BK1>(KRaw, k_batch);
assert(actual_batch_and_ksplitted_A.first == actual_batch_and_ksplitted_B.first);
BatchCount_ = actual_batch_and_ksplitted_A.first;
const auto AKSplitted = actual_batch_and_ksplitted_A.second;
const auto BKSplitted = actual_batch_and_ksplitted_B.second;
a_grid_desc_ak0_m_ak1_ =
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, AKSplitted, StrideA);
DeviceOp::MakeAGridDescriptor_AK0_M_AK1<false>(MRaw, AKSplitted, StrideA);
b_grid_desc_bk0_n_bk1_ =
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(BKSplitted, NRaw, StrideB);
DeviceOp::MakeBGridDescriptor_BK0_N_BK1<false>(BKSplitted, NRaw, StrideB);
c_grid_desc_m_n_ = DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC);
if(GridwiseGemm::CheckValidity(
a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_))
is_valid_ = GridwiseGemm::CheckValidity(
a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_);
if(KRaw != AKSplitted * BatchCount_ || KRaw != BKSplitted * BatchCount_)
{
has_tail_ = true;
const auto AKTail = KRaw - AKSplitted * (BatchCount_ - 1);
const auto BKTail = KRaw - BKSplitted * (BatchCount_ - 1);
a_grid_desc_ak0_m_ak1_tail_ =
DeviceOp::MakeAGridDescriptor_AK0_M_AK1<true>(MRaw, AKTail, StrideA);
b_grid_desc_bk0_n_bk1_tail_ =
DeviceOp::MakeBGridDescriptor_BK0_N_BK1<true>(BKTail, NRaw, StrideB);
is_valid_ &= GridwiseGemm::CheckValidity(
a_grid_desc_ak0_m_ak1_tail_, b_grid_desc_bk0_n_bk1_tail_, c_grid_desc_m_n_);
}
if(is_valid_)
{
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
......@@ -492,7 +651,8 @@ struct DeviceGemmXdlSplitKCShuffle
compute_ptr_offset_of_batch_ =
ComputePtrOffsetOfStridedBatch{a_batch_stride, b_batch_stride};
block_2_ctile_map_ = BatchedGemmUtil::MakeBlock2CTileMap<MPerBlock, NPerBlock>(BatchCount_, c_grid_desc_m_n_.GetLength(I0), c_grid_desc_m_n_.GetLength(I1));
block_2_ctile_map_ = BatchedGemmUtil::MakeBlock2CTileMap<MPerBlock, NPerBlock>(
BatchCount_, c_grid_desc_m_n_.GetLength(I0), c_grid_desc_m_n_.GetLength(I1));
}
}
......@@ -501,8 +661,12 @@ struct DeviceGemmXdlSplitKCShuffle
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
index_t BatchCount_;
bool has_tail_;
bool is_valid_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
AGridDesc_AK0_M_AK1_Tail a_grid_desc_ak0_m_ak1_tail_;
BGridDesc_BK0_N_BK1_Tail b_grid_desc_bk0_n_bk1_tail_;
CGridDesc_M_N c_grid_desc_m_n_;
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
......@@ -534,10 +698,23 @@ struct DeviceGemmXdlSplitKCShuffle
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
if(arg.has_tail_)
{
std::cout << "arg.a_grid_desc_ak0_m_ak1_tail_{"
<< arg.a_grid_desc_ak0_m_ak1_tail_.GetLength(I0) << ", "
<< arg.a_grid_desc_ak0_m_ak1_tail_.GetLength(I1) << ", "
<< arg.a_grid_desc_ak0_m_ak1_tail_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_bk0_n_bk1_tail_{"
<< arg.b_grid_desc_bk0_n_bk1_tail_.GetLength(I0) << ", "
<< arg.b_grid_desc_bk0_n_bk1_tail_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_tail_.GetLength(I2) << "}" << std::endl;
}
}
if(!GridwiseGemm::CheckValidity(
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_))
if(!arg.is_valid_)
{
throw std::runtime_error(
"wrong! GridwiseBatchedGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
......@@ -546,127 +723,233 @@ struct DeviceGemmXdlSplitKCShuffle
const index_t grid_size =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_;
const auto K0 = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0);
const auto K0 = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
float ave_time = 0;
if(has_main_k0_block_loop)
if(arg.has_tail_)
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v1<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffsetOfStridedBatch,
Block2CTileMap,
true>;
if(nrepeat == 0)
const auto K0_tail = arg.a_grid_desc_ak0_m_ak1_tail_.GetLength(I0);
const bool tail_has_main_k0_block_loop =
GridwiseGemm::CalculateHasMainK0BlockLoop(K0_tail);
const auto Run = [&](const auto& kernel) {
if(nrepeat == 0)
{
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.a_grid_desc_ak0_m_ak1_tail_,
arg.b_grid_desc_bk0_n_bk1_tail_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
return 0.0f;
}
else
{
return launch_and_time_kernel(
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.a_grid_desc_ak0_m_ak1_tail_,
arg.b_grid_desc_bk0_n_bk1_tail_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
}
};
if(has_main_k0_block_loop && tail_has_main_k0_block_loop)
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v1<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::AGridDesc_AK0_M_AK1_Tail,
DeviceOp::BGridDesc_BK0_N_BK1_Tail,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffsetOfStridedBatch,
Block2CTileMap,
true,
true>;
ave_time = Run(kernel);
}
else if(has_main_k0_block_loop && !tail_has_main_k0_block_loop)
{
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v1<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::AGridDesc_AK0_M_AK1_Tail,
DeviceOp::BGridDesc_BK0_N_BK1_Tail,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffsetOfStridedBatch,
Block2CTileMap,
true,
false>;
ave_time = Run(kernel);
}
else if(!has_main_k0_block_loop && tail_has_main_k0_block_loop)
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v1<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::AGridDesc_AK0_M_AK1_Tail,
DeviceOp::BGridDesc_BK0_N_BK1_Tail,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffsetOfStridedBatch,
Block2CTileMap,
false,
true>;
ave_time = Run(kernel);
}
else
{
ave_time =
launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v1<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::AGridDesc_AK0_M_AK1_Tail,
DeviceOp::BGridDesc_BK0_N_BK1_Tail,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffsetOfStridedBatch,
Block2CTileMap,
false,
false>;
ave_time = Run(kernel);
}
}
else
{
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v1<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffsetOfStridedBatch,
Block2CTileMap,
false>;
if(nrepeat == 0)
const auto Run = [&](const auto& kernel) {
if(nrepeat == 0)
{
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
return 0.0f;
}
else
{
return launch_and_time_kernel(
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
}
};
if(has_main_k0_block_loop)
{
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
const auto kernel = ck::kernel_batched_gemm_xdl_cshuffle_v1<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffsetOfStridedBatch,
Block2CTileMap,
true>;
ave_time = Run(kernel);
}
else
{
ave_time =
launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
const auto kernel = ck::kernel_batched_gemm_xdl_cshuffle_v1<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffsetOfStridedBatch,
Block2CTileMap,
false>;
ave_time = Run(kernel);
}
}
return ave_time;
}
......@@ -781,7 +1064,7 @@ struct DeviceGemmXdlSplitKCShuffle
return str.str();
}
};
}; // namespace device
} // namespace device
} // namespace tensor_operation
......
......@@ -142,6 +142,7 @@ __global__ void
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <typename FloatAB,
typename FloatGemmAcc,
typename FloatCShuffle,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment