Commit bffb335b authored by ozturkosu's avatar ozturkosu
Browse files

gridwise stream-k update

parent 04dd3148
......@@ -946,7 +946,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) &&
!(is_same<tensor_layout::gemm::RowMajor, ALayout>::value))
{
if(!(karg.M % MPerBlock == 0))
{
......@@ -963,7 +964,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) &&
(is_same<tensor_layout::gemm::RowMajor, BLayout>::value))
{
if(!(karg.N % NPerBlock == 0))
{
......@@ -998,7 +1000,14 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
else
{
if(karg.K <= 0)
// if(karg.K <= 0)
// {
// return false;
// }
constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
auto K_t = karg.KBatch * KReadVec;
auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
{
return false;
}
......@@ -1095,13 +1104,32 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
}
if constexpr(is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
// if constexpr(is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
// {
// if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
// {
// std::cout << " Grid size: " << karg.Grid_size << " > 1 is not support yet"
// << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
// << std::endl;
// }
// }
if constexpr(!(is_same<remove_cvref_t<CDataType>, half_t>::value ||
is_same<remove_cvref_t<CDataType>, float>::value ||
is_same<remove_cvref_t<CDataType>, bhalf_t>::value ||
is_same<remove_cvref_t<CDataType>, int32_t>::value))
{
if(!karg.IsReduceAdd())
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << " Grid size: " << karg.Grid_size << " > 1 is not support yet"
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
}
if(karg.KBatch > 1)
{
return false;
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment