Commit 41449b67 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Replace duplicated local variable by parameters

parent e9144d38
......@@ -82,7 +82,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
static auto MakeAGridDescriptor_AK0_M_AK1(
index_t MRaw, index_t MPad, index_t KRaw, index_t KPad, index_t StrideA)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
......@@ -97,9 +98,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
}
}();
const auto MPad = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto KPad = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
......@@ -149,9 +147,10 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const auto AK0 = KPad / AK1;
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad - KRaw)),
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(MRaw),
make_right_pad_transform(KRaw, KPad - KRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -182,7 +181,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
}
}
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
static auto MakeBGridDescriptor_BK0_N_BK1(
index_t KRaw, index_t KPad, index_t NRaw, index_t NPad, index_t StrideB)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
......@@ -197,9 +197,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
}
}();
const auto NPad = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto KPad = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
......@@ -249,9 +246,10 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const auto BK0 = KPad / BK1;
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad - KRaw)),
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(NRaw),
make_right_pad_transform(KRaw, KPad - KRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -282,7 +280,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
}
}
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
static auto
MakeCGridDescriptor_M_N(index_t MRaw, index_t MPad, index_t NRaw, index_t NPad, index_t StrideC)
{
const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
......@@ -297,14 +296,12 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
}
}();
const auto MPad = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto NPad = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad - MRaw),
make_right_pad_transform(NRaw, NPad - NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
......@@ -316,7 +313,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad - MRaw), make_pass_through_transform(NRaw)),
make_tuple(make_right_pad_transform(MRaw, MPad - MRaw),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
......@@ -326,7 +324,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad - NRaw)),
make_tuple(make_pass_through_transform(MRaw),
make_right_pad_transform(NRaw, NPad - NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
......@@ -337,9 +336,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
}
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
......@@ -406,9 +405,24 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw,
GridwiseGemm::CalculateMPadded(MRaw),
KRaw,
GridwiseGemm::CalculateKPadded(KRaw),
StrideA)},
b_grid_desc_bk0_n_bk1_{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw,
GridwiseGemm::CalculateKPadded(KRaw),
NRaw,
GridwiseGemm::CalculateNPadded(NRaw),
StrideB)},
c_grid_desc_m_n_{
DeviceOp::MakeCGridDescriptor_M_N(MRaw,
GridwiseGemm::CalculateMPadded(MRaw),
NRaw,
GridwiseGemm::CalculateNPadded(NRaw),
StrideC)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op},
......
......@@ -135,6 +135,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
__host__ __device__ static auto CalculateMPadded(index_t M)
{
return (M + MPerBlock - 1) / MPerBlock * MPerBlock;
}
__host__ __device__ static auto CalculateNPadded(index_t N)
{
return (N + NPerBlock - 1) / NPerBlock * NPerBlock;
}
__host__ __device__ static auto CalculateKPadded(index_t K)
{
return (K + KPerBlock - 1) / KPerBlock * KPerBlock;
}
// FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment