Commit 3a558e59 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Move AK0/BK0 compute logic into GridwiseGemm

parent 7a62d4a7
......@@ -82,8 +82,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static auto
MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA)
static auto MakeAGridDescriptor_AK0_M_AK1(
index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
......@@ -100,10 +100,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both M and K
assert(K % AK1 == 0);
const auto AK0 = KPad / AK1;
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(M, MPad - M),
......@@ -124,10 +120,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
GemmSpec == GemmSpecialization::MNPadding)
{
// pad M, but not K
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
......@@ -141,10 +133,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
GemmSpec == GemmSpecialization::NKPadding)
{
// pad K, but not M
assert(K % AK1 == 0);
const auto AK0 = KPad / AK1;
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
......@@ -163,10 +151,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
else
{
// not pad M or K
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
......@@ -178,8 +162,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
}
}
static auto
MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB)
static auto MakeBGridDescriptor_BK0_N_BK1(
index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
......@@ -196,10 +180,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both N and K
assert(K % BK1 == 0);
const auto BK0 = KPad / BK1;
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(N, NPad - N),
......@@ -220,10 +200,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
......@@ -237,10 +213,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
GemmSpec == GemmSpecialization::MKPadding)
{
// pad K, but not N
assert(K % BK1 == 0);
const auto BK0 = KPad / BK1;
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)),
......@@ -259,10 +231,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
else
{
// not pad N or K
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
......@@ -325,8 +293,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
}
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1, 1, 1));
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1, 1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1, 1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1));
// GridwiseGemm
......@@ -338,6 +306,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
......@@ -399,13 +368,15 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
GridwiseGemm::CalculateMPadded(M),
K,
GridwiseGemm::CalculateKPadded(K),
StrideA)},
StrideA,
GridwiseGemm::CalculateAK0(K))},
b_grid_desc_bk0_n_bk1_{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(K,
GridwiseGemm::CalculateKPadded(K),
N,
GridwiseGemm::CalculateNPadded(N),
StrideB)},
StrideB,
GridwiseGemm::CalculateBK0(K))},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(M,
GridwiseGemm::CalculateMPadded(M),
N,
......
......@@ -79,6 +79,7 @@ template <typename FloatAB,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
......@@ -150,6 +151,48 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return (K + KPerBlock - 1) / KPerBlock * KPerBlock;
}
__host__ __device__ static auto CalculateAK0(index_t K)
{
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
assert(CalculateKPadded(K) % AK1 == 0);
return CalculateKPadded(K) / AK1;
}
else
{
assert(K % AK1 == 0);
return K / AK1;
}
}
__host__ __device__ static auto CalculateBK0(index_t K)
{
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
assert(CalculateKPadded(K) % BK1 == 0);
return CalculateKPadded(K) / BK1;
}
else
{
assert(K % BK1 == 0);
return K / BK1;
}
}
// FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment