Commit 06f57782 authored by Jianfeng yan's avatar Jianfeng yan
Browse files

regress to using 1 grid_desc

parent 308146e7
......@@ -29,8 +29,6 @@ template <typename GridwiseGemm,
typename FloatC,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename AGridDesc_K0_M_K1_Tail,
typename BGridDesc_K0_N_K1_Tail,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename AElementwiseOperation,
typename BElementwiseOperation,
......@@ -50,8 +48,8 @@ __global__ void
const index_t batch_count,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const AGridDesc_K0_M_K1_Tail a_grid_desc_k0_m_k1_tail,
const BGridDesc_K0_N_K1_Tail b_grid_desc_k0_n_k1_tail,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_tail,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_tail,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
......@@ -173,7 +171,7 @@ struct DeviceGemmXdlSplitK
return std::make_pair(actual_batch, KSplitted);
}
static auto MakeAGridDescriptor_K0_M_K1_Tail(index_t M, index_t K, index_t StrideA)
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
{
const index_t KPadded = math::integer_divide_ceil(K, K1 * K0PerBlock) * K1 * K0PerBlock;
const index_t K0 = KPadded / K1;
......@@ -217,7 +215,7 @@ struct DeviceGemmXdlSplitK
}
}
static auto MakeBGridDescriptor_K0_N_K1_Tail(index_t K, index_t N, index_t StrideB)
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
{
const index_t KPadded = math::integer_divide_ceil(K, K1 * K0PerBlock) * K1 * K0PerBlock;
......@@ -262,87 +260,6 @@ struct DeviceGemmXdlSplitK
}
}
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
{
// return MakeAGridDescriptor_K0_M_K1_Tail(M, K, StrideA);
assert(K % (K1 * K0PerBlock) == 0);
const index_t K0 = K / K1;
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(M, PadM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
{
// return MakeBGridDescriptor_K0_N_K1_Tail(K, N, StrideB);
assert(K % (K1 * K0PerBlock) == 0);
const index_t K0 = K / K1;
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
{
const auto c_grid_desc_m_n = [&]() {
......@@ -378,11 +295,9 @@ struct DeviceGemmXdlSplitK
}
}
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using AGridDesc_K0_M_K1_Tail = decltype(MakeAGridDescriptor_K0_M_K1_Tail(1, 1, 1));
using BGridDesc_K0_N_K1_Tail = decltype(MakeBGridDescriptor_K0_N_K1_Tail(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
static constexpr auto MakeBlock2CTileMap(index_t batch_count,
const CGridDesc_M_N& c_grid_desc_m_n,
......@@ -543,9 +458,9 @@ struct DeviceGemmXdlSplitK
has_tail_ = true;
const auto KTail = K - KSplitted * (BatchCount_ - 1);
a_grid_desc_k0_m_k1_tail_ =
DeviceGemmXdlSplitK::MakeAGridDescriptor_K0_M_K1_Tail(M, KTail, StrideA);
DeviceGemmXdlSplitK::MakeAGridDescriptor_K0_M_K1(M, KTail, StrideA);
b_grid_desc_k0_n_k1_tail_ =
DeviceGemmXdlSplitK::MakeBGridDescriptor_K0_N_K1_Tail(KTail, N, StrideB);
DeviceGemmXdlSplitK::MakeBGridDescriptor_K0_N_K1(KTail, N, StrideB);
is_valid &= GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_tail_,
b_grid_desc_k0_n_k1_tail_,
......@@ -597,8 +512,8 @@ struct DeviceGemmXdlSplitK
bool has_tail_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
AGridDesc_K0_M_K1_Tail a_grid_desc_k0_m_k1_tail_;
BGridDesc_K0_N_K1_Tail b_grid_desc_k0_n_k1_tail_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_tail_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_tail_;
CGridDesc_M_N c_grid_desc_m_n_;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
......@@ -707,8 +622,6 @@ struct DeviceGemmXdlSplitK
CDataType,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1_Tail>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1_Tail>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
......@@ -728,8 +641,6 @@ struct DeviceGemmXdlSplitK
CDataType,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1_Tail>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1_Tail>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
......@@ -749,8 +660,6 @@ struct DeviceGemmXdlSplitK
CDataType,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1_Tail>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1_Tail>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
......@@ -770,8 +679,6 @@ struct DeviceGemmXdlSplitK
CDataType,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1_Tail>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1_Tail>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
......
......@@ -33,8 +33,6 @@ template <typename GridwiseGemm,
typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename AGridDesc_AK0_M_AK1_Tail,
typename BGridDesc_BK0_N_BK1_Tail,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename ComputePtrOffsetOfBatch,
typename Block2CTileMap,
......@@ -54,8 +52,8 @@ __global__ void
const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const AGridDesc_AK0_M_AK1_Tail a_grid_desc_ak0_m_ak1_tail,
const BGridDesc_BK0_N_BK1_Tail b_grid_desc_bk0_n_bk1_tail,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_tail,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_tail,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
......@@ -183,118 +181,7 @@ struct DeviceGemmXdlSplitKCShuffle
return std::make_pair(actual_batch, KSplitted);
}
template <bool IsTail>
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA);
template <bool IsTail>
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB);
/*
* No padding in K
*/
template <>
static auto MakeAGridDescriptor_AK0_M_AK1<false>(index_t MRaw, index_t K, index_t StrideA)
{
// return MakeAGridDescriptor_AK0_M_AK1<true>(MRaw, K, StrideA);
assert(K % KPerBlock == 0);
assert(K % AK1 == 0);
const auto a_grid_desc_mraw_k = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, K), make_tuple(I1, StrideA));
}
}();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto MPad = M - MRaw;
const auto AK0 = K / AK1;
if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M, but not K
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else
{
// not pad M or K
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
}
template <>
static auto MakeBGridDescriptor_BK0_N_BK1<false>(index_t K, index_t NRaw, index_t StrideB)
{
// return MakeBGridDescriptor_BK0_N_BK1<true>(K, NRaw, StrideB);
assert(K % KPerBlock == 0);
assert(K % BK1 == 0);
const auto b_grid_desc_nraw_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, K), make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, K), make_tuple(StrideB, I1));
}
}();
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto NPad = N - NRaw;
const auto BK0 = K / BK1;
if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad N, but not K
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
}
template <>
static auto MakeAGridDescriptor_AK0_M_AK1<true>(index_t MRaw, index_t KRaw, index_t StrideA)
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
......@@ -359,8 +246,7 @@ struct DeviceGemmXdlSplitKCShuffle
}
}
template <>
static auto MakeBGridDescriptor_BK0_N_BK1<true>(index_t KRaw, index_t NRaw, index_t StrideB)
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
......@@ -481,11 +367,9 @@ struct DeviceGemmXdlSplitKCShuffle
}
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1<false>(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1<false>(1, 1, 1));
using AGridDesc_AK0_M_AK1_Tail = decltype(MakeAGridDescriptor_AK0_M_AK1<true>(1, 1, 1));
using BGridDesc_BK0_N_BK1_Tail = decltype(MakeBGridDescriptor_BK0_N_BK1<true>(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
struct ComputePtrOffsetOfStridedBatch
{
......@@ -598,9 +482,9 @@ struct DeviceGemmXdlSplitKCShuffle
const auto BKSplitted = actual_batch_and_ksplitted_B.second;
a_grid_desc_ak0_m_ak1_ =
DeviceOp::MakeAGridDescriptor_AK0_M_AK1<false>(MRaw, AKSplitted, StrideA);
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, AKSplitted, StrideA);
b_grid_desc_bk0_n_bk1_ =
DeviceOp::MakeBGridDescriptor_BK0_N_BK1<false>(BKSplitted, NRaw, StrideB);
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(BKSplitted, NRaw, StrideB);
c_grid_desc_m_n_ = DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC);
is_valid_ = GridwiseGemm::CheckValidity(
......@@ -613,9 +497,9 @@ struct DeviceGemmXdlSplitKCShuffle
const auto BKTail = KRaw - BKSplitted * (BatchCount_ - 1);
a_grid_desc_ak0_m_ak1_tail_ =
DeviceOp::MakeAGridDescriptor_AK0_M_AK1<true>(MRaw, AKTail, StrideA);
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, AKTail, StrideA);
b_grid_desc_bk0_n_bk1_tail_ =
DeviceOp::MakeBGridDescriptor_BK0_N_BK1<true>(BKTail, NRaw, StrideB);
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(BKTail, NRaw, StrideB);
is_valid_ &= GridwiseGemm::CheckValidity(
a_grid_desc_ak0_m_ak1_tail_, b_grid_desc_bk0_n_bk1_tail_, c_grid_desc_m_n_);
......@@ -668,8 +552,8 @@ struct DeviceGemmXdlSplitKCShuffle
bool is_valid_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
AGridDesc_AK0_M_AK1_Tail a_grid_desc_ak0_m_ak1_tail_;
BGridDesc_BK0_N_BK1_Tail b_grid_desc_bk0_n_bk1_tail_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_tail_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_tail_;
CGridDesc_M_N c_grid_desc_m_n_;
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
......@@ -796,8 +680,6 @@ struct DeviceGemmXdlSplitKCShuffle
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::AGridDesc_AK0_M_AK1_Tail,
DeviceOp::BGridDesc_BK0_N_BK1_Tail,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffsetOfStridedBatch,
Block2CTileMap,
......@@ -817,8 +699,6 @@ struct DeviceGemmXdlSplitKCShuffle
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::AGridDesc_AK0_M_AK1_Tail,
DeviceOp::BGridDesc_BK0_N_BK1_Tail,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffsetOfStridedBatch,
Block2CTileMap,
......@@ -838,8 +718,6 @@ struct DeviceGemmXdlSplitKCShuffle
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::AGridDesc_AK0_M_AK1_Tail,
DeviceOp::BGridDesc_BK0_N_BK1_Tail,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffsetOfStridedBatch,
Block2CTileMap,
......@@ -859,8 +737,6 @@ struct DeviceGemmXdlSplitKCShuffle
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::AGridDesc_AK0_M_AK1_Tail,
DeviceOp::BGridDesc_BK0_N_BK1_Tail,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffsetOfStridedBatch,
Block2CTileMap,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment