Commit 1f9c0a13 authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

tmp save

parents da8e50dd 0e5e29c4
...@@ -464,8 +464,16 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -464,8 +464,16 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack = math::max( constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); lcm_AK1_BK1,
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
......
...@@ -599,9 +599,16 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -599,9 +599,16 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
constexpr index_t KPack = constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
math::max(math::lcm(AK1, BK1), constexpr bool is_single_rate_mfma =
MfmaSelector<AComputeType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); ((is_same<AComputeType, half_t>::value || is_same<AComputeType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack = math::max(
lcm_AK1_BK1,
MfmaSelector<AComputeType, MPerXdl, NPerXdl, AComputeType, is_single_rate_mfma>::
selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
......
...@@ -451,8 +451,16 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -451,8 +451,16 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack = math::max( constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); lcm_AK1_BK1,
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
......
...@@ -581,9 +581,16 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle ...@@ -581,9 +581,16 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack = constexpr index_t KPack =
math::max(math::lcm(AK1, BK1), math::max(lcm_AK1_BK1,
MfmaSelector<ABDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<ABDataType, MPerXdl, NPerXdl, ABDataType, is_single_rate_mfma>::
selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
...@@ -1006,9 +1013,16 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle ...@@ -1006,9 +1013,16 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack = constexpr index_t KPack =
math::max(math::lcm(AK1, BK1), math::max(lcm_AK1_BK1,
MfmaSelector<ABDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<ABDataType, MPerXdl, NPerXdl, ABDataType, is_single_rate_mfma>::
selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
......
...@@ -595,9 +595,16 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -595,9 +595,16 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
constexpr index_t KPack = constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
math::max(math::lcm(AK1, BK1), constexpr bool is_single_rate_mfma =
MfmaSelector<ComputeType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); ((is_same<ComputeType, half_t>::value || is_same<ComputeType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack = math::max(
lcm_AK1_BK1,
MfmaSelector<ComputeType, MPerXdl, NPerXdl, ComputeType, is_single_rate_mfma>::
selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
......
...@@ -79,9 +79,16 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -79,9 +79,16 @@ struct GridwiseGemm_xdl_cshuffle_v3
static constexpr auto AK1Number = Number<AK1Value>{}; static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{}; static constexpr auto BK1Number = Number<BK1Value>{};
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
static constexpr bool is_single_rate_mfma =
((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
static constexpr index_t KPack = static constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number), math::max(lcm_AK1_BK1,
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>::
selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
......
...@@ -139,9 +139,16 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -139,9 +139,16 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
static constexpr auto AK1Number = Number<AK1Value>{}; static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{}; static constexpr auto BK1Number = Number<BK1Value>{};
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
static constexpr bool is_single_rate_mfma =
((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
static constexpr index_t KPack = static constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number), math::max(lcm_AK1_BK1,
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>::
selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
__host__ static auto CalculateMPadded(index_t M) __host__ static auto CalculateMPadded(index_t M)
......
...@@ -869,9 +869,16 @@ struct GridwiseGemm_xdl_cshuffle_v2 ...@@ -869,9 +869,16 @@ struct GridwiseGemm_xdl_cshuffle_v2
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
constexpr index_t KPack = constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
math::max(math::lcm(AK1Number, BK1Number), constexpr bool is_single_rate_mfma =
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); ((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack = math::max(
lcm_AK1_BK1,
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>::
selected_mfma.k_per_blk);
// auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< // auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
// BlockSize, // BlockSize,
......
...@@ -147,9 +147,16 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -147,9 +147,16 @@ struct GridwiseGemm_xdl_cshuffle_v3
static constexpr auto AK1Number = Number<AK1Value>{}; static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{}; static constexpr auto BK1Number = Number<BK1Value>{};
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
static constexpr bool is_single_rate_mfma =
((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
static constexpr index_t KPack = static constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number), math::max(lcm_AK1_BK1,
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>::
selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
......
...@@ -155,9 +155,16 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -155,9 +155,16 @@ struct GridwiseGemm_xdl_cshuffle_v3
static constexpr auto AK1Number = Number<AK1Value>{}; static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{}; static constexpr auto BK1Number = Number<BK1Value>{};
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
static constexpr bool is_single_rate_mfma =
((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
static constexpr index_t KPack = static constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number), math::max(lcm_AK1_BK1,
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>::
selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
...@@ -1424,7 +1431,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1424,7 +1431,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
// b scale // b scale
// static_assert(KPerBlock <= ScaleBlockK); // static_assert(KPerBlock <= ScaleBlockK);
static constexpr auto mfma = MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>{}; static constexpr auto mfma =
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>{};
static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); static constexpr auto KPerXdlops = mfma.GetKPerXdlops();
static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops(); static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops();
static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
...@@ -1895,7 +1903,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1895,7 +1903,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
KPerBlock); KPerBlock);
// B scale // B scale
static constexpr auto mfma = MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>{}; static constexpr auto mfma =
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>{};
static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); static constexpr auto KPerXdlops = mfma.GetKPerXdlops();
static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops(); static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops();
static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
......
...@@ -489,8 +489,16 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -489,8 +489,16 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack = math::max( constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); lcm_AK1_BK1,
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
......
...@@ -487,9 +487,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle ...@@ -487,9 +487,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
else if(TileMathThreadGroup::IsBelong()) else if(TileMathThreadGroup::IsBelong())
{ {
// branch early for math wave // branch early for math wave
constexpr index_t KPack = constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
math::max(math::lcm(AK1, BK1), constexpr bool is_single_rate_mfma =
MfmaSelector<ABDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); ((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack = math::max(
lcm_AK1_BK1,
MfmaSelector<ABDataType, MPerXdl, NPerXdl, ABDataType, is_single_rate_mfma>::
selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<
TileMathThreadGroupSize, TileMathThreadGroupSize,
......
...@@ -446,8 +446,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -446,8 +446,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t k_pack = math::max( constexpr index_t k_pack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); lcm_AK1_BK1,
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
.k_per_blk);
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
......
...@@ -68,52 +68,82 @@ struct transpose_vectors ...@@ -68,52 +68,82 @@ struct transpose_vectors
} }
else if constexpr(sizeof(S) == 1) else if constexpr(sizeof(S) == 1)
{ {
static_assert((NX % 4 == 0 && NY % 4 == 0), "wrong!"); static_assert(((NX % 4 == 0 && NY % 4 == 0) || (NX % 2 == 0 && NY % 2 == 0)), "wrong!");
using S4 = array<S, 4>; // typename array<S, 4>::type; using S4 = array<S, 4>; // typename array<S, 4>::type;
using S2 = array<S, 2>; // typename array<S, 4>::type;
// loop over 4x4 tile and transpose data from vx_tuple into vy_tuple
static_for<0, NY, 4>{}([&](auto iy) { if constexpr(NX % 4 == 0 && NY % 4 == 0)
static_for<0, NX, 4>{}([&](auto ix) { {
// 4 int8x4 data from vx_tuple // loop over 4x4 tile and transpose data from vx_tuple into vy_tuple
const int32_t x_s4_0 = static_for<0, NY, 4>{}([&](auto iy) {
bit_cast<int32_t>(vx_tuple[ix].template get_as<S4>()[iy / I4]); static_for<0, NX, 4>{}([&](auto ix) {
const int32_t x_s4_1 = // 4 int8x4 data from vx_tuple
bit_cast<int32_t>(vx_tuple[ix + I1].template get_as<S4>()[iy / I4]); const int32_t x_s4_0 =
const int32_t x_s4_2 = bit_cast<int32_t>(vx_tuple[ix].template get_as<S4>()[iy / I4]);
bit_cast<int32_t>(vx_tuple[ix + I2].template get_as<S4>()[iy / I4]); const int32_t x_s4_1 =
const int32_t x_s4_3 = bit_cast<int32_t>(vx_tuple[ix + I1].template get_as<S4>()[iy / I4]);
bit_cast<int32_t>(vx_tuple[ix + I3].template get_as<S4>()[iy / I4]); const int32_t x_s4_2 =
bit_cast<int32_t>(vx_tuple[ix + I2].template get_as<S4>()[iy / I4]);
// transpose const int32_t x_s4_3 =
int32_t t_s4_0, t_s4_1; bit_cast<int32_t>(vx_tuple[ix + I3].template get_as<S4>()[iy / I4]);
int32_t y_s4_0, y_s4_1, y_s4_2, y_s4_3;
// transpose
constexpr int32_t m0 = 0x05010400; int32_t t_s4_0, t_s4_1;
constexpr int32_t m1 = 0x05040100; int32_t y_s4_0, y_s4_1, y_s4_2, y_s4_3;
constexpr int32_t m2 = 0x07060302;
constexpr int32_t m3 = 0x07030602; constexpr int32_t m0 = 0x05010400;
constexpr int32_t m1 = 0x05040100;
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488 constexpr int32_t m2 = 0x07060302;
// -- -- -- -- -- -- -- -- - - - - constexpr int32_t m3 = 0x07030602;
// index 7 6 5 4 3 2 1 0 33 77 44 88
// index is reversed because of little endianness (least significant bits first) // ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) ->
t_s4_0 = __builtin_amdgcn_perm(x_s4_1, x_s4_0, m0); // 0x33774488
t_s4_1 = __builtin_amdgcn_perm(x_s4_3, x_s4_2, m0); // -- -- -- -- -- -- -- -- - - - -
y_s4_0 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m1); // index 7 6 5 4 3 2 1 0 33 77 44 88
y_s4_1 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m2); // index is reversed because of little endianness (least significant bits
t_s4_0 = __builtin_amdgcn_perm(x_s4_1, x_s4_0, m3); // first)
t_s4_1 = __builtin_amdgcn_perm(x_s4_3, x_s4_2, m3); t_s4_0 = __builtin_amdgcn_perm(x_s4_1, x_s4_0, m0);
y_s4_2 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m1); t_s4_1 = __builtin_amdgcn_perm(x_s4_3, x_s4_2, m0);
y_s4_3 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m2); y_s4_0 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m1);
y_s4_1 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m2);
// 4 int8x4 data from vy_tuple t_s4_0 = __builtin_amdgcn_perm(x_s4_1, x_s4_0, m3);
vy_tuple(iy).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_0); t_s4_1 = __builtin_amdgcn_perm(x_s4_3, x_s4_2, m3);
vy_tuple(iy + I1).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_1); y_s4_2 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m1);
vy_tuple(iy + I2).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_2); y_s4_3 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m2);
vy_tuple(iy + I3).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_3);
// 4 int8x4 data from vy_tuple
vy_tuple(iy).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_0);
vy_tuple(iy + I1).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_1);
vy_tuple(iy + I2).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_2);
vy_tuple(iy + I3).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_3);
});
}); });
}); }
else if constexpr(NX % 2 == 0 && NY % 2 == 0)
{
static_for<0, NY, 2>{}([&](auto ix) {
static_for<0, NX, 2>{}([&](auto iy) {
const int16_t x_s2_0 =
bit_cast<int16_t>(vx_tuple[ix].template get_as<S2>()[iy / I2]);
const int16_t x_s2_1 =
bit_cast<int16_t>(vx_tuple[ix + I1].template get_as<S2>()[iy / I2]);
constexpr int32_t m0 = 0x05040100;
constexpr int32_t m1 = 0x07060302;
const int32_t x0_32 = static_cast<int32_t>(x_s2_0 & 0xFFFF);
const int32_t x1_32 = static_cast<int32_t>(x_s2_1 & 0xFFFF);
const int32_t y_s2_0 = __builtin_amdgcn_perm(x1_32, x0_32, m0);
const int32_t y_s2_1 = __builtin_amdgcn_perm(x1_32, x0_32, m1);
vy_tuple(iy).template get_as<S2>()[ix / I2] =
bit_cast<S2>(static_cast<int16_t>(y_s2_0 & 0xFFFF));
vy_tuple(iy + I1).template get_as<S2>()[ix / I2] =
bit_cast<S2>(static_cast<int16_t>(y_s2_1 & 0xFFFF));
});
});
}
} }
else else
{ {
......
...@@ -29,6 +29,8 @@ ...@@ -29,6 +29,8 @@
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
......
...@@ -14,24 +14,54 @@ namespace ck_tile { ...@@ -14,24 +14,54 @@ namespace ck_tile {
template <typename Problem_, typename Policy_ = BlockGemmARegBRegCRegV1DefaultPolicy> template <typename Problem_, typename Policy_ = BlockGemmARegBRegCRegV1DefaultPolicy>
struct BlockGemmARegBRegCRegV1 struct BlockGemmARegBRegCRegV1
{ {
using Problem = remove_cvref_t<Problem_>; private:
using Policy = remove_cvref_t<Policy_>; template <typename PipelineProblem_, typename GemmPolicy_>
using ADataType = remove_cvref_t<typename Problem::ADataType>; struct GemmTraits_
using BDataType = remove_cvref_t<typename Problem::BDataType>; {
using CDataType = remove_cvref_t<typename Problem::CDataType>; using Problem = remove_cvref_t<PipelineProblem_>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; using Policy = remove_cvref_t<GemmPolicy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
static constexpr index_t kBlockSize = Problem::kBlockSize; using BDataType = remove_cvref_t<typename Problem::BDataType>;
static constexpr index_t MPerBlock = BlockGemmShape::kM; using CDataType = remove_cvref_t<typename Problem::CDataType>;
static constexpr index_t NPerBlock = BlockGemmShape::kN; using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>(); static constexpr index_t kBlockSize = Problem::kBlockSize;
using WG = remove_cvref_t<decltype(config.template at<0>())>;
static constexpr index_t MWarp = config.template at<1>(); static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NWarp = config.template at<2>(); static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WG::kK; static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
static constexpr index_t MWarp = config.template at<1>();
static constexpr index_t NWarp = config.template at<2>();
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
static constexpr index_t KPack = WarpGemm::kKPerThread;
};
public:
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using Traits = GemmTraits_<Problem, Policy>;
using WarpGemm = typename Traits::WarpGemm;
using BlockGemmShape = typename Traits::BlockGemmShape;
using ADataType = remove_cvref_t<typename Traits::ADataType>;
using BDataType = remove_cvref_t<typename Traits::BDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
static constexpr index_t MIterPerWarp = Traits::MIterPerWarp;
static constexpr index_t NIterPerWarp = Traits::NIterPerWarp;
static constexpr index_t MWarp = Traits::MWarp;
static constexpr index_t NWarp = Traits::NWarp;
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
{ {
...@@ -43,7 +73,7 @@ struct BlockGemmARegBRegCRegV1 ...@@ -43,7 +73,7 @@ struct BlockGemmARegBRegCRegV1
sequence<1, 2>, sequence<1, 2>,
sequence<0, 0>>{}; sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
return a_block_dstr_encode; return a_block_dstr_encode;
} }
...@@ -58,7 +88,7 @@ struct BlockGemmARegBRegCRegV1 ...@@ -58,7 +88,7 @@ struct BlockGemmARegBRegCRegV1
sequence<1, 2>, sequence<1, 2>,
sequence<0, 0>>{}; sequence<0, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{}); b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
return b_block_dstr_encode; return b_block_dstr_encode;
} }
...@@ -73,7 +103,7 @@ struct BlockGemmARegBRegCRegV1 ...@@ -73,7 +103,7 @@ struct BlockGemmARegBRegCRegV1
sequence<1, 2>, sequence<1, 2>,
sequence<0, 0>>{}; sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
return c_block_dstr_encode; return c_block_dstr_encode;
} }
...@@ -112,13 +142,13 @@ struct BlockGemmARegBRegCRegV1 ...@@ -112,13 +142,13 @@ struct BlockGemmARegBRegCRegV1
.get_static_tile_distribution_encoding())>>, .get_static_tile_distribution_encoding())>>,
"C distribution is wrong!"); "C distribution is wrong!");
using AWarpDstr = typename WG::AWarpDstr; using AWarpDstr = typename WarpGemm::AWarpDstr;
using BWarpDstr = typename WG::BWarpDstr; using BWarpDstr = typename WarpGemm::BWarpDstr;
using CWarpDstr = typename WG::CWarpDstr; using CWarpDstr = typename WarpGemm::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor; using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WG::BWarpTensor; using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WG::CWarpTensor; using CWarpTensor = typename WarpGemm::CWarpTensor;
constexpr auto a_warp_y_lengths = constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
...@@ -157,7 +187,7 @@ struct BlockGemmARegBRegCRegV1 ...@@ -157,7 +187,7 @@ struct BlockGemmARegBRegCRegV1
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM // warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor // write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data( c_block_tensor.set_y_sliced_thread_data(
...@@ -180,7 +210,7 @@ struct BlockGemmARegBRegCRegV1 ...@@ -180,7 +210,7 @@ struct BlockGemmARegBRegCRegV1
sequence<0, 0>>{}; sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr); auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor; return c_block_tensor;
......
...@@ -463,7 +463,9 @@ struct GemmKernel ...@@ -463,7 +463,9 @@ struct GemmKernel
* @param a_ptr input A pointer * @param a_ptr input A pointer
* @param b_ptr input B pointer * @param b_ptr input B pointer
* @param c_ptr output C pointer * @param c_ptr output C pointer
* @param smem_ptr_0 The start memory pointer of the shared memory block.
* @param kargs GEMM kernel arguments * @param kargs GEMM kernel arguments
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
* *
...@@ -473,7 +475,7 @@ struct GemmKernel ...@@ -473,7 +475,7 @@ struct GemmKernel
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr, CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
const BDataType* b_ptr, const BDataType* b_ptr,
CDataType* c_ptr, CDataType* c_ptr,
void* smem_ptr, void* smem_ptr_0,
const GemmKernelArgs& kargs, const GemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset, const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m, const index_t block_idx_m,
...@@ -491,15 +493,67 @@ struct GemmKernel ...@@ -491,15 +493,67 @@ struct GemmKernel
// Run GEMM cooperatively by whole workgroup. // Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0); const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_block_window = gemm_tile_windows.at(I1); const auto& b_block_window = gemm_tile_windows.at(I1);
const auto& c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr); const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, smem_ptr_0);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I2);
EpiloguePipeline{}
.template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>(
c_block_window, c_block_tile, smem_ptr_0);
}
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
* @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism.
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param c_ptr output C pointer
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset Utility structure used to calculate k batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
* @tparam DstInMemOp Destination memory operation (default: set).
*/
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr,
const BDataType* b_ptr,
CDataType* c_ptr,
void* __restrict__ smem_ptr_0,
void* __restrict__ smem_ptr_1,
const GemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<DstInMemOp>(a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_block_window = gemm_tile_windows.at(I1);
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
// Run Epilogue Pipeline // Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I2); auto& c_block_window = gemm_tile_windows.at(I2);
EpiloguePipeline{} EpiloguePipeline{}
.template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>( .template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>(
c_block_window, c_block_tile, smem_ptr); c_block_window, c_block_tile, smem_ptr_0);
} }
CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const
...@@ -517,11 +571,27 @@ struct GemmKernel ...@@ -517,11 +571,27 @@ struct GemmKernel
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr); CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
// allocate LDS // allocate LDS
__shared__ char smem_ptr[GetSmemSize()]; __shared__ char smem_ptr_0[GetSmemSize()];
__shared__ char smem_ptr_1[GetSmemSize()];
if(kargs.k_batch == 1) if(kargs.k_batch == 1)
{ {
RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
RunGemm2LDS(a_ptr,
b_ptr,
c_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
else
{
RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
}
} }
else else
{ {
...@@ -530,8 +600,23 @@ struct GemmKernel ...@@ -530,8 +600,23 @@ struct GemmKernel
if constexpr(!(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 && if constexpr(!(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value)) is_any_of<CDataType, fp16_t, bf16_t>::value))
{ {
RunGemm<memory_operation_enum::atomic_add>( if constexpr(GemmPipeline::DoubleSmemBuffer == true)
a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); {
RunGemm2LDS<memory_operation_enum::atomic_add>(a_ptr,
b_ptr,
c_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
else
{
RunGemm<memory_operation_enum::atomic_add>(
a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
}
} }
} }
} }
......
...@@ -41,20 +41,26 @@ struct GemmPipelineAgBgCrImplBase ...@@ -41,20 +41,26 @@ struct GemmPipelineAgBgCrImplBase
store_tile(lds_tile_window, block_tile_tmp); store_tile(lds_tile_window, block_tile_tmp);
} }
template <typename DstBlockTile, typename SrcTileWindow>
CK_TILE_DEVICE void LocalPrefetch(DstBlockTile& dst_block_tile,
const SrcTileWindow& lds_tile_window) const
{
load_tile(dst_block_tile, lds_tile_window);
}
CK_TILE_DEVICE auto GetABLdsTensorViews(void* p_smem) const CK_TILE_DEVICE auto GetABLdsTensorViews(void* p_smem) const
{ {
// A tile in LDS // A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem); ADataType* __restrict__ p_a_lds = static_cast<ADataType*>(p_smem);
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>(); constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc); auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
// TODO: LDS alignment should come from Policy! // TODO: LDS alignment should come from Policy!
constexpr index_t a_lds_block_space_size_aligned = constexpr index_t a_lds_block_space_size_aligned = integer_least_multiple(
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) * sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16);
16;
// B tile in LDS // B tile in LDS
BDataType* p_b_lds = static_cast<BDataType*>( BDataType* __restrict__ p_b_lds = static_cast<BDataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned)); static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>(); constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc); auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
......
...@@ -76,6 +76,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -76,6 +76,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK; static constexpr bool kPadK = Problem::kPadK;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr bool HasHotLoop = Problem::HasHotLoop; static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum; static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler; static constexpr auto Scheduler = Problem::Scheduler;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp"
namespace ck_tile {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem>
struct BaseGemmPipelineAgBgCrCompV4
{
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
{
if(num_loop % PrefetchStages == 1)
{
return TailNumber::Three;
}
else
{
return TailNumber::Two;
}
}
};
/**
* @brief Compute optimized pipeline version 4
*
* This version introduces a dual LDS window mechanism using a ping-pong buffer approach
* for more efficient data handling from global memory. Unlike compute version 3, this method
* allows one LDS to fetch data from global memory while the other LDS executes warps for MFMA
* matrix multiplication. This dual operation helps in keeping the Warp unit continuously busy,
* thereby significantly reducing memory load times and enhancing overall performance.
*
* @note This version shows improved performance over Compute Version 3 with the same block tile.
* It is particularly more efficient for large matrices where M, N, and K are greater than 8K,
* even when Compute Version 3's block size is twice that of Compute Version 4.
*/
template <typename Problem, typename Policy = GemmPipelineAgBgCrCompV4DefaultPolicy>
struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
{
using Base = BaseGemmPipelineAgBgCrCompV4<Problem>;
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
using I0 = number<0>;
using I1 = number<1>;
using I2 = number<2>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
{
return Policy::template IsTransposeC<Problem>();
}
template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase
{
};
template <>
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
{
using Base = PipelineImplBase;
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
{
constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0{});
constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1{});
constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2{});
constexpr index_t WaveSize = 64;
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
constexpr index_t A_LDS_Read_Width = KPerXDL;
constexpr index_t B_LDS_Read_Width = KPerXDL;
constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) /
(MPerXDL * NPerXDL * KPerXDL);
constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16
? A_LDS_Read_Inst_Num
: A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16
? B_LDS_Read_Inst_Num
: B_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst = num_ds_read_inst_a + num_ds_read_inst_b;
constexpr auto num_ds_write_inst = A_LDS_Write_Inst_Num + B_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst = A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num;
constexpr auto num_issue = num_buffer_load_inst;
static_for<0, num_buffer_load_inst, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
__builtin_amdgcn_sched_group_barrier(
0x100, num_ds_read_inst / num_issue, 0); // DS read : 2
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1
__builtin_amdgcn_sched_group_barrier(
0x200, num_ds_write_inst / num_issue, 0); // DS write : 1
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1
__builtin_amdgcn_sched_group_barrier(
0x008, C_MFMA_Inst_Num / num_issue - 3, 0); // MFMA : 5
});
__builtin_amdgcn_sched_barrier(0);
}
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType,
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"Data Type conflict on A and B matrix input data type.");
constexpr bool is_a_col_major =
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_a_col_major
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"A block window has incorrect lengths for defined ALayout!");
static_assert(is_b_row_major
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
////////////// global window & register /////////////////
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window_linear(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
// B DRAM tile window for load
auto b_copy_dram_window =
make_tile_window_linear(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
// A register tile for global load
constexpr auto ABlockTileDistr = a_copy_dram_window.get_tile_distribution();
constexpr auto BBlockTileDistr = b_copy_dram_window.get_tile_distribution();
using ABlockTile = decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr));
using BBlockTile = decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr));
ABlockTile a_global_load_tile;
BBlockTile b_global_load_tile;
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
// global prefetch 0
// global read 0
Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
////////////// LDS desc, window & register /////////////////
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0);
auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1);
auto a_copy_lds_window0 = make_tile_window(
a_lds_block0, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
auto a_copy_lds_window1 = make_tile_window(
a_lds_block1, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
auto b_copy_lds_window0 = make_tile_window(
b_lds_block0, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
auto b_copy_lds_window1 = make_tile_window(
b_lds_block1, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
// Block GEMM
auto block_gemm = BlockGemm();
auto c_block_tile = block_gemm.MakeCBlockTile();
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window0, b_global_load_tile, b_element_func);
}
// global read 1
Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
block_sync_lds();
constexpr auto ALdsTileDistr = decltype(make_static_tile_distribution(
BlockGemm::MakeABlockDistributionEncode())){};
constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution(
BlockGemm::MakeBBlockDistributionEncode())){};
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
ALdsTile a_block_tile0;
ALdsTile a_block_tile1;
BLdsTile b_block_tile0;
BLdsTile b_block_tile1;
auto a_lds_ld_window0 =
make_tile_window_linear(a_lds_block0,
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
{0, 0},
ALdsTileDistr);
auto a_lds_ld_window1 =
make_tile_window_linear(a_lds_block1,
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
{0, 0},
ALdsTileDistr);
auto b_lds_ld_window0 =
make_tile_window_linear(b_lds_block0,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{0, 0},
BLdsTileDistr);
auto b_lds_ld_window1 =
make_tile_window_linear(b_lds_block1,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{0, 0},
BLdsTileDistr);
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0);
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0);
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window1, a_global_load_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window1, b_global_load_tile, b_element_func);
}
Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
if(HasHotLoop)
{
// minus 2 because we have ping-pong double buffer.
index_t iCounter = __builtin_amdgcn_readfirstlane(num_loop - 2);
do
{
// ping
{
block_sync_lds();
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1);
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(
a_copy_lds_window0, a_global_load_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(
b_copy_lds_window0, b_global_load_tile, b_element_func);
}
Base::GlobalPrefetch(
a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(
b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
// gemm
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
}
// pong
{
block_sync_lds();
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0);
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0);
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(
a_copy_lds_window1, a_global_load_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(
b_copy_lds_window1, b_global_load_tile, b_element_func);
}
Base::GlobalPrefetch(
a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(
b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
// gemm
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
}
iCounter -= 2;
} while(iCounter > 1);
}
// tail 3
if(TailNum == TailNumber::Three)
{
// 3
{
block_sync_lds();
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1);
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window0, b_global_load_tile, b_element_func);
}
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
}
// 2
{
block_sync_lds();
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0);
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0);
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
}
// 1
{
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
__builtin_amdgcn_sched_barrier(0);
}
}
else
{
// 2
{
block_sync_lds();
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
static_for<0, 8, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA
});
__builtin_amdgcn_sched_barrier(0);
}
// 1
{
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
__builtin_amdgcn_sched_barrier(0);
}
}
return c_block_tile;
}
};
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem_0,
void* p_smem_1) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem_0,
p_smem_1);
}
public:
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const index_t num_loop,
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
num_loop,
p_smem_0,
p_smem_1);
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment