Commit d3469a55 authored by Jing Zhang's avatar Jing Zhang
Browse files

added kpad support into v2r3

parent f5ec04f0
...@@ -315,11 +315,6 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -315,11 +315,6 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
return false; return false;
} }
if(problem.K % K1 != 0)
{
return false;
}
return GridwiseGemm::CheckValidity(problem); return GridwiseGemm::CheckValidity(problem);
} }
...@@ -416,7 +411,12 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -416,7 +411,12 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << K0PerBlock << ", "
<< K1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ">" << ">"
<< " NumGemmKPrefetchStage: " << " NumGemmKPrefetchStage: "
<< NumGemmKPrefetchStage << ", " << NumGemmKPrefetchStage << ", "
......
...@@ -194,7 +194,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -194,7 +194,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
StrideC{StrideC_}, StrideC{StrideC_},
MPadded{CalculateMPadded(M_)}, MPadded{CalculateMPadded(M_)},
NPadded{CalculateNPadded(N_)}, NPadded{CalculateNPadded(N_)},
K0{CalculateK0(K)} K0{CalculateK0(K_)}
{ {
} }
...@@ -383,7 +383,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -383,7 +383,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
__host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{ {
const index_t num_loop = K / (K0PerBlock * K1); const index_t KPad = math::integer_divide_ceil(K, K0PerBlock * K1) * K0PerBlock * K1;
const index_t num_loop = KPad / (K0PerBlock * K1);
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
} }
...@@ -840,11 +841,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext ...@@ -840,11 +841,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
} }
}(); }();
const auto K0Pad = math::integer_divide_ceil(K0, K0PerBlock) * K0PerBlock;
const auto KPad = K0Pad * K1Value;
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{ {
return transform_tensor_descriptor( return transform_tensor_descriptor(
a_grid_desc_m_k, a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)), make_tuple(make_unmerge_transform(make_tuple(K0Pad, K1Value)),
make_right_pad_transform(M, MPad - M)), make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
...@@ -852,8 +862,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext ...@@ -852,8 +862,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
else else
{ {
return transform_tensor_descriptor( return transform_tensor_descriptor(
a_grid_desc_m_k, a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)), make_tuple(make_unmerge_transform(make_tuple(K0Pad, K1Value)),
make_pass_through_transform(M)), make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
...@@ -874,11 +884,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext ...@@ -874,11 +884,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
} }
}(); }();
const auto K0Pad = math::integer_divide_ceil(K0, K0PerBlock) * K0PerBlock;
const auto KPad = K0Pad * K1Value;
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{ {
return transform_tensor_descriptor( return transform_tensor_descriptor(
b_grid_desc_k_n, b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)), make_tuple(make_unmerge_transform(make_tuple(K0Pad, K1Value)),
make_right_pad_transform(N, NPad - N)), make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
...@@ -886,8 +905,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext ...@@ -886,8 +905,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
else else
{ {
return transform_tensor_descriptor( return transform_tensor_descriptor(
b_grid_desc_k_n, b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)), make_tuple(make_unmerge_transform(make_tuple(K0Pad, K1Value)),
make_pass_through_transform(N)), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
...@@ -947,6 +966,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext ...@@ -947,6 +966,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
} }
} }
if(problem.K % K1 != 0)
{
return false;
}
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment