Commit 7910f486 authored by Jianfeng yan's avatar Jianfeng yan
Browse files

DeviceGemmXdlSplit and DeviceGemmXdlSplitKCShuffle both work for arbitrary K

parent b5a9f642
......@@ -19,6 +19,11 @@ namespace ck {
namespace tensor_operation {
namespace device {
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \see \link device_batched_gemm_xdl.hpp kernel_batched_gemm_xdlops_v2r3
*/
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
......@@ -159,15 +164,12 @@ struct DeviceGemmXdlSplitK
static constexpr auto K1Number = Number<K1>{};
// static constexpr index_t Getk
static auto GetActualBatchAndKSplitted(index_t K, index_t KBatch)
{
const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
const index_t KSplitted = K0 * K1;
const index_t actual_batch = math::integer_divide_ceil(K, KSplitted);
// return std::make_pair<index_t, index_t>(actual_batch, KSplitted);
return std::make_pair(actual_batch, KSplitted);
}
......@@ -251,8 +253,8 @@ struct DeviceGemmXdlSplitK
static auto MakeAGridDescriptor_K0_M_K1_Tail(index_t M, index_t K, index_t StrideA)
{
const index_t KPad = math::integer_divide_ceil(K, K1 * K0PerBlock) * K1 * K0PerBlock;
const index_t K0 = KPad / K1;
const index_t KPadded = math::integer_divide_ceil(K, K1 * K0PerBlock) * K1 * K0PerBlock;
const index_t K0 = KPadded / K1;
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
......@@ -267,7 +269,7 @@ struct DeviceGemmXdlSplitK
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPadded - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -295,9 +297,9 @@ struct DeviceGemmXdlSplitK
static auto MakeBGridDescriptor_K0_N_K1_Tail(index_t K, index_t N, index_t StrideB)
{
const index_t KPad = math::integer_divide_ceil(K, K1 * K0PerBlock) * K1 * K0PerBlock;
const index_t KPadded = math::integer_divide_ceil(K, K1 * K0PerBlock) * K1 * K0PerBlock;
const index_t K0 = KPad / K1;
const index_t K0 = KPadded / K1;
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
......@@ -312,7 +314,7 @@ struct DeviceGemmXdlSplitK
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)),
make_tuple(make_right_pad_transform(K, KPadded - K), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -672,26 +674,9 @@ struct DeviceGemmXdlSplitK
const bool tail_has_main_k0_block_loop =
GridwiseGemm::CalculateHasMainK0BlockLoop(K0_tail);
if(has_main_k0_block_loop && tail_has_main_k0_block_loop)
const auto Run = [&](const auto& kernel)
{
const auto kernel = kernel_batched_gemm_xdlops_v2r3<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1_Tail>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1_Tail>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
ComputePtrOffsetOfStridedBatch,
remove_reference_t<Block2CTileMap>,
true,
true>;
ave_time = launch_and_time_kernel(kernel,
return launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
......@@ -710,6 +695,30 @@ struct DeviceGemmXdlSplitK
arg.c_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
};
if(has_main_k0_block_loop && tail_has_main_k0_block_loop)
{
const auto kernel = kernel_batched_gemm_xdlops_v2r3<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1_Tail>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1_Tail>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
ComputePtrOffsetOfStridedBatch,
remove_reference_t<Block2CTileMap>,
true,
true>;
ave_time = Run(kernel);
}
else if(has_main_k0_block_loop && !tail_has_main_k0_block_loop)
{
......@@ -730,25 +739,7 @@ struct DeviceGemmXdlSplitK
true,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.a_grid_desc_k0_m_k1_tail_,
arg.b_grid_desc_k0_n_k1_tail_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
ave_time = Run(kernel);
}
else if(!has_main_k0_block_loop && tail_has_main_k0_block_loop)
{
......@@ -769,25 +760,7 @@ struct DeviceGemmXdlSplitK
false,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.a_grid_desc_k0_m_k1_tail_,
arg.b_grid_desc_k0_n_k1_tail_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
ave_time = Run(kernel);
}
else
{
......@@ -808,25 +781,7 @@ struct DeviceGemmXdlSplitK
false,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.a_grid_desc_k0_m_k1_tail_,
arg.b_grid_desc_k0_n_k1_tail_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
ave_time = Run(kernel);
}
}
else
......
......@@ -142,6 +142,7 @@ __global__ void
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <typename FloatAB,
typename FloatGemmAcc,
typename FloatCShuffle,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment