Commit 09f3a75e authored by ozturkosu's avatar ozturkosu
Browse files

trace gridwise gemm CheckValidity For Padding

parent 1ff50e78
......@@ -21,24 +21,42 @@ using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
// // clang-format off
// using DeviceGemmV2_Streamk_Instance =
// ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_Streamk_V3<
// ALayout, BLayout, CLayout,
// ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
// PassThrough, PassThrough, PassThrough, GemmDefault,
// 256,
// 128, 128,
// 64, 8, 8,
// 16, 16,
// 4, 4,
// S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
// 2, 8, 8, 0,
// S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
// 2, 8, 8, 0,
// 1, 2, S<1, 32, 1, 8>, 8,
// ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>;
// // clang-format on
using DeviceGemmV2_Streamk_Instance =
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_Streamk_V3<
ALayout, BLayout, CLayout,
Row, Col, Row,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
PassThrough, PassThrough, PassThrough, GemmDefault,
256,
128, 128,
64, 8, 8,
16, 16,
4, 4,
32, 32,
2, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
1, 2, S<1, 32, 1, 8>, 8,
ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>;
// clang-format on
1, 1, S<1, 16, 1, 16>, 4,
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
......
......@@ -467,11 +467,17 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
{
if(!ck::is_xdl_supported())
{
std::cout << "@EminHari BugFix device_gemm IsSupportedArgument Case1"
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
return false;
}
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> &&
arg.Streamk_sel > 0)
{
std::cout << "@EminHari BugFix device_gemm IsSupportedArgument Case2"
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
......@@ -479,9 +485,16 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
std::cout << "@EminHari BugFix device_gemm IsSupportedArgument Case3"
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
return false;
}
std::cout << "@EminHari BugFix device_gemm IsSupportedArgument Validity Passed"
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
return GridwiseGemm::CheckValidity(arg);
}
......@@ -762,7 +775,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceGemmXdlUniversal"
str << "DeviceGemmXdlUniversal_StreamK"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< std::string(ALayout::name)[0]
......
......@@ -956,6 +956,11 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
std::cout << "@EminDebug (gridwise_gemm_sk): Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
return false;
}
}
......@@ -963,7 +968,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
(is_same<tensor_layout::gemm::RowMajor, BLayout>::value))
{
if(!(karg.N % NPerBlock == 0))
{
......@@ -973,6 +979,11 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
std::cout << "@EminDebug (gridwise_gemm_sk): Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
return false;
}
}
......@@ -992,6 +1003,11 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< karg.K << " " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__ << std::endl;
}
std::cout << "@EminDebug (gridwise_gemm_sk): Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
return false;
}
}
......@@ -1015,6 +1031,10 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
std::cout << "@EminDebug (gridwise_gemm_sk): Arg K (" << karg.K
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
return false;
}
}
......@@ -1029,6 +1049,12 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
std::cout << "@EminDebug (gridwise_gemm_sk): Arg M (" << karg.M
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
return false;
}
}
......@@ -1044,6 +1070,12 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
std::cout << "@EminDebug (gridwise_gemm_sk): Arg N (" << karg.N
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
return false;
}
}
......@@ -1058,6 +1090,12 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
std::cout << "@EminDebug (gridwise_gemm_sk): Arg K (" << karg.K
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
return false;
}
}
......@@ -1075,6 +1113,14 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
std::cout << "@EminDebug (gridwise_gemm_sk): Arg N (" << karg.N
<< ") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
return false;
}
}
......@@ -1091,18 +1137,30 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
std::cout << "@EminDebug (gridwise_gemm_sk): Arg M (" << karg.M
<< ") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
return false;
}
}
if constexpr(is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
// Following Should be removed
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << " Grid size: " << karg.Grid_size << " > 1 is not support yet"
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
std::cout << " @EminDebug (gridwise_gemm_sk): Grid size: " << karg.Grid_size << " > 1 is not support yet"
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
// check gridwise gemm pipeline
......@@ -1112,6 +1170,9 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
{
if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
{
std::cout << " @EminDebug (gridwise_gemm_sk): Grid size: "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
return false;
}
}
......
......@@ -1143,6 +1143,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
std::cout << "@EminDebug (gridwise_gemm_sk): Arg N (" << karg.N
<< ") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
return false;
}
}
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment