Commit 5393111f authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed comments

parent 5ac3d021
......@@ -383,8 +383,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
__host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t KPad = math::integer_divide_ceil(K, K0PerBlock * K1) * K0PerBlock * K1;
const index_t num_loop = KPad / (K0PerBlock * K1);
const index_t num_loop = math::integer_divide_ceil(K, K0PerBlock * K1);
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
......@@ -841,6 +840,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
}
}();
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
{
const auto K0Pad = math::integer_divide_ceil(K0, K0PerBlock) * K0PerBlock;
const auto KPad = K0Pad * K1Value;
......@@ -850,8 +851,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(K0Pad, K1Value)),
......@@ -859,11 +858,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)),
make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(K0Pad, K1Value)),
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
......@@ -884,6 +892,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
}
}();
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
{
const auto K0Pad = math::integer_divide_ceil(K0, K0PerBlock) * K0PerBlock;
const auto KPad = K0Pad * K1Value;
......@@ -893,8 +903,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(K0Pad, K1Value)),
......@@ -902,11 +910,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(K0Pad, K1Value)),
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment