Commit 7a62d4a7 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Rename variables M/N/KRaw to M/N/K

parent 41449b67
...@@ -82,19 +82,17 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -82,19 +82,17 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static auto MakeAGridDescriptor_AK0_M_AK1( static auto
index_t MRaw, index_t MPad, index_t KRaw, index_t KPad, index_t StrideA) MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA)
{ {
const auto a_grid_desc_mraw_kraw = [&]() { const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>) if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{ {
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
make_tuple(StrideA, I1));
} }
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{ {
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
make_tuple(I1, StrideA));
} }
}(); }();
...@@ -108,8 +106,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -108,8 +106,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const auto a_grid_desc_m_k = const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw, transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(MRaw, MPad - MRaw), make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(KRaw, KPad - KRaw)), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -126,14 +124,14 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -126,14 +124,14 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
GemmSpec == GemmSpecialization::MNPadding) GemmSpec == GemmSpecialization::MNPadding)
{ {
// pad M, but not K // pad M, but not K
assert(KRaw % AK1 == 0); assert(K % AK1 == 0);
const auto AK0 = KRaw / AK1; const auto AK0 = K / AK1;
const auto a_grid_desc_ak0_m_ak1 = const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw, transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_right_pad_transform(MRaw, MPad - MRaw)), make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
...@@ -147,17 +145,16 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -147,17 +145,16 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const auto AK0 = KPad / AK1; const auto AK0 = KPad / AK1;
const auto a_grid_desc_m_k = const auto a_grid_desc_m_k = transform_tensor_descriptor(
transform_tensor_descriptor(a_grid_desc_mraw_kraw, a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(MRaw), make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_right_pad_transform(KRaw, KPad - KRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k, transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)), make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
...@@ -166,14 +163,14 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -166,14 +163,14 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
else else
{ {
// not pad M or K // not pad M or K
assert(KRaw % AK1 == 0); assert(K % AK1 == 0);
const auto AK0 = KRaw / AK1; const auto AK0 = K / AK1;
const auto a_grid_desc_ak0_m_ak1 = const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw, transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)), make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
...@@ -181,19 +178,17 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -181,19 +178,17 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
} }
} }
static auto MakeBGridDescriptor_BK0_N_BK1( static auto
index_t KRaw, index_t KPad, index_t NRaw, index_t NPad, index_t StrideB) MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB)
{ {
const auto b_grid_desc_nraw_kraw = [&]() { const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
make_tuple(I1, StrideB));
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
make_tuple(StrideB, I1));
} }
}(); }();
...@@ -207,8 +202,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -207,8 +202,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const auto b_grid_desc_n_k = const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw, transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad - NRaw), make_tuple(make_right_pad_transform(N, NPad - N),
make_right_pad_transform(KRaw, KPad - KRaw)), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -225,14 +220,14 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -225,14 +220,14 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
GemmSpec == GemmSpecialization::MNPadding) GemmSpec == GemmSpecialization::MNPadding)
{ {
// pad N, but not K // pad N, but not K
assert(KRaw % BK1 == 0); assert(K % BK1 == 0);
const auto BK0 = KRaw / BK1; const auto BK0 = K / BK1;
const auto b_grid_desc_bk0_n_bk1 = const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw, transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_right_pad_transform(NRaw, NPad - NRaw)), make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
...@@ -246,17 +241,16 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -246,17 +241,16 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const auto BK0 = KPad / BK1; const auto BK0 = KPad / BK1;
const auto b_grid_desc_n_k = const auto b_grid_desc_n_k = transform_tensor_descriptor(
transform_tensor_descriptor(b_grid_desc_nraw_kraw, b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(NRaw), make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)),
make_right_pad_transform(KRaw, KPad - KRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 = const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k, transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)), make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
...@@ -265,14 +259,14 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -265,14 +259,14 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
else else
{ {
// not pad N or K // not pad N or K
assert(KRaw % BK1 == 0); assert(K % BK1 == 0);
const auto BK0 = KRaw / BK1; const auto BK0 = K / BK1;
const auto b_grid_desc_bk0_n_bk1 = const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw, transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)), make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
...@@ -281,18 +275,16 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -281,18 +275,16 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
} }
static auto static auto
MakeCGridDescriptor_M_N(index_t MRaw, index_t MPad, index_t NRaw, index_t NPad, index_t StrideC) MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
{ {
const auto c_grid_desc_mraw_nraw = [&]() { const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
make_tuple(StrideC, I1));
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
make_tuple(I1, StrideC));
} }
}(); }();
...@@ -300,10 +292,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -300,10 +292,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
GemmSpec == GemmSpecialization::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad M and N // pad M and N
return transform_tensor_descriptor( return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
c_grid_desc_mraw_nraw, make_tuple(make_right_pad_transform(M, MPad - M),
make_tuple(make_right_pad_transform(MRaw, MPad - MRaw), make_right_pad_transform(N, NPad - N)),
make_right_pad_transform(NRaw, NPad - NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
...@@ -313,8 +304,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -313,8 +304,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// pad M, but not N // pad M, but not N
return transform_tensor_descriptor( return transform_tensor_descriptor(
c_grid_desc_mraw_nraw, c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad - MRaw), make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
...@@ -324,8 +314,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -324,8 +314,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// pad N, but not M // pad N, but not M
return transform_tensor_descriptor( return transform_tensor_descriptor(
c_grid_desc_mraw_nraw, c_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(MRaw), make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)),
make_right_pad_transform(NRaw, NPad - NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
...@@ -393,9 +382,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -393,9 +382,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
Argument(const ADataType* p_a_grid, Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid, const BDataType* p_b_grid,
CDataType* p_c_grid, CDataType* p_c_grid,
index_t MRaw, index_t M,
index_t NRaw, index_t N,
index_t KRaw, index_t K,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
...@@ -406,29 +395,28 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -406,29 +395,28 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
a_grid_desc_ak0_m_ak1_{ a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, DeviceOp::MakeAGridDescriptor_AK0_M_AK1(M,
GridwiseGemm::CalculateMPadded(MRaw), GridwiseGemm::CalculateMPadded(M),
KRaw, K,
GridwiseGemm::CalculateKPadded(KRaw), GridwiseGemm::CalculateKPadded(K),
StrideA)}, StrideA)},
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, DeviceOp::MakeBGridDescriptor_BK0_N_BK1(K,
GridwiseGemm::CalculateKPadded(KRaw), GridwiseGemm::CalculateKPadded(K),
NRaw, N,
GridwiseGemm::CalculateNPadded(NRaw), GridwiseGemm::CalculateNPadded(N),
StrideB)}, StrideB)},
c_grid_desc_m_n_{ c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(M,
DeviceOp::MakeCGridDescriptor_M_N(MRaw, GridwiseGemm::CalculateMPadded(M),
GridwiseGemm::CalculateMPadded(MRaw), N,
NRaw, GridwiseGemm::CalculateNPadded(N),
GridwiseGemm::CalculateNPadded(NRaw),
StrideC)}, StrideC)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
kraw_{KRaw} kraw_{K}
{ {
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_, b_grid_desc_bk0_n_bk1_,
...@@ -608,9 +596,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -608,9 +596,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
static auto MakeArgument(const ADataType* p_a, static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b, const BDataType* p_b,
CDataType* p_c, CDataType* p_c,
index_t MRaw, index_t M,
index_t NRaw, index_t N,
index_t KRaw, index_t K,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
...@@ -621,9 +609,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -621,9 +609,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
return Argument{p_a, return Argument{p_a,
p_b, p_b,
p_c, p_c,
MRaw, M,
NRaw, N,
KRaw, K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
...@@ -638,9 +626,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -638,9 +626,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
void* p_c, void* p_c,
index_t MRaw, index_t M,
index_t NRaw, index_t N,
index_t KRaw, index_t K,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
...@@ -651,9 +639,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -651,9 +639,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
MRaw, M,
NRaw, N,
KRaw, K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment