Commit bffb335b authored by ozturkosu's avatar ozturkosu
Browse files

gridwise stream-k update

parent 04dd3148
...@@ -946,7 +946,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -946,7 +946,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) &&
!(is_same<tensor_layout::gemm::RowMajor, ALayout>::value))
{ {
if(!(karg.M % MPerBlock == 0)) if(!(karg.M % MPerBlock == 0))
{ {
...@@ -963,7 +964,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -963,7 +964,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) &&
(is_same<tensor_layout::gemm::RowMajor, BLayout>::value))
{ {
if(!(karg.N % NPerBlock == 0)) if(!(karg.N % NPerBlock == 0))
{ {
...@@ -998,7 +1000,14 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -998,7 +1000,14 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
else else
{ {
if(karg.K <= 0) // if(karg.K <= 0)
// {
// return false;
// }
constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
auto K_t = karg.KBatch * KReadVec;
auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
{ {
return false; return false;
} }
...@@ -1095,13 +1104,32 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1095,13 +1104,32 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
} }
} }
if constexpr(is_same<remove_cvref_t<CDataType>, bhalf_t>::value) // if constexpr(is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
// {
// if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
// {
// std::cout << " Grid size: " << karg.Grid_size << " > 1 is not support yet"
// << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
// << std::endl;
// }
// }
if constexpr(!(is_same<remove_cvref_t<CDataType>, half_t>::value ||
is_same<remove_cvref_t<CDataType>, float>::value ||
is_same<remove_cvref_t<CDataType>, bhalf_t>::value ||
is_same<remove_cvref_t<CDataType>, int32_t>::value))
{ {
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) if(!karg.IsReduceAdd())
{ {
std::cout << " Grid size: " << karg.Grid_size << " > 1 is not support yet" if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__ {
<< std::endl; std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
}
if(karg.KBatch > 1)
{
return false;
}
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment