Commit 06f57782 authored by Jianfeng yan's avatar Jianfeng yan
Browse files

regress to using 1 grid_desc

parent 308146e7
...@@ -29,8 +29,6 @@ template <typename GridwiseGemm, ...@@ -29,8 +29,6 @@ template <typename GridwiseGemm,
typename FloatC, typename FloatC,
typename AGridDesc_K0_M_K1, typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1, typename BGridDesc_K0_N_K1,
typename AGridDesc_K0_M_K1_Tail,
typename BGridDesc_K0_N_K1_Tail,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2, typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
...@@ -50,8 +48,8 @@ __global__ void ...@@ -50,8 +48,8 @@ __global__ void
const index_t batch_count, const index_t batch_count,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const AGridDesc_K0_M_K1_Tail a_grid_desc_k0_m_k1_tail, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_tail,
const BGridDesc_K0_N_K1_Tail b_grid_desc_k0_n_k1_tail, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_tail,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
...@@ -173,7 +171,7 @@ struct DeviceGemmXdlSplitK ...@@ -173,7 +171,7 @@ struct DeviceGemmXdlSplitK
return std::make_pair(actual_batch, KSplitted); return std::make_pair(actual_batch, KSplitted);
} }
static auto MakeAGridDescriptor_K0_M_K1_Tail(index_t M, index_t K, index_t StrideA) static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
{ {
const index_t KPadded = math::integer_divide_ceil(K, K1 * K0PerBlock) * K1 * K0PerBlock; const index_t KPadded = math::integer_divide_ceil(K, K1 * K0PerBlock) * K1 * K0PerBlock;
const index_t K0 = KPadded / K1; const index_t K0 = KPadded / K1;
...@@ -217,7 +215,7 @@ struct DeviceGemmXdlSplitK ...@@ -217,7 +215,7 @@ struct DeviceGemmXdlSplitK
} }
} }
static auto MakeBGridDescriptor_K0_N_K1_Tail(index_t K, index_t N, index_t StrideB) static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
{ {
const index_t KPadded = math::integer_divide_ceil(K, K1 * K0PerBlock) * K1 * K0PerBlock; const index_t KPadded = math::integer_divide_ceil(K, K1 * K0PerBlock) * K1 * K0PerBlock;
...@@ -262,87 +260,6 @@ struct DeviceGemmXdlSplitK ...@@ -262,87 +260,6 @@ struct DeviceGemmXdlSplitK
} }
} }
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
{
// return MakeAGridDescriptor_K0_M_K1_Tail(M, K, StrideA);
assert(K % (K1 * K0PerBlock) == 0);
const index_t K0 = K / K1;
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(M, PadM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
{
// return MakeBGridDescriptor_K0_N_K1_Tail(K, N, StrideB);
assert(K % (K1 * K0PerBlock) == 0);
const index_t K0 = K / K1;
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
{ {
const auto c_grid_desc_m_n = [&]() { const auto c_grid_desc_m_n = [&]() {
...@@ -378,11 +295,9 @@ struct DeviceGemmXdlSplitK ...@@ -378,11 +295,9 @@ struct DeviceGemmXdlSplitK
} }
} }
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using AGridDesc_K0_M_K1_Tail = decltype(MakeAGridDescriptor_K0_M_K1_Tail(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using BGridDesc_K0_N_K1_Tail = decltype(MakeBGridDescriptor_K0_N_K1_Tail(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
static constexpr auto MakeBlock2CTileMap(index_t batch_count, static constexpr auto MakeBlock2CTileMap(index_t batch_count,
const CGridDesc_M_N& c_grid_desc_m_n, const CGridDesc_M_N& c_grid_desc_m_n,
...@@ -543,9 +458,9 @@ struct DeviceGemmXdlSplitK ...@@ -543,9 +458,9 @@ struct DeviceGemmXdlSplitK
has_tail_ = true; has_tail_ = true;
const auto KTail = K - KSplitted * (BatchCount_ - 1); const auto KTail = K - KSplitted * (BatchCount_ - 1);
a_grid_desc_k0_m_k1_tail_ = a_grid_desc_k0_m_k1_tail_ =
DeviceGemmXdlSplitK::MakeAGridDescriptor_K0_M_K1_Tail(M, KTail, StrideA); DeviceGemmXdlSplitK::MakeAGridDescriptor_K0_M_K1(M, KTail, StrideA);
b_grid_desc_k0_n_k1_tail_ = b_grid_desc_k0_n_k1_tail_ =
DeviceGemmXdlSplitK::MakeBGridDescriptor_K0_N_K1_Tail(KTail, N, StrideB); DeviceGemmXdlSplitK::MakeBGridDescriptor_K0_N_K1(KTail, N, StrideB);
is_valid &= GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_tail_, is_valid &= GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_tail_,
b_grid_desc_k0_n_k1_tail_, b_grid_desc_k0_n_k1_tail_,
...@@ -597,8 +512,8 @@ struct DeviceGemmXdlSplitK ...@@ -597,8 +512,8 @@ struct DeviceGemmXdlSplitK
bool has_tail_; bool has_tail_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
AGridDesc_K0_M_K1_Tail a_grid_desc_k0_m_k1_tail_; AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_tail_;
BGridDesc_K0_N_K1_Tail b_grid_desc_k0_n_k1_tail_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_tail_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
...@@ -707,8 +622,6 @@ struct DeviceGemmXdlSplitK ...@@ -707,8 +622,6 @@ struct DeviceGemmXdlSplitK
CDataType, CDataType,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1_Tail>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1_Tail>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
...@@ -728,8 +641,6 @@ struct DeviceGemmXdlSplitK ...@@ -728,8 +641,6 @@ struct DeviceGemmXdlSplitK
CDataType, CDataType,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1_Tail>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1_Tail>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
...@@ -749,8 +660,6 @@ struct DeviceGemmXdlSplitK ...@@ -749,8 +660,6 @@ struct DeviceGemmXdlSplitK
CDataType, CDataType,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1_Tail>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1_Tail>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
...@@ -770,8 +679,6 @@ struct DeviceGemmXdlSplitK ...@@ -770,8 +679,6 @@ struct DeviceGemmXdlSplitK
CDataType, CDataType,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1_Tail>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1_Tail>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
......
...@@ -33,8 +33,6 @@ template <typename GridwiseGemm, ...@@ -33,8 +33,6 @@ template <typename GridwiseGemm,
typename CElementwiseOperation, typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename AGridDesc_AK0_M_AK1_Tail,
typename BGridDesc_BK0_N_BK1_Tail,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename ComputePtrOffsetOfBatch, typename ComputePtrOffsetOfBatch,
typename Block2CTileMap, typename Block2CTileMap,
...@@ -54,8 +52,8 @@ __global__ void ...@@ -54,8 +52,8 @@ __global__ void
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const AGridDesc_AK0_M_AK1_Tail a_grid_desc_ak0_m_ak1_tail, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_tail,
const BGridDesc_BK0_N_BK1_Tail b_grid_desc_bk0_n_bk1_tail, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_tail,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
...@@ -183,118 +181,7 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -183,118 +181,7 @@ struct DeviceGemmXdlSplitKCShuffle
return std::make_pair(actual_batch, KSplitted); return std::make_pair(actual_batch, KSplitted);
} }
template <bool IsTail> static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA);
template <bool IsTail>
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB);
/*
* No padding in K
*/
template <>
static auto MakeAGridDescriptor_AK0_M_AK1<false>(index_t MRaw, index_t K, index_t StrideA)
{
// return MakeAGridDescriptor_AK0_M_AK1<true>(MRaw, K, StrideA);
assert(K % KPerBlock == 0);
assert(K % AK1 == 0);
const auto a_grid_desc_mraw_k = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, K), make_tuple(I1, StrideA));
}
}();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto MPad = M - MRaw;
const auto AK0 = K / AK1;
if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M, but not K
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else
{
// not pad M or K
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
}
template <>
static auto MakeBGridDescriptor_BK0_N_BK1<false>(index_t K, index_t NRaw, index_t StrideB)
{
// return MakeBGridDescriptor_BK0_N_BK1<true>(K, NRaw, StrideB);
assert(K % KPerBlock == 0);
assert(K % BK1 == 0);
const auto b_grid_desc_nraw_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, K), make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, K), make_tuple(StrideB, I1));
}
}();
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto NPad = N - NRaw;
const auto BK0 = K / BK1;
if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad N, but not K
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
}
template <>
static auto MakeAGridDescriptor_AK0_M_AK1<true>(index_t MRaw, index_t KRaw, index_t StrideA)
{ {
const auto a_grid_desc_mraw_kraw = [&]() { const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>) if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
...@@ -359,8 +246,7 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -359,8 +246,7 @@ struct DeviceGemmXdlSplitKCShuffle
} }
} }
template <> static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
static auto MakeBGridDescriptor_BK0_N_BK1<true>(index_t KRaw, index_t NRaw, index_t StrideB)
{ {
const auto b_grid_desc_nraw_kraw = [&]() { const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
...@@ -481,11 +367,9 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -481,11 +367,9 @@ struct DeviceGemmXdlSplitKCShuffle
} }
} }
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1<false>(1, 1, 1)); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1<false>(1, 1, 1)); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using AGridDesc_AK0_M_AK1_Tail = decltype(MakeAGridDescriptor_AK0_M_AK1<true>(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using BGridDesc_BK0_N_BK1_Tail = decltype(MakeBGridDescriptor_BK0_N_BK1<true>(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
struct ComputePtrOffsetOfStridedBatch struct ComputePtrOffsetOfStridedBatch
{ {
...@@ -598,9 +482,9 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -598,9 +482,9 @@ struct DeviceGemmXdlSplitKCShuffle
const auto BKSplitted = actual_batch_and_ksplitted_B.second; const auto BKSplitted = actual_batch_and_ksplitted_B.second;
a_grid_desc_ak0_m_ak1_ = a_grid_desc_ak0_m_ak1_ =
DeviceOp::MakeAGridDescriptor_AK0_M_AK1<false>(MRaw, AKSplitted, StrideA); DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, AKSplitted, StrideA);
b_grid_desc_bk0_n_bk1_ = b_grid_desc_bk0_n_bk1_ =
DeviceOp::MakeBGridDescriptor_BK0_N_BK1<false>(BKSplitted, NRaw, StrideB); DeviceOp::MakeBGridDescriptor_BK0_N_BK1(BKSplitted, NRaw, StrideB);
c_grid_desc_m_n_ = DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC); c_grid_desc_m_n_ = DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC);
is_valid_ = GridwiseGemm::CheckValidity( is_valid_ = GridwiseGemm::CheckValidity(
...@@ -613,9 +497,9 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -613,9 +497,9 @@ struct DeviceGemmXdlSplitKCShuffle
const auto BKTail = KRaw - BKSplitted * (BatchCount_ - 1); const auto BKTail = KRaw - BKSplitted * (BatchCount_ - 1);
a_grid_desc_ak0_m_ak1_tail_ = a_grid_desc_ak0_m_ak1_tail_ =
DeviceOp::MakeAGridDescriptor_AK0_M_AK1<true>(MRaw, AKTail, StrideA); DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, AKTail, StrideA);
b_grid_desc_bk0_n_bk1_tail_ = b_grid_desc_bk0_n_bk1_tail_ =
DeviceOp::MakeBGridDescriptor_BK0_N_BK1<true>(BKTail, NRaw, StrideB); DeviceOp::MakeBGridDescriptor_BK0_N_BK1(BKTail, NRaw, StrideB);
is_valid_ &= GridwiseGemm::CheckValidity( is_valid_ &= GridwiseGemm::CheckValidity(
a_grid_desc_ak0_m_ak1_tail_, b_grid_desc_bk0_n_bk1_tail_, c_grid_desc_m_n_); a_grid_desc_ak0_m_ak1_tail_, b_grid_desc_bk0_n_bk1_tail_, c_grid_desc_m_n_);
...@@ -668,8 +552,8 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -668,8 +552,8 @@ struct DeviceGemmXdlSplitKCShuffle
bool is_valid_; bool is_valid_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
AGridDesc_AK0_M_AK1_Tail a_grid_desc_ak0_m_ak1_tail_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_tail_;
BGridDesc_BK0_N_BK1_Tail b_grid_desc_bk0_n_bk1_tail_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_tail_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_; c_grid_desc_mblock_mperblock_nblock_nperblock_;
...@@ -796,8 +680,6 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -796,8 +680,6 @@ struct DeviceGemmXdlSplitKCShuffle
CElementwiseOperation, CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::AGridDesc_AK0_M_AK1_Tail,
DeviceOp::BGridDesc_BK0_N_BK1_Tail,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffsetOfStridedBatch, ComputePtrOffsetOfStridedBatch,
Block2CTileMap, Block2CTileMap,
...@@ -817,8 +699,6 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -817,8 +699,6 @@ struct DeviceGemmXdlSplitKCShuffle
CElementwiseOperation, CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::AGridDesc_AK0_M_AK1_Tail,
DeviceOp::BGridDesc_BK0_N_BK1_Tail,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffsetOfStridedBatch, ComputePtrOffsetOfStridedBatch,
Block2CTileMap, Block2CTileMap,
...@@ -838,8 +718,6 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -838,8 +718,6 @@ struct DeviceGemmXdlSplitKCShuffle
CElementwiseOperation, CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::AGridDesc_AK0_M_AK1_Tail,
DeviceOp::BGridDesc_BK0_N_BK1_Tail,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffsetOfStridedBatch, ComputePtrOffsetOfStridedBatch,
Block2CTileMap, Block2CTileMap,
...@@ -859,8 +737,6 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -859,8 +737,6 @@ struct DeviceGemmXdlSplitKCShuffle
CElementwiseOperation, CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::AGridDesc_AK0_M_AK1_Tail,
DeviceOp::BGridDesc_BK0_N_BK1_Tail,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffsetOfStridedBatch, ComputePtrOffsetOfStridedBatch,
Block2CTileMap, Block2CTileMap,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment