Commit 8e3c41a5 authored by Jianfeng yan's avatar Jianfeng yan
Browse files

minor changes

parent 7910f486
......@@ -173,11 +173,10 @@ struct DeviceGemmXdlSplitK
return std::make_pair(actual_batch, KSplitted);
}
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
static auto MakeAGridDescriptor_K0_M_K1_Tail(index_t M, index_t K, index_t StrideA)
{
assert(K % (K1 * K0PerBlock) == 0);
const index_t K0 = K / K1;
const index_t KPadded = math::integer_divide_ceil(K, K1 * K0PerBlock) * K1 * K0PerBlock;
const index_t K0 = KPadded / K1;
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
......@@ -190,12 +189,18 @@ struct DeviceGemmXdlSplitK
}
}();
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPadded - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor(
a_grid_desc_m_k,
a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(M, PadM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
......@@ -204,7 +209,7 @@ struct DeviceGemmXdlSplitK
else
{
return transform_tensor_descriptor(
a_grid_desc_m_k,
a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
......@@ -212,11 +217,11 @@ struct DeviceGemmXdlSplitK
}
}
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
static auto MakeBGridDescriptor_K0_N_K1_Tail(index_t K, index_t N, index_t StrideB)
{
assert(K % (K1 * K0PerBlock) == 0);
const index_t KPadded = math::integer_divide_ceil(K, K1 * K0PerBlock) * K1 * K0PerBlock;
const index_t K0 = K / K1;
const index_t K0 = KPadded / K1;
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
......@@ -229,12 +234,18 @@ struct DeviceGemmXdlSplitK
}
}();
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_right_pad_transform(K, KPadded - K), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_k_n,
b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
......@@ -243,7 +254,7 @@ struct DeviceGemmXdlSplitK
else
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
......@@ -251,10 +262,13 @@ struct DeviceGemmXdlSplitK
}
}
static auto MakeAGridDescriptor_K0_M_K1_Tail(index_t M, index_t K, index_t StrideA)
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
{
const index_t KPadded = math::integer_divide_ceil(K, K1 * K0PerBlock) * K1 * K0PerBlock;
const index_t K0 = KPadded / K1;
// return MakeAGridDescriptor_K0_M_K1_Tail(M, K, StrideA);
assert(K % (K1 * K0PerBlock) == 0);
const index_t K0 = K / K1;
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
......@@ -267,18 +281,12 @@ struct DeviceGemmXdlSplitK
}
}();
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPadded - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(M, PadM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
......@@ -287,7 +295,7 @@ struct DeviceGemmXdlSplitK
else
{
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
......@@ -295,11 +303,12 @@ struct DeviceGemmXdlSplitK
}
}
static auto MakeBGridDescriptor_K0_N_K1_Tail(index_t K, index_t N, index_t StrideB)
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
{
const index_t KPadded = math::integer_divide_ceil(K, K1 * K0PerBlock) * K1 * K0PerBlock;
// return MakeBGridDescriptor_K0_N_K1_Tail(K, N, StrideB);
assert(K % (K1 * K0PerBlock) == 0);
const index_t K0 = KPadded / K1;
const index_t K0 = K / K1;
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
......@@ -312,18 +321,12 @@ struct DeviceGemmXdlSplitK
}
}();
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_right_pad_transform(K, KPadded - K), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
......@@ -332,7 +335,7 @@ struct DeviceGemmXdlSplitK
else
{
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
......@@ -674,28 +677,26 @@ struct DeviceGemmXdlSplitK
const bool tail_has_main_k0_block_loop =
GridwiseGemm::CalculateHasMainK0BlockLoop(K0_tail);
const auto Run = [&](const auto& kernel)
{
const auto Run = [&](const auto& kernel) {
return launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.a_grid_desc_k0_m_k1_tail_,
arg.b_grid_desc_k0_n_k1_tail_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.a_grid_desc_k0_m_k1_tail_,
arg.b_grid_desc_k0_n_k1_tail_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
};
if(has_main_k0_block_loop && tail_has_main_k0_block_loop)
......@@ -718,7 +719,6 @@ struct DeviceGemmXdlSplitK
true>;
ave_time = Run(kernel);
}
else if(has_main_k0_block_loop && !tail_has_main_k0_block_loop)
{
......
......@@ -20,9 +20,10 @@ namespace tensor_operation {
namespace device {
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
* \brief Wrapper function of GridwiseGemm::Run to realize a customized BatchedGemm for splitK.
*
* \see \link device_batched_gemm_xdl.hpp kernel_batched_gemm_xdlops_v2r3
* The main difference from \see \link device_batched_gemm_xdl.hpp kernel_batched_gemm_xdlops_v2r3
* is that there are 2 different tensor descriptors for matrix A and B.
*/
template <typename GridwiseGemm,
typename FloatAB,
......@@ -174,7 +175,7 @@ struct DeviceGemmXdlSplitKCShuffle
template <index_t K1>
static auto GetActualBatchAndKSplitted(index_t K, index_t KBatch)
{
const index_t K0PerBlock = KPerBlock / K1;
const index_t K0PerBlock = KPerBlock / K1;
const index_t K0 = math::integer_divide_ceil(K, KPerBlock * KBatch) * K0PerBlock;
const index_t KSplitted = K0 * K1;
const index_t actual_batch = math::integer_divide_ceil(K, KSplitted);
......@@ -193,6 +194,7 @@ struct DeviceGemmXdlSplitKCShuffle
template <>
static auto MakeAGridDescriptor_AK0_M_AK1<false>(index_t MRaw, index_t K, index_t StrideA)
{
// return MakeAGridDescriptor_AK0_M_AK1<true>(MRaw, K, StrideA);
assert(K % KPerBlock == 0);
assert(K % AK1 == 0);
......@@ -243,6 +245,7 @@ struct DeviceGemmXdlSplitKCShuffle
template <>
static auto MakeBGridDescriptor_BK0_N_BK1<false>(index_t K, index_t NRaw, index_t StrideB)
{
// return MakeBGridDescriptor_BK0_N_BK1<true>(K, NRaw, StrideB);
assert(K % KPerBlock == 0);
assert(K % BK1 == 0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment