Commit 7910f486 authored by Jianfeng yan's avatar Jianfeng yan
Browse files

DeviceGemmXdlSplit and DeviceGemmXdlSplitKCShuffle both work for arbitrary K

parent b5a9f642
...@@ -19,6 +19,11 @@ namespace ck { ...@@ -19,6 +19,11 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \see \link device_batched_gemm_xdl.hpp kernel_batched_gemm_xdlops_v2r3
*/
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
...@@ -159,15 +164,12 @@ struct DeviceGemmXdlSplitK ...@@ -159,15 +164,12 @@ struct DeviceGemmXdlSplitK
static constexpr auto K1Number = Number<K1>{}; static constexpr auto K1Number = Number<K1>{};
// static constexpr index_t Getk
static auto GetActualBatchAndKSplitted(index_t K, index_t KBatch) static auto GetActualBatchAndKSplitted(index_t K, index_t KBatch)
{ {
const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock; const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
const index_t KSplitted = K0 * K1; const index_t KSplitted = K0 * K1;
const index_t actual_batch = math::integer_divide_ceil(K, KSplitted); const index_t actual_batch = math::integer_divide_ceil(K, KSplitted);
// return std::make_pair<index_t, index_t>(actual_batch, KSplitted);
return std::make_pair(actual_batch, KSplitted); return std::make_pair(actual_batch, KSplitted);
} }
...@@ -251,8 +253,8 @@ struct DeviceGemmXdlSplitK ...@@ -251,8 +253,8 @@ struct DeviceGemmXdlSplitK
static auto MakeAGridDescriptor_K0_M_K1_Tail(index_t M, index_t K, index_t StrideA) static auto MakeAGridDescriptor_K0_M_K1_Tail(index_t M, index_t K, index_t StrideA)
{ {
const index_t KPad = math::integer_divide_ceil(K, K1 * K0PerBlock) * K1 * K0PerBlock; const index_t KPadded = math::integer_divide_ceil(K, K1 * K0PerBlock) * K1 * K0PerBlock;
const index_t K0 = KPad / K1; const index_t K0 = KPadded / K1;
const auto a_grid_desc_m_k = [&]() { const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
...@@ -267,7 +269,7 @@ struct DeviceGemmXdlSplitK ...@@ -267,7 +269,7 @@ struct DeviceGemmXdlSplitK
const auto a_grid_desc_m_kpad = transform_tensor_descriptor( const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k, a_grid_desc_m_k,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPadded - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -295,9 +297,9 @@ struct DeviceGemmXdlSplitK ...@@ -295,9 +297,9 @@ struct DeviceGemmXdlSplitK
static auto MakeBGridDescriptor_K0_N_K1_Tail(index_t K, index_t N, index_t StrideB) static auto MakeBGridDescriptor_K0_N_K1_Tail(index_t K, index_t N, index_t StrideB)
{ {
const index_t KPad = math::integer_divide_ceil(K, K1 * K0PerBlock) * K1 * K0PerBlock; const index_t KPadded = math::integer_divide_ceil(K, K1 * K0PerBlock) * K1 * K0PerBlock;
const index_t K0 = KPad / K1; const index_t K0 = KPadded / K1;
const auto b_grid_desc_k_n = [&]() { const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
...@@ -312,7 +314,7 @@ struct DeviceGemmXdlSplitK ...@@ -312,7 +314,7 @@ struct DeviceGemmXdlSplitK
const auto b_grid_desc_kpad_n = transform_tensor_descriptor( const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
b_grid_desc_k_n, b_grid_desc_k_n,
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), make_tuple(make_right_pad_transform(K, KPadded - K), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -672,26 +674,9 @@ struct DeviceGemmXdlSplitK ...@@ -672,26 +674,9 @@ struct DeviceGemmXdlSplitK
const bool tail_has_main_k0_block_loop = const bool tail_has_main_k0_block_loop =
GridwiseGemm::CalculateHasMainK0BlockLoop(K0_tail); GridwiseGemm::CalculateHasMainK0BlockLoop(K0_tail);
if(has_main_k0_block_loop && tail_has_main_k0_block_loop) const auto Run = [&](const auto& kernel)
{ {
const auto kernel = kernel_batched_gemm_xdlops_v2r3< return launch_and_time_kernel(kernel,
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1_Tail>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1_Tail>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
ComputePtrOffsetOfStridedBatch,
remove_reference_t<Block2CTileMap>,
true,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
...@@ -710,6 +695,30 @@ struct DeviceGemmXdlSplitK ...@@ -710,6 +695,30 @@ struct DeviceGemmXdlSplitK
arg.c_element_op_, arg.c_element_op_,
arg.compute_ptr_offset_of_batch_, arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
};
if(has_main_k0_block_loop && tail_has_main_k0_block_loop)
{
const auto kernel = kernel_batched_gemm_xdlops_v2r3<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1_Tail>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1_Tail>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
ComputePtrOffsetOfStridedBatch,
remove_reference_t<Block2CTileMap>,
true,
true>;
ave_time = Run(kernel);
} }
else if(has_main_k0_block_loop && !tail_has_main_k0_block_loop) else if(has_main_k0_block_loop && !tail_has_main_k0_block_loop)
{ {
...@@ -730,25 +739,7 @@ struct DeviceGemmXdlSplitK ...@@ -730,25 +739,7 @@ struct DeviceGemmXdlSplitK
true, true,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = Run(kernel);
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.a_grid_desc_k0_m_k1_tail_,
arg.b_grid_desc_k0_n_k1_tail_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
} }
else if(!has_main_k0_block_loop && tail_has_main_k0_block_loop) else if(!has_main_k0_block_loop && tail_has_main_k0_block_loop)
{ {
...@@ -769,25 +760,7 @@ struct DeviceGemmXdlSplitK ...@@ -769,25 +760,7 @@ struct DeviceGemmXdlSplitK
false, false,
true>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = Run(kernel);
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.a_grid_desc_k0_m_k1_tail_,
arg.b_grid_desc_k0_n_k1_tail_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
} }
else else
{ {
...@@ -808,25 +781,7 @@ struct DeviceGemmXdlSplitK ...@@ -808,25 +781,7 @@ struct DeviceGemmXdlSplitK
false, false,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = Run(kernel);
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.a_grid_desc_k0_m_k1_tail_,
arg.b_grid_desc_k0_n_k1_tail_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
} }
} }
else else
......
...@@ -142,6 +142,7 @@ __global__ void ...@@ -142,6 +142,7 @@ __global__ void
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
template <typename FloatAB, template <typename FloatAB,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment