Commit a9a3a3e2 authored by ozturkosu's avatar ozturkosu
Browse files

bug fixed

parent 9c5b2f39
......@@ -223,72 +223,88 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
}();
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both M and K
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad M, but not K
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad K, but not M
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else
{
// not pad M or K
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
// Pad both M and K to be multiples of the block sizes
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
// using GemmSpecialization = tensor_operation::device::GemmSpecialization;
// if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
// GemmSpec == GemmSpecialization::MNKPadding)
// {
// // pad both M and K
// const auto a_grid_desc_m_k =
// transform_tensor_descriptor(a_grid_desc_mraw_kraw,
// make_tuple(make_right_pad_transform(M, MPad - M),
// make_right_pad_transform(K, KPad - K)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}));
// const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
// a_grid_desc_m_k,
// make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
// make_pass_through_transform(MPad)),
// make_tuple(Sequence<1>{}, Sequence<0>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// return a_grid_desc_ak0_m_ak1;
// }
// else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
// GemmSpec == GemmSpecialization::MNPadding)
// {
// // pad M, but not K
// const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
// a_grid_desc_mraw_kraw,
// make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
// make_right_pad_transform(M, MPad - M)),
// make_tuple(Sequence<1>{}, Sequence<0>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// return a_grid_desc_ak0_m_ak1;
// }
// else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
// GemmSpec == GemmSpecialization::NKPadding)
// {
// // pad K, but not M
// const auto a_grid_desc_m_k = transform_tensor_descriptor(
// a_grid_desc_mraw_kraw,
// make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}));
// const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
// a_grid_desc_m_k,
// make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
// make_pass_through_transform(M)),
// make_tuple(Sequence<1>{}, Sequence<0>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// return a_grid_desc_ak0_m_ak1;
// }
// else
// {
// // not pad M or K
// const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
// a_grid_desc_mraw_kraw,
// make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
// make_pass_through_transform(M)),
// make_tuple(Sequence<1>{}, Sequence<0>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// return a_grid_desc_ak0_m_ak1;
// }
}
__device__ static auto MakeBGridDescriptor_BK0_N_BK1(
......@@ -304,73 +320,89 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
}
}();
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both N and K
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(N, NPad - N),
make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad K, but not N
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
// Pad both N and K to be multiples of the block sizes
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(N, NPad - N),
make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
// using GemmSpecialization = tensor_operation::device::GemmSpecialization;
// if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
// GemmSpec == GemmSpecialization::MNKPadding)
// {
// // pad both N and K
// const auto b_grid_desc_n_k =
// transform_tensor_descriptor(b_grid_desc_nraw_kraw,
// make_tuple(make_right_pad_transform(N, NPad - N),
// make_right_pad_transform(K, KPad - K)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}));
// const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
// b_grid_desc_n_k,
// make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
// make_pass_through_transform(NPad)),
// make_tuple(Sequence<1>{}, Sequence<0>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// return b_grid_desc_bk0_n_bk1;
// }
// else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
// GemmSpec == GemmSpecialization::MNPadding)
// {
// // pad N, but not K
// const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
// b_grid_desc_nraw_kraw,
// make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
// make_right_pad_transform(N, NPad - N)),
// make_tuple(Sequence<1>{}, Sequence<0>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// return b_grid_desc_bk0_n_bk1;
// }
// else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
// GemmSpec == GemmSpecialization::MKPadding)
// {
// // pad K, but not N
// const auto b_grid_desc_n_k = transform_tensor_descriptor(
// b_grid_desc_nraw_kraw,
// make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}));
// const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
// b_grid_desc_n_k,
// make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
// make_pass_through_transform(N)),
// make_tuple(Sequence<1>{}, Sequence<0>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// return b_grid_desc_bk0_n_bk1;
// }
// else
// {
// // not pad N or K
// const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
// b_grid_desc_nraw_kraw,
// make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
// make_pass_through_transform(N)),
// make_tuple(Sequence<1>{}, Sequence<0>{}),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// return b_grid_desc_bk0_n_bk1;
// }
}
template <typename ABlockDesc_AK0_M_AK1>
......@@ -405,43 +437,49 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
}();
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return c_grid_desc_mraw_nraw;
}
// Pad both M and N to be multiples of the block sizes
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// using GemmSpecialization = tensor_operation::device::GemmSpecialization;
// if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
// GemmSpec == GemmSpecialization::MNKPadding)
// {
// // pad M and N
// return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
// make_tuple(make_right_pad_transform(M, MPad - M),
// make_right_pad_transform(N, NPad - N)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}));
// }
// else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
// GemmSpec == GemmSpecialization::MKPadding)
// {
// // pad M, but not N
// return transform_tensor_descriptor(
// c_grid_desc_mraw_nraw,
// make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}));
// }
// else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
// GemmSpec == GemmSpecialization::NKPadding)
// {
// // pad N, but not M
// return transform_tensor_descriptor(
// c_grid_desc_mraw_nraw,
// make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)),
// make_tuple(Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}));
// }
// else
// {
// // not pad M or N
// return c_grid_desc_mraw_nraw;
// }
}
struct Problem
......@@ -946,7 +984,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
!(is_same<tensor_layout::gemm::RowMajor, ALayout>::value))
{
if(!(karg.M % MPerBlock == 0))
{
......@@ -963,7 +1002,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
(is_same<tensor_layout::gemm::RowMajor, BLayout>::value))
{
if(!(karg.N % NPerBlock == 0))
{
......@@ -1029,6 +1069,11 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
std::cout << "Arg M (" << karg.M
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
return false;
}
}
......@@ -1044,6 +1089,10 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
std::cout << "Arg N (" << karg.N
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
return false;
}
}
......@@ -1058,6 +1107,10 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
std::cout << "Arg K (" << karg.K
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
return false;
}
}
......@@ -1075,6 +1128,12 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
std::cout << "Arg N (" << karg.N
<< ") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
return false;
}
}
......@@ -1091,10 +1150,17 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
std::cout << "Arg M (" << karg.M
<< ") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
return false;
}
}
// @Emin Need to Remove This !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
if constexpr(is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
......
......@@ -43,7 +43,7 @@ using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_instances =
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Compute friendly
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, "N, K0, K1" S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 32, 32, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
......@@ -55,7 +55,51 @@ using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_instances =
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>,
// // Emin Added Kernel Instance
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec,
// Tiling Parameters - How to partition 'Block tiling - wave tiling'
256, // Block Size
128, // MPer Block
128, // NPer Block
64, // KPer Block
8, // AK1 ::
8, // BK1 float4 float8
32, // MPer XDL
32, // NPer XDL
2, // MXdl Per Wave
2, // NXdl Per Wave
// For Tensor A these define how to copy data from Global to Shared Mem
S<8, 32, 1>, // ABlockTransfer ThreadCluster Lengths_K0_M_K1
S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder !!!!! Determined by Layout
S<1, 0, 2>, // ABlockTransfer SrcAccessOrder !!!!!! Determined by Layout , Always 1-0-2 If A is row major ,
2, // ABlockTransfer SrcVectorDim !! If A is row major this is always 2
2, // ABlockTransfer SrcScalar PerVector // How you read 'A tensor' data from global memory
8, // ABlockTransfer DstScalar PerVector_K1 // How you write 'A tensor' data to shared memory
0, // ABlockLds AddExtraM
// For Tensor B these define how to copy data from Global to Shared Mem
S<8, 32, 1>, // BBlockTransfer ThreadCluster Lengths_K0_N_K1
S<1, 0, 2>, // BBlockTransfer ThreadCluster ArrangeOrder Always 1-0-2 If B is col major
S<1, 0, 2>, // BBlockTransfer SrcAccessOrder Always 1-0-2 If B is col major
2, // BBlockTransfer SrcVectorDim !! If B is column major this is always 2
2, // BBlockTransfer SrcScalar PerVector
8, // BlockTransfer DstScalar PerVector_k1
0, // B BlockLdsAddExtraN
1, // CShuffle MXdlPerWave PerShuffle ::
1, // CShuffle NXdlPerWave 2 OR 1 it depens on kernel sometimes only 1
S<1, 16, 1, 16>,
2,
BlockGemmPipelineScheduler::Interwave,
BlockGemmPipelineVersion::v1>
// ...existing code...
// clang-format on
>;
......@@ -84,7 +128,46 @@ using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_instances =
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
// // Emin Added Kernel Instance
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec,
// Tiling Parameters - How to partition 'Block tiling - wave tiling
256, // Block Size
16, // MPer Block
256, // NPer Block
64, // KPer Block
8, // AK1 ::
8, // BK1 float4 float8
16, // MPer XDL
16, // NPer XDL
1, // MXdl Per Wave
4, // NXdl Per Wave
// For Tensor A these define how to copy data from Global to Shared Mem
S<8, 16, 1>, //S<8, 32, 1>, // ABlockTransfer ThreadCluster Lengths_K0_M_K1
S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder !!!!! Determined by Layout
S<1, 0, 2>, // ABlockTransfer SrcAccessOrder !!!!!! Determined by Layout , Always 1-0-2 If A is row major ,
2, // ABlockTransfer SrcVectorDim !! If A is row major this is always 2
2, // ABlockTransfer SrcScalar PerVector // How you read 'A tensor' data from global memory
8, // ABlockTransfer DstScalar PerVector_K1 // How you write 'A tensor' data to shared memory
0, // ABlockLds AddExtraM
// For Tensor B these define how to copy data from Global to Shared Mem
S<8, 16, 1>, // BBlockTransfer ThreadCluster Lengths_K0_N_K1
S<1, 0, 2>, // BBlockTransfer ThreadCluster ArrangeOrder Always 1-0-2 If B is col major
S<1, 0, 2>, // BBlockTransfer SrcAccessOrder Always 1-0-2 If B is col major
2, // BBlockTransfer SrcVectorDim !! If B is column major this is always 2
2, // BBlockTransfer SrcScalar PerVector
8,
0,
1,
1,
S<1, 16, 1, 16>,
2,
BlkGemmPipeSched,
BlockGemmPipelineVersion::v2>
// clang-format on
>;
} // namespace instance
......
......@@ -55,7 +55,28 @@ int profile_gemm_universal_streamk(int argc, char* argv[])
printf("arg18: memory for rotating buffer (default 0, size in MB)\n");
exit(1);
}
int M;
int N;
int StrideA;
int StrideB;
// Analyze the unsupported matrix shapes, switch the M and N number
if(std::stoi(argv[9]) % 8 != 0 && std::stoi(argv[8]) % 8 == 0)
{
M = std::stoi(argv[9]);
StrideA = std::stoi(argv[12]);
N = std::stoi(argv[8]);
StrideB = std::stoi(argv[11]);
}
else
{
M = std::stoi(argv[8]);
StrideA = std::stoi(argv[11]);
N = std::stoi(argv[9]);
StrideB = std::stoi(argv[12]);
}
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]);
......@@ -63,12 +84,12 @@ int profile_gemm_universal_streamk(int argc, char* argv[])
const bool do_log = std::stoi(argv[6]);
const bool time_kernel = std::stoi(argv[7]);
const int M = std::stoi(argv[8]);
const int N = std::stoi(argv[9]);
// const int M = std::stoi(argv[8]);
// const int N = std::stoi(argv[9]);
const int K = std::stoi(argv[10]);
const int StrideA = std::stoi(argv[11]);
const int StrideB = std::stoi(argv[12]);
// const int StrideA = std::stoi(argv[11]);
// const int StrideB = std::stoi(argv[12]);
const int StrideC = std::stoi(argv[13]);
const int Streamk_sel = std::stoi(argv[14]);
const int Grid_size = std::stoi(argv[15]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment