Commit 8e3c41a5 authored by Jianfeng yan's avatar Jianfeng yan
Browse files

minor changes

parent 7910f486
...@@ -173,11 +173,10 @@ struct DeviceGemmXdlSplitK ...@@ -173,11 +173,10 @@ struct DeviceGemmXdlSplitK
return std::make_pair(actual_batch, KSplitted); return std::make_pair(actual_batch, KSplitted);
} }
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) static auto MakeAGridDescriptor_K0_M_K1_Tail(index_t M, index_t K, index_t StrideA)
{ {
assert(K % (K1 * K0PerBlock) == 0); const index_t KPadded = math::integer_divide_ceil(K, K1 * K0PerBlock) * K1 * K0PerBlock;
const index_t K0 = KPadded / K1;
const index_t K0 = K / K1;
const auto a_grid_desc_m_k = [&]() { const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
...@@ -190,12 +189,18 @@ struct DeviceGemmXdlSplitK ...@@ -190,12 +189,18 @@ struct DeviceGemmXdlSplitK
} }
}(); }();
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPadded - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == GemmSpecialization::MNPadding) if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{ {
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor( return transform_tensor_descriptor(
a_grid_desc_m_k, a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(M, PadM)), make_right_pad_transform(M, PadM)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
...@@ -204,7 +209,7 @@ struct DeviceGemmXdlSplitK ...@@ -204,7 +209,7 @@ struct DeviceGemmXdlSplitK
else else
{ {
return transform_tensor_descriptor( return transform_tensor_descriptor(
a_grid_desc_m_k, a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(M)), make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
...@@ -212,11 +217,11 @@ struct DeviceGemmXdlSplitK ...@@ -212,11 +217,11 @@ struct DeviceGemmXdlSplitK
} }
} }
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) static auto MakeBGridDescriptor_K0_N_K1_Tail(index_t K, index_t N, index_t StrideB)
{ {
assert(K % (K1 * K0PerBlock) == 0); const index_t KPadded = math::integer_divide_ceil(K, K1 * K0PerBlock) * K1 * K0PerBlock;
const index_t K0 = K / K1; const index_t K0 = KPadded / K1;
const auto b_grid_desc_k_n = [&]() { const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
...@@ -229,12 +234,18 @@ struct DeviceGemmXdlSplitK ...@@ -229,12 +234,18 @@ struct DeviceGemmXdlSplitK
} }
}(); }();
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_right_pad_transform(K, KPadded - K), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == GemmSpecialization::MNPadding) if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{ {
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor( return transform_tensor_descriptor(
b_grid_desc_k_n, b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(N, PadN)), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
...@@ -243,7 +254,7 @@ struct DeviceGemmXdlSplitK ...@@ -243,7 +254,7 @@ struct DeviceGemmXdlSplitK
else else
{ {
return transform_tensor_descriptor( return transform_tensor_descriptor(
b_grid_desc_k_n, b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
...@@ -251,10 +262,13 @@ struct DeviceGemmXdlSplitK ...@@ -251,10 +262,13 @@ struct DeviceGemmXdlSplitK
} }
} }
static auto MakeAGridDescriptor_K0_M_K1_Tail(index_t M, index_t K, index_t StrideA) static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
{ {
const index_t KPadded = math::integer_divide_ceil(K, K1 * K0PerBlock) * K1 * K0PerBlock; // return MakeAGridDescriptor_K0_M_K1_Tail(M, K, StrideA);
const index_t K0 = KPadded / K1;
assert(K % (K1 * K0PerBlock) == 0);
const index_t K0 = K / K1;
const auto a_grid_desc_m_k = [&]() { const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
...@@ -267,18 +281,12 @@ struct DeviceGemmXdlSplitK ...@@ -267,18 +281,12 @@ struct DeviceGemmXdlSplitK
} }
}(); }();
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPadded - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == GemmSpecialization::MNPadding) if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{ {
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor( return transform_tensor_descriptor(
a_grid_desc_m_kpad, a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(M, PadM)), make_right_pad_transform(M, PadM)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
...@@ -287,7 +295,7 @@ struct DeviceGemmXdlSplitK ...@@ -287,7 +295,7 @@ struct DeviceGemmXdlSplitK
else else
{ {
return transform_tensor_descriptor( return transform_tensor_descriptor(
a_grid_desc_m_kpad, a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(M)), make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
...@@ -295,11 +303,12 @@ struct DeviceGemmXdlSplitK ...@@ -295,11 +303,12 @@ struct DeviceGemmXdlSplitK
} }
} }
static auto MakeBGridDescriptor_K0_N_K1_Tail(index_t K, index_t N, index_t StrideB) static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
{ {
const index_t KPadded = math::integer_divide_ceil(K, K1 * K0PerBlock) * K1 * K0PerBlock; // return MakeBGridDescriptor_K0_N_K1_Tail(K, N, StrideB);
assert(K % (K1 * K0PerBlock) == 0);
const index_t K0 = KPadded / K1; const index_t K0 = K / K1;
const auto b_grid_desc_k_n = [&]() { const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
...@@ -312,18 +321,12 @@ struct DeviceGemmXdlSplitK ...@@ -312,18 +321,12 @@ struct DeviceGemmXdlSplitK
} }
}(); }();
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_right_pad_transform(K, KPadded - K), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == GemmSpecialization::MNPadding) if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{ {
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor( return transform_tensor_descriptor(
b_grid_desc_kpad_n, b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(N, PadN)), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
...@@ -332,7 +335,7 @@ struct DeviceGemmXdlSplitK ...@@ -332,7 +335,7 @@ struct DeviceGemmXdlSplitK
else else
{ {
return transform_tensor_descriptor( return transform_tensor_descriptor(
b_grid_desc_kpad_n, b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
...@@ -674,8 +677,7 @@ struct DeviceGemmXdlSplitK ...@@ -674,8 +677,7 @@ struct DeviceGemmXdlSplitK
const bool tail_has_main_k0_block_loop = const bool tail_has_main_k0_block_loop =
GridwiseGemm::CalculateHasMainK0BlockLoop(K0_tail); GridwiseGemm::CalculateHasMainK0BlockLoop(K0_tail);
const auto Run = [&](const auto& kernel) const auto Run = [&](const auto& kernel) {
{
return launch_and_time_kernel(kernel, return launch_and_time_kernel(kernel,
nrepeat, nrepeat,
dim3(grid_size), dim3(grid_size),
...@@ -695,7 +697,6 @@ struct DeviceGemmXdlSplitK ...@@ -695,7 +697,6 @@ struct DeviceGemmXdlSplitK
arg.c_element_op_, arg.c_element_op_,
arg.compute_ptr_offset_of_batch_, arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
}; };
if(has_main_k0_block_loop && tail_has_main_k0_block_loop) if(has_main_k0_block_loop && tail_has_main_k0_block_loop)
...@@ -718,7 +719,6 @@ struct DeviceGemmXdlSplitK ...@@ -718,7 +719,6 @@ struct DeviceGemmXdlSplitK
true>; true>;
ave_time = Run(kernel); ave_time = Run(kernel);
} }
else if(has_main_k0_block_loop && !tail_has_main_k0_block_loop) else if(has_main_k0_block_loop && !tail_has_main_k0_block_loop)
{ {
......
...@@ -20,9 +20,10 @@ namespace tensor_operation { ...@@ -20,9 +20,10 @@ namespace tensor_operation {
namespace device { namespace device {
/* /*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. * \brief Wrapper function of GridwiseGemm::Run to realize a customized BatchedGemm for splitK.
* *
* \see \link device_batched_gemm_xdl.hpp kernel_batched_gemm_xdlops_v2r3 * The main difference from \see \link device_batched_gemm_xdl.hpp kernel_batched_gemm_xdlops_v2r3
* is that there are 2 different tensor descriptors for matrix A and B.
*/ */
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
...@@ -193,6 +194,7 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -193,6 +194,7 @@ struct DeviceGemmXdlSplitKCShuffle
template <> template <>
static auto MakeAGridDescriptor_AK0_M_AK1<false>(index_t MRaw, index_t K, index_t StrideA) static auto MakeAGridDescriptor_AK0_M_AK1<false>(index_t MRaw, index_t K, index_t StrideA)
{ {
// return MakeAGridDescriptor_AK0_M_AK1<true>(MRaw, K, StrideA);
assert(K % KPerBlock == 0); assert(K % KPerBlock == 0);
assert(K % AK1 == 0); assert(K % AK1 == 0);
...@@ -243,6 +245,7 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -243,6 +245,7 @@ struct DeviceGemmXdlSplitKCShuffle
template <> template <>
static auto MakeBGridDescriptor_BK0_N_BK1<false>(index_t K, index_t NRaw, index_t StrideB) static auto MakeBGridDescriptor_BK0_N_BK1<false>(index_t K, index_t NRaw, index_t StrideB)
{ {
// return MakeBGridDescriptor_BK0_N_BK1<true>(K, NRaw, StrideB);
assert(K % KPerBlock == 0); assert(K % KPerBlock == 0);
assert(K % BK1 == 0); assert(K % BK1 == 0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment