Commit 5393111f authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed comments

parent 5ac3d021
...@@ -383,8 +383,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -383,8 +383,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
__host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{ {
const index_t KPad = math::integer_divide_ceil(K, K0PerBlock * K1) * K0PerBlock * K1; const index_t num_loop = math::integer_divide_ceil(K, K0PerBlock * K1);
const index_t num_loop = KPad / (K0PerBlock * K1);
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
} }
...@@ -841,6 +840,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext ...@@ -841,6 +840,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
} }
}(); }();
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
{
const auto K0Pad = math::integer_divide_ceil(K0, K0PerBlock) * K0PerBlock; const auto K0Pad = math::integer_divide_ceil(K0, K0PerBlock) * K0PerBlock;
const auto KPad = K0Pad * K1Value; const auto KPad = K0Pad * K1Value;
...@@ -850,8 +851,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext ...@@ -850,8 +851,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
return transform_tensor_descriptor( return transform_tensor_descriptor(
a_grid_desc_m_kpad, a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(K0Pad, K1Value)), make_tuple(make_unmerge_transform(make_tuple(K0Pad, K1Value)),
...@@ -859,11 +858,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext ...@@ -859,11 +858,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)),
make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else else
{ {
return transform_tensor_descriptor( return transform_tensor_descriptor(
a_grid_desc_m_kpad, a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0Pad, K1Value)), make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)),
make_pass_through_transform(M)), make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
...@@ -884,6 +892,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext ...@@ -884,6 +892,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
} }
}(); }();
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
{
const auto K0Pad = math::integer_divide_ceil(K0, K0PerBlock) * K0PerBlock; const auto K0Pad = math::integer_divide_ceil(K0, K0PerBlock) * K0PerBlock;
const auto KPad = K0Pad * K1Value; const auto KPad = K0Pad * K1Value;
...@@ -893,8 +903,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext ...@@ -893,8 +903,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
return transform_tensor_descriptor( return transform_tensor_descriptor(
b_grid_desc_kpad_n, b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(K0Pad, K1Value)), make_tuple(make_unmerge_transform(make_tuple(K0Pad, K1Value)),
...@@ -902,11 +910,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext ...@@ -902,11 +910,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else else
{ {
return transform_tensor_descriptor( return transform_tensor_descriptor(
b_grid_desc_kpad_n, b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0Pad, K1Value)), make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)),
make_pass_through_transform(N)), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment