Commit 7f29ed0b authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Remove M/N/KPad local variables

parent bb5530af
......@@ -100,9 +100,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto MPad = M - MRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
......@@ -113,8 +110,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(make_right_pad_transform(MRaw, M - MRaw),
make_right_pad_transform(KRaw, K - KRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -138,7 +135,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_right_pad_transform(MRaw, MPad)),
make_right_pad_transform(MRaw, M - MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
......@@ -154,7 +151,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, K - KRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -203,9 +200,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto NPad = N - NRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
......@@ -216,8 +210,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(make_right_pad_transform(NRaw, N - NRaw),
make_right_pad_transform(KRaw, K - KRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -241,7 +235,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_right_pad_transform(NRaw, NPad)),
make_right_pad_transform(NRaw, N - NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
......@@ -257,7 +251,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, K - KRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -306,16 +300,13 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(NRaw, NPad)),
make_tuple(make_right_pad_transform(MRaw, M - MRaw),
make_right_pad_transform(NRaw, N - NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
......@@ -325,7 +316,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
make_tuple(make_right_pad_transform(MRaw, M - MRaw), make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
......@@ -335,7 +326,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, N - NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment