Commit 7a62d4a7 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Rename variables M/N/KRaw to M/N/K

parent 41449b67
......@@ -82,19 +82,17 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static auto MakeAGridDescriptor_AK0_M_AK1(
index_t MRaw, index_t MPad, index_t KRaw, index_t KPad, index_t StrideA)
static auto
MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA, I1));
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
......@@ -108,8 +106,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(MRaw, MPad - MRaw),
make_right_pad_transform(KRaw, KPad - KRaw)),
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -126,14 +124,14 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
GemmSpec == GemmSpecialization::MNPadding)
{
// pad M, but not K
assert(KRaw % AK1 == 0);
assert(K % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto AK0 = K / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_right_pad_transform(MRaw, MPad - MRaw)),
make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
......@@ -147,17 +145,16 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const auto AK0 = KPad / AK1;
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(MRaw),
make_right_pad_transform(KRaw, KPad - KRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
......@@ -166,14 +163,14 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
else
{
// not pad M or K
assert(KRaw % AK1 == 0);
assert(K % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto AK0 = K / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
......@@ -181,19 +178,17 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
}
}
static auto MakeBGridDescriptor_BK0_N_BK1(
index_t KRaw, index_t KPad, index_t NRaw, index_t NPad, index_t StrideB)
static auto
MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
}
}();
......@@ -207,8 +202,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad - NRaw),
make_right_pad_transform(KRaw, KPad - KRaw)),
make_tuple(make_right_pad_transform(N, NPad - N),
make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -225,14 +220,14 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
assert(KRaw % BK1 == 0);
assert(K % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto BK0 = K / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_right_pad_transform(NRaw, NPad - NRaw)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
......@@ -246,17 +241,16 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const auto BK0 = KPad / BK1;
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(NRaw),
make_right_pad_transform(KRaw, KPad - KRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
......@@ -265,14 +259,14 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
else
{
// not pad N or K
assert(KRaw % BK1 == 0);
assert(K % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto BK0 = K / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
......@@ -281,18 +275,16 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
}
static auto
MakeCGridDescriptor_M_N(index_t MRaw, index_t MPad, index_t NRaw, index_t NPad, index_t StrideC)
MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
{
const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideC, I1));
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideC));
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}();
......@@ -300,12 +292,11 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad - MRaw),
make_right_pad_transform(NRaw, NPad - NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
......@@ -313,8 +304,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad - MRaw),
make_pass_through_transform(NRaw)),
make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
......@@ -324,8 +314,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(MRaw),
make_right_pad_transform(NRaw, NPad - NRaw)),
make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
......@@ -393,9 +382,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
......@@ -406,29 +395,28 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw,
GridwiseGemm::CalculateMPadded(MRaw),
KRaw,
GridwiseGemm::CalculateKPadded(KRaw),
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(M,
GridwiseGemm::CalculateMPadded(M),
K,
GridwiseGemm::CalculateKPadded(K),
StrideA)},
b_grid_desc_bk0_n_bk1_{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw,
GridwiseGemm::CalculateKPadded(KRaw),
NRaw,
GridwiseGemm::CalculateNPadded(NRaw),
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(K,
GridwiseGemm::CalculateKPadded(K),
N,
GridwiseGemm::CalculateNPadded(N),
StrideB)},
c_grid_desc_m_n_{
DeviceOp::MakeCGridDescriptor_M_N(MRaw,
GridwiseGemm::CalculateMPadded(MRaw),
NRaw,
GridwiseGemm::CalculateNPadded(NRaw),
StrideC)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(M,
GridwiseGemm::CalculateMPadded(M),
N,
GridwiseGemm::CalculateNPadded(N),
StrideC)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
kraw_{KRaw}
kraw_{K}
{
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
......@@ -608,9 +596,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
......@@ -621,9 +609,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
return Argument{p_a,
p_b,
p_c,
MRaw,
NRaw,
KRaw,
M,
N,
K,
StrideA,
StrideB,
StrideC,
......@@ -638,9 +626,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
......@@ -651,9 +639,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
MRaw,
NRaw,
KRaw,
M,
N,
K,
StrideA,
StrideB,
StrideC,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment