Commit 09f3a75e authored by ozturkosu's avatar ozturkosu
Browse files

trace gridwise gemm CheckValidity For Padding

parent 1ff50e78
...@@ -21,24 +21,42 @@ using CElementOp = PassThrough; ...@@ -21,24 +21,42 @@ using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // // clang-format off
using DeviceGemmV2_Streamk_Instance = // using DeviceGemmV2_Streamk_Instance =
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_Streamk_V3< // ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_Streamk_V3<
ALayout, BLayout, CLayout, // ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, // ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
PassThrough, PassThrough, PassThrough, GemmDefault, // PassThrough, PassThrough, PassThrough, GemmDefault,
256, // 256,
128, 128, // 128, 128,
64, 8, 8, // 64, 8, 8,
16, 16, // 16, 16,
4, 4, // 4, 4,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, // S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0, // 2, 8, 8, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, // S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0, // 2, 8, 8, 0,
1, 2, S<1, 32, 1, 8>, 8, // 1, 2, S<1, 32, 1, 8>, 8,
ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>; // ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>;
// clang-format on // // clang-format on
using DeviceGemmV2_Streamk_Instance =
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_Streamk_V3<
Row, Col, Row,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
PassThrough, PassThrough, PassThrough, GemmDefault,
256,
128, 128,
64, 8, 8,
32, 32,
2, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
1, 1, S<1, 16, 1, 16>, 4,
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>;
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>; ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
......
...@@ -467,11 +467,17 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout ...@@ -467,11 +467,17 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
{ {
if(!ck::is_xdl_supported()) if(!ck::is_xdl_supported())
{ {
std::cout << "@EminHari BugFix device_gemm IsSupportedArgument Case1"
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
return false; return false;
} }
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> &&
arg.Streamk_sel > 0) arg.Streamk_sel > 0)
{ {
std::cout << "@EminHari BugFix device_gemm IsSupportedArgument Case2"
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
return false; return false;
} }
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
...@@ -479,9 +485,16 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout ...@@ -479,9 +485,16 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
GemmSpec == GemmSpecialization::MNKPadding || GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding)) GemmSpec == GemmSpecialization::KPadding))
{ {
std::cout << "@EminHari BugFix device_gemm IsSupportedArgument Case3"
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
return false; return false;
} }
std::cout << "@EminHari BugFix device_gemm IsSupportedArgument Validity Passed"
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
return GridwiseGemm::CheckValidity(arg); return GridwiseGemm::CheckValidity(arg);
} }
...@@ -762,7 +775,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout ...@@ -762,7 +775,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
{BlockGemmPipelineVersion::v5, "v5"}}; {BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off // clang-format off
str << "DeviceGemmXdlUniversal" str << "DeviceGemmXdlUniversal_StreamK"
<< "<" << "<"
<< getGemmSpecializationString(GemmSpec) << ", " << getGemmSpecializationString(GemmSpec) << ", "
<< std::string(ALayout::name)[0] << std::string(ALayout::name)[0]
......
...@@ -956,6 +956,11 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -956,6 +956,11 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl; << std::endl;
} }
std::cout << "@EminDebug (gridwise_gemm_sk): Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
return false; return false;
} }
} }
...@@ -963,7 +968,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -963,7 +968,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
(is_same<tensor_layout::gemm::RowMajor, BLayout>::value))
{ {
if(!(karg.N % NPerBlock == 0)) if(!(karg.N % NPerBlock == 0))
{ {
...@@ -973,6 +979,11 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -973,6 +979,11 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl; << std::endl;
} }
std::cout << "@EminDebug (gridwise_gemm_sk): Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
return false; return false;
} }
} }
...@@ -992,6 +1003,11 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -992,6 +1003,11 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< karg.K << " " << __FILE__ << ":" << __LINE__ << karg.K << " " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__ << std::endl; << ", in function: " << __func__ << std::endl;
} }
std::cout << "@EminDebug (gridwise_gemm_sk): Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
return false; return false;
} }
} }
...@@ -1015,6 +1031,10 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1015,6 +1031,10 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl; << __LINE__ << ", in function: " << __func__ << std::endl;
} }
std::cout << "@EminDebug (gridwise_gemm_sk): Arg K (" << karg.K
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
return false; return false;
} }
} }
...@@ -1029,6 +1049,12 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1029,6 +1049,12 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl; << __LINE__ << ", in function: " << __func__ << std::endl;
} }
std::cout << "@EminDebug (gridwise_gemm_sk): Arg M (" << karg.M
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
return false; return false;
} }
} }
...@@ -1044,6 +1070,12 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1044,6 +1070,12 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl; << __LINE__ << ", in function: " << __func__ << std::endl;
} }
std::cout << "@EminDebug (gridwise_gemm_sk): Arg N (" << karg.N
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
return false; return false;
} }
} }
...@@ -1058,6 +1090,12 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1058,6 +1090,12 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl; << __LINE__ << ", in function: " << __func__ << std::endl;
} }
std::cout << "@EminDebug (gridwise_gemm_sk): Arg K (" << karg.K
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
return false; return false;
} }
} }
...@@ -1075,6 +1113,14 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1075,6 +1113,14 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl; << std::endl;
} }
std::cout << "@EminDebug (gridwise_gemm_sk): Arg N (" << karg.N
<< ") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
return false; return false;
} }
} }
...@@ -1091,18 +1137,30 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1091,18 +1137,30 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl; << std::endl;
} }
std::cout << "@EminDebug (gridwise_gemm_sk): Arg M (" << karg.M
<< ") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
return false; return false;
} }
} }
if constexpr(is_same<remove_cvref_t<CDataType>, bhalf_t>::value) if constexpr(is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{ {
// Following Should be removed
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << " Grid size: " << karg.Grid_size << " > 1 is not support yet" std::cout << " Grid size: " << karg.Grid_size << " > 1 is not support yet"
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl; << std::endl;
} }
std::cout << " @EminDebug (gridwise_gemm_sk): Grid size: " << karg.Grid_size << " > 1 is not support yet"
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
} }
// check gridwise gemm pipeline // check gridwise gemm pipeline
...@@ -1112,6 +1170,9 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1112,6 +1170,9 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
{ {
if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
{ {
std::cout << " @EminDebug (gridwise_gemm_sk): Grid size: "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
return false; return false;
} }
} }
......
...@@ -1143,6 +1143,13 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1143,6 +1143,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl; << std::endl;
} }
std::cout << "@EminDebug (gridwise_gemm_sk): Arg N (" << karg.N
<< ") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
return false; return false;
} }
} }
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment