Commit f84c49fa authored by ozturkosu's avatar ozturkosu
Browse files

update

parent 114674c1
...@@ -239,72 +239,74 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -239,72 +239,74 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1; return a_grid_desc_ak0_m_ak1;
// using GemmSpecialization = tensor_operation::device::GemmSpecialization; #if 0
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
// if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
// GemmSpec == GemmSpecialization::MNKPadding) if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
// { GemmSpec == GemmSpecialization::MNKPadding)
// // pad both M and K {
// const auto a_grid_desc_m_k = // pad both M and K
// transform_tensor_descriptor(a_grid_desc_mraw_kraw, const auto a_grid_desc_m_k =
// make_tuple(make_right_pad_transform(M, MPad - M), transform_tensor_descriptor(a_grid_desc_mraw_kraw,
// make_right_pad_transform(K, KPad - K)), make_tuple(make_right_pad_transform(M, MPad - M),
// make_tuple(Sequence<0>{}, Sequence<1>{}), make_right_pad_transform(K, KPad - K)),
// make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
// a_grid_desc_m_k, const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
// make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), a_grid_desc_m_k,
// make_pass_through_transform(MPad)), make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
// make_tuple(Sequence<1>{}, Sequence<0>{}), make_pass_through_transform(MPad)),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// return a_grid_desc_ak0_m_ak1;
// } return a_grid_desc_ak0_m_ak1;
// else if constexpr(GemmSpec == GemmSpecialization::MPadding || }
// GemmSpec == GemmSpecialization::MNPadding) else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
// { GemmSpec == GemmSpecialization::MNPadding)
// // pad M, but not K {
// const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( // pad M, but not K
// a_grid_desc_mraw_kraw, const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
// make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), a_grid_desc_mraw_kraw,
// make_right_pad_transform(M, MPad - M)), make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
// make_tuple(Sequence<1>{}, Sequence<0>{}), make_right_pad_transform(M, MPad - M)),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// return a_grid_desc_ak0_m_ak1;
// } return a_grid_desc_ak0_m_ak1;
// else if constexpr(GemmSpec == GemmSpecialization::KPadding || }
// GemmSpec == GemmSpecialization::NKPadding) else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
// { GemmSpec == GemmSpecialization::NKPadding)
// // pad K, but not M {
// const auto a_grid_desc_m_k = transform_tensor_descriptor( // pad K, but not M
// a_grid_desc_mraw_kraw, const auto a_grid_desc_m_k = transform_tensor_descriptor(
// make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), a_grid_desc_mraw_kraw,
// make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
// make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
// a_grid_desc_m_k, const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
// make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), a_grid_desc_m_k,
// make_pass_through_transform(M)), make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
// make_tuple(Sequence<1>{}, Sequence<0>{}), make_pass_through_transform(M)),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// return a_grid_desc_ak0_m_ak1;
// } return a_grid_desc_ak0_m_ak1;
// else }
// { else
// // not pad M or K {
// const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( // not pad M or K
// a_grid_desc_mraw_kraw, const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
// make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), a_grid_desc_mraw_kraw,
// make_pass_through_transform(M)), make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
// make_tuple(Sequence<1>{}, Sequence<0>{}), make_pass_through_transform(M)),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// return a_grid_desc_ak0_m_ak1;
// } return a_grid_desc_ak0_m_ak1;
}
#endif
} }
__device__ static auto MakeBGridDescriptor_BK0_N_BK1( __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
...@@ -337,72 +339,74 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -337,72 +339,74 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1; return b_grid_desc_bk0_n_bk1;
// using GemmSpecialization = tensor_operation::device::GemmSpecialization; #if 0
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
// if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
// GemmSpec == GemmSpecialization::MNKPadding) if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
// { GemmSpec == GemmSpecialization::MNKPadding)
// // pad both N and K {
// const auto b_grid_desc_n_k = // pad both N and K
// transform_tensor_descriptor(b_grid_desc_nraw_kraw, const auto b_grid_desc_n_k =
// make_tuple(make_right_pad_transform(N, NPad - N), transform_tensor_descriptor(b_grid_desc_nraw_kraw,
// make_right_pad_transform(K, KPad - K)), make_tuple(make_right_pad_transform(N, NPad - N),
// make_tuple(Sequence<0>{}, Sequence<1>{}), make_right_pad_transform(K, KPad - K)),
// make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
// b_grid_desc_n_k, const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
// make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), b_grid_desc_n_k,
// make_pass_through_transform(NPad)), make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
// make_tuple(Sequence<1>{}, Sequence<0>{}), make_pass_through_transform(NPad)),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// return b_grid_desc_bk0_n_bk1;
// } return b_grid_desc_bk0_n_bk1;
// else if constexpr(GemmSpec == GemmSpecialization::NPadding || }
// GemmSpec == GemmSpecialization::MNPadding) else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
// { GemmSpec == GemmSpecialization::MNPadding)
// // pad N, but not K {
// const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( // pad N, but not K
// b_grid_desc_nraw_kraw, const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
// make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), b_grid_desc_nraw_kraw,
// make_right_pad_transform(N, NPad - N)), make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
// make_tuple(Sequence<1>{}, Sequence<0>{}), make_right_pad_transform(N, NPad - N)),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// return b_grid_desc_bk0_n_bk1;
// } return b_grid_desc_bk0_n_bk1;
// else if constexpr(GemmSpec == GemmSpecialization::KPadding || }
// GemmSpec == GemmSpecialization::MKPadding) else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
// { GemmSpec == GemmSpecialization::MKPadding)
// // pad K, but not N {
// const auto b_grid_desc_n_k = transform_tensor_descriptor( // pad K, but not N
// b_grid_desc_nraw_kraw, const auto b_grid_desc_n_k = transform_tensor_descriptor(
// make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), b_grid_desc_nraw_kraw,
// make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)),
// make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
// b_grid_desc_n_k, const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
// make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), b_grid_desc_n_k,
// make_pass_through_transform(N)), make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
// make_tuple(Sequence<1>{}, Sequence<0>{}), make_pass_through_transform(N)),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// return b_grid_desc_bk0_n_bk1;
// } return b_grid_desc_bk0_n_bk1;
// else }
// { else
// // not pad N or K {
// const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( // not pad N or K
// b_grid_desc_nraw_kraw, const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
// make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), b_grid_desc_nraw_kraw,
// make_pass_through_transform(N)), make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
// make_tuple(Sequence<1>{}, Sequence<0>{}), make_pass_through_transform(N)),
// make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// return b_grid_desc_bk0_n_bk1;
// } return b_grid_desc_bk0_n_bk1;
}
#endif
} }
template <typename ABlockDesc_AK0_M_AK1> template <typename ABlockDesc_AK0_M_AK1>
...@@ -443,43 +447,45 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -443,43 +447,45 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
make_right_pad_transform(N, NPad - N)), make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
// using GemmSpecialization = tensor_operation::device::GemmSpecialization; #if 0
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
// if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
// GemmSpec == GemmSpecialization::MNKPadding) if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
// { GemmSpec == GemmSpecialization::MNKPadding)
// // pad M and N {
// return transform_tensor_descriptor(c_grid_desc_mraw_nraw, // pad M and N
// make_tuple(make_right_pad_transform(M, MPad - M), return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
// make_right_pad_transform(N, NPad - N)), make_tuple(make_right_pad_transform(M, MPad - M),
// make_tuple(Sequence<0>{}, Sequence<1>{}), make_right_pad_transform(N, NPad - N)),
// make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}),
// } make_tuple(Sequence<0>{}, Sequence<1>{}));
// else if constexpr(GemmSpec == GemmSpecialization::MPadding || }
// GemmSpec == GemmSpecialization::MKPadding) else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
// { GemmSpec == GemmSpecialization::MKPadding)
// // pad M, but not N {
// return transform_tensor_descriptor( // pad M, but not N
// c_grid_desc_mraw_nraw, return transform_tensor_descriptor(
// make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), c_grid_desc_mraw_nraw,
// make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)),
// make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}),
// } make_tuple(Sequence<0>{}, Sequence<1>{}));
// else if constexpr(GemmSpec == GemmSpecialization::NPadding || }
// GemmSpec == GemmSpecialization::NKPadding) else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
// { GemmSpec == GemmSpecialization::NKPadding)
// // pad N, but not M {
// return transform_tensor_descriptor( // pad N, but not M
// c_grid_desc_mraw_nraw, return transform_tensor_descriptor(
// make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), c_grid_desc_mraw_nraw,
// make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)),
// make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}),
// } make_tuple(Sequence<0>{}, Sequence<1>{}));
// else }
// { else
// // not pad M or N {
// return c_grid_desc_mraw_nraw; // not pad M or N
// } return c_grid_desc_mraw_nraw;
}
#endif
} }
struct Problem struct Problem
...@@ -1069,10 +1075,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1069,10 +1075,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl; << __LINE__ << ", in function: " << __func__ << std::endl;
} }
std::cout << "Arg M (" << karg.M
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
return false; return false;
} }
...@@ -1107,10 +1110,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1107,10 +1110,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl; << __LINE__ << ", in function: " << __func__ << std::endl;
} }
std::cout << "Arg K (" << karg.K
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
return false; return false;
} }
} }
...@@ -1128,12 +1128,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1128,12 +1128,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl; << std::endl;
} }
std::cout << "Arg N (" << karg.N
<< ") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
return false; return false;
} }
} }
...@@ -1150,27 +1145,11 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1150,27 +1145,11 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl; << std::endl;
} }
std::cout << "Arg M (" << karg.M
<< ") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
return false; return false;
} }
} }
// @Emin Need to Remove This !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
if constexpr(is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << " Grid size: " << karg.Grid_size << " > 1 is not support yet"
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
}
// check gridwise gemm pipeline // check gridwise gemm pipeline
const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = bhalf_t;
using F32 = float;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
template <index_t... Is>
using S = Sequence<Is...>;
using PassThrough = element_wise::PassThrough;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
static constexpr auto GemmMPadding = GemmSpecialization::MPadding;
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
static constexpr auto GemmMKPadding = GemmSpecialization::MKPadding;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
template <GemmSpecialization GemmSpec>
using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_instances = std::tuple<
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 32, 32, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 2, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>
// clang-format on
>;
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_instances = std::tuple<
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Latency friendly
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
// Memory friendly
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 2, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 2, 2, 16, 16, 4, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 64, 8, 4, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 64, 8, 4, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 4, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -43,7 +43,7 @@ using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_instances = ...@@ -43,7 +43,7 @@ using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_instances =
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Compute friendly // Compute friendly
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, "N, K0, K1" S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 32, 32, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 32, 32, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
...@@ -52,54 +52,13 @@ using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_instances = ...@@ -52,54 +52,13 @@ using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_instances =
// DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, // DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
// AGPR Spill when use permuted lds layout. so, use padding for these two. // AGPR Spill when use permuted lds layout. so, use padding for these two.
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>
// DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, 1, 1, S<1, 16, 1, 16>, 2, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>
// // Emin Added Kernel Instance
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec,
// Tiling Parameters - How to partition 'Block tiling - wave tiling'
256, // Block Size
128, // MPer Block
128, // NPer Block
64, // KPer Block
8, // AK1 ::
8, // BK1 float4 float8
32, // MPer XDL
32, // NPer XDL
2, // MXdl Per Wave
2, // NXdl Per Wave
// For Tensor A these define how to copy data from Global to Shared Mem
S<8, 32, 1>, // ABlockTransfer ThreadCluster Lengths_K0_M_K1
S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder !!!!! Determined by Layout
S<1, 0, 2>, // ABlockTransfer SrcAccessOrder !!!!!! Determined by Layout , Always 1-0-2 If A is row major ,
2, // ABlockTransfer SrcVectorDim !! If A is row major this is always 2
2, // ABlockTransfer SrcScalar PerVector // How you read 'A tensor' data from global memory
8, // ABlockTransfer DstScalar PerVector_K1 // How you write 'A tensor' data to shared memory
0, // ABlockLds AddExtraM
// For Tensor B these define how to copy data from Global to Shared Mem
S<8, 32, 1>, // BBlockTransfer ThreadCluster Lengths_K0_N_K1
S<1, 0, 2>, // BBlockTransfer ThreadCluster ArrangeOrder Always 1-0-2 If B is col major
S<1, 0, 2>, // BBlockTransfer SrcAccessOrder Always 1-0-2 If B is col major
2, // BBlockTransfer SrcVectorDim !! If B is column major this is always 2
2, // BBlockTransfer SrcScalar PerVector
8, // BlockTransfer DstScalar PerVector_k1
0, // B BlockLdsAddExtraN
1, // CShuffle MXdlPerWave PerShuffle ::
1, // CShuffle NXdlPerWave 2 OR 1 it depens on kernel sometimes only 1
S<1, 16, 1, 16>,
2,
BlockGemmPipelineScheduler::Interwave,
BlockGemmPipelineVersion::v1>
// ...existing code...
// clang-format on // clang-format on
>; >;
...@@ -128,46 +87,9 @@ using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_instances = ...@@ -128,46 +87,9 @@ using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_instances =
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
// // Emin Added Kernel Instance
DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec,
// Tiling Parameters - How to partition 'Block tiling - wave tiling
256, // Block Size
16, // MPer Block
256, // NPer Block
64, // KPer Block
8, // AK1 ::
8, // BK1 float4 float8
16, // MPer XDL
16, // NPer XDL
1, // MXdl Per Wave
4, // NXdl Per Wave
// For Tensor A these define how to copy data from Global to Shared Mem
S<8, 16, 1>, //S<8, 32, 1>, // ABlockTransfer ThreadCluster Lengths_K0_M_K1
S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder !!!!! Determined by Layout
S<1, 0, 2>, // ABlockTransfer SrcAccessOrder !!!!!! Determined by Layout , Always 1-0-2 If A is row major ,
2, // ABlockTransfer SrcVectorDim !! If A is row major this is always 2
2, // ABlockTransfer SrcScalar PerVector // How you read 'A tensor' data from global memory
8, // ABlockTransfer DstScalar PerVector_K1 // How you write 'A tensor' data to shared memory
0, // ABlockLds AddExtraM
// For Tensor B these define how to copy data from Global to Shared Mem
S<8, 16, 1>, // BBlockTransfer ThreadCluster Lengths_K0_N_K1
S<1, 0, 2>, // BBlockTransfer ThreadCluster ArrangeOrder Always 1-0-2 If B is col major
S<1, 0, 2>, // BBlockTransfer SrcAccessOrder Always 1-0-2 If B is col major
2, // BBlockTransfer SrcVectorDim !! If B is column major this is always 2 // DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, 1, 1, S<1, 16, 1, 16>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
2, // BBlockTransfer SrcScalar PerVector
8,
0,
1,
1,
S<1, 16, 1, 16>,
2,
BlkGemmPipeSched,
BlockGemmPipelineVersion::v2>
// clang-format on // clang-format on
>; >;
} // namespace instance } // namespace instance
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment