Commit 6570ef7a authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Move check for karg.K into CheckValidity()

parent b0e02b8a
...@@ -137,14 +137,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -137,14 +137,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
GemmSpec == GemmSpecialization::KPadding || GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding) GemmSpec == GemmSpecialization::NKPadding)
{ {
assert(CalculateKPadded(K) % AK1Value == 0);
return CalculateKPadded(K) / AK1Value; return CalculateKPadded(K) / AK1Value;
} }
else else
{ {
assert(K % AK1Value == 0);
return K / AK1Value; return K / AK1Value;
} }
} }
...@@ -158,14 +154,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -158,14 +154,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
GemmSpec == GemmSpecialization::KPadding || GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding) GemmSpec == GemmSpecialization::MKPadding)
{ {
assert(CalculateKPadded(K) % BK1Value == 0);
return CalculateKPadded(K) / BK1Value; return CalculateKPadded(K) / BK1Value;
} }
else else
{ {
assert(K % BK1Value == 0);
return K / BK1Value; return K / BK1Value;
} }
} }
...@@ -530,6 +522,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -530,6 +522,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
} }
} }
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding)
{
if(!(CalculateKPadded(karg.K) % AK1Value == 0) ||
!(CalculateKPadded(karg.K) % BK1Value == 0))
{
return false;
}
}
else
{
if(!(karg.K % AK1Value == 0) || !(karg.K % BK1Value == 0))
{
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{ {
if(karg.K % ABlockTransferSrcScalarPerVector != 0) if(karg.K % ABlockTransferSrcScalarPerVector != 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment