Commit c9bef1c6 authored by Anthony Chang's avatar Anthony Chang
Browse files

coarsely controlled 2nd gemm padding

parent e55b67a0
...@@ -308,8 +308,6 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -308,8 +308,6 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
GemmSpec == GemmSpecialization::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad both N and K // pad both N and K
assert(K % BK1 == 0);
const auto BK0 = K / BK1; const auto BK0 = K / BK1;
const auto b_grid_desc_n_k = const auto b_grid_desc_n_k =
...@@ -332,8 +330,6 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -332,8 +330,6 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
GemmSpec == GemmSpecialization::MNPadding) GemmSpec == GemmSpecialization::MNPadding)
{ {
// pad N, but not K // pad N, but not K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1; const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 = const auto b_grid_desc_bk0_n_bk1 =
...@@ -349,8 +345,6 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -349,8 +345,6 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
GemmSpec == GemmSpecialization::MKPadding) GemmSpec == GemmSpecialization::MKPadding)
{ {
// pad K, but not N // pad K, but not N
assert(K % BK1 == 0);
const auto BK0 = K / BK1; const auto BK0 = K / BK1;
const auto b_grid_desc_n_k = transform_tensor_descriptor( const auto b_grid_desc_n_k = transform_tensor_descriptor(
...@@ -371,8 +365,6 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -371,8 +365,6 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
else else
{ {
// not pad N or K // not pad N or K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1; const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 = const auto b_grid_desc_bk0_n_bk1 =
...@@ -408,20 +400,42 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -408,20 +400,42 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
const auto NPad = N - NRaw; const auto NPad = N - NRaw;
const auto KPad = K - KRaw; const auto KPad = K - KRaw;
// TODO ANT: implement padding // TODO: implement finer-grained padding
// not pad N or K if constexpr(GemmSpec == GemmSpecialization::Default)
assert(KRaw % B1K1 == 0); {
const auto B1K0 = KRaw / B1K1;
const auto B1K0 = KRaw / B1K1; const auto b1_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b1_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b1_grid_desc_bk0_n_bk1;
}
else
{
// pad both B1N and B1K
const auto B1K0 = K / B1K1;
const auto b1_grid_desc_bk0_n_bk1 = const auto b1_grid_desc_n_k =
transform_tensor_descriptor(b1_grid_desc_nraw_kraw, transform_tensor_descriptor(b1_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), make_tuple(make_right_pad_transform(NRaw, NPad),
make_pass_through_transform(NRaw)), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b1_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b1_grid_desc_bk0_n_bk1;
}
return b1_grid_desc_bk0_n_bk1;
} }
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment