Commit e9144d38 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Use M/N/KPad to name padded lengths

parent 7f29ed0b
...@@ -97,8 +97,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -97,8 +97,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
} }
}(); }();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; const auto MPad = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; const auto KPad = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
if constexpr(GemmSpec == GemmSpecialization::MKPadding || if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
...@@ -106,19 +106,19 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -106,19 +106,19 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// pad both M and K // pad both M and K
assert(K % AK1 == 0); assert(K % AK1 == 0);
const auto AK0 = K / AK1; const auto AK0 = KPad / AK1;
const auto a_grid_desc_m_k = const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw, transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(MRaw, M - MRaw), make_tuple(make_right_pad_transform(MRaw, MPad - MRaw),
make_right_pad_transform(KRaw, K - KRaw)), make_right_pad_transform(KRaw, KPad - KRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k, transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)), make_pass_through_transform(MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
...@@ -135,7 +135,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -135,7 +135,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const auto a_grid_desc_ak0_m_ak1 = const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw, transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_right_pad_transform(MRaw, M - MRaw)), make_right_pad_transform(MRaw, MPad - MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
...@@ -147,11 +147,11 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -147,11 +147,11 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// pad K, but not M // pad K, but not M
assert(K % AK1 == 0); assert(K % AK1 == 0);
const auto AK0 = K / AK1; const auto AK0 = KPad / AK1;
const auto a_grid_desc_m_k = transform_tensor_descriptor( const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw, a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, K - KRaw)), make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad - KRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -197,8 +197,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -197,8 +197,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
} }
}(); }();
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; const auto NPad = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; const auto KPad = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
if constexpr(GemmSpec == GemmSpecialization::NKPadding || if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
...@@ -206,19 +206,19 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -206,19 +206,19 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// pad both N and K // pad both N and K
assert(K % BK1 == 0); assert(K % BK1 == 0);
const auto BK0 = K / BK1; const auto BK0 = KPad / BK1;
const auto b_grid_desc_n_k = const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw, transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, N - NRaw), make_tuple(make_right_pad_transform(NRaw, NPad - NRaw),
make_right_pad_transform(KRaw, K - KRaw)), make_right_pad_transform(KRaw, KPad - KRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 = const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k, transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)), make_pass_through_transform(NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
...@@ -235,7 +235,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -235,7 +235,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const auto b_grid_desc_bk0_n_bk1 = const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw, transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_right_pad_transform(NRaw, N - NRaw)), make_right_pad_transform(NRaw, NPad - NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
...@@ -247,11 +247,11 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -247,11 +247,11 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// pad K, but not N // pad K, but not N
assert(K % BK1 == 0); assert(K % BK1 == 0);
const auto BK0 = K / BK1; const auto BK0 = KPad / BK1;
const auto b_grid_desc_n_k = transform_tensor_descriptor( const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw, b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, K - KRaw)), make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad - KRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -297,16 +297,16 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -297,16 +297,16 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
} }
}(); }();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; const auto MPad = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; const auto NPad = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
if constexpr(GemmSpec == GemmSpecialization::MNPadding || if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad M and N // pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw, return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, M - MRaw), make_tuple(make_right_pad_transform(MRaw, MPad - MRaw),
make_right_pad_transform(NRaw, N - NRaw)), make_right_pad_transform(NRaw, NPad - NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
...@@ -316,7 +316,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -316,7 +316,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// pad M, but not N // pad M, but not N
return transform_tensor_descriptor( return transform_tensor_descriptor(
c_grid_desc_mraw_nraw, c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, M - MRaw), make_pass_through_transform(NRaw)), make_tuple(make_right_pad_transform(MRaw, MPad - MRaw), make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
...@@ -326,7 +326,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -326,7 +326,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// pad N, but not M // pad N, but not M
return transform_tensor_descriptor( return transform_tensor_descriptor(
c_grid_desc_mraw_nraw, c_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, N - NRaw)), make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad - NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment