Unverified Commit 7b826807 authored by jefyang1's avatar jefyang1 Committed by GitHub
Browse files

Fix KPack and enable existing instances on gfx950 (#1871)

parent 3c7fef7f
......@@ -515,9 +515,16 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<ABDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
math::max(lcm_AK1_BK1,
MfmaSelector<ABDataType, MPerXdl, NPerXdl, ABDataType, is_single_rate_mfma>::
selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
......
......@@ -448,8 +448,16 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
// acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
lcm_AK1_BK1,
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize,
......
......@@ -361,8 +361,16 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
const auto M = d0_grid_desc_m_n.GetLength(I0);
const auto N = d0_grid_desc_m_n.GetLength(I1);
constexpr auto mfma =
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma;
constexpr bool is_single_rate_mfma =
((is_same<A0B0B1DataType, half_t>::value || is_same<A0B0B1DataType, bhalf_t>::value) &&
math::lcm(A0K1, B0K1) <= 4)
? true
: false;
constexpr auto mfma = MfmaSelector<A0B0B1DataType,
Gemm0MPerXdl,
Gemm0NPerXdl,
A0B0B1DataType,
is_single_rate_mfma>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N5 = mfma.group_size;
return transform_tensor_descriptor(
......@@ -643,9 +651,19 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
// acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check
constexpr index_t KPack = math::max(
math::lcm(A0K1, B0K1),
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma.k_per_blk);
constexpr auto lcm_A0K1_B0K1 = math::lcm(A0K1, B0K1);
constexpr bool is_single_rate_mfma =
((is_same<A0B0B1DataType, half_t>::value || is_same<A0B0B1DataType, bhalf_t>::value) &&
lcm_A0K1_B0K1 <= 4)
? true
: false;
constexpr index_t KPack =
math::max(lcm_A0K1_B0K1,
MfmaSelector<A0B0B1DataType,
Gemm0MPerXdl,
Gemm0NPerXdl,
A0B0B1DataType,
is_single_rate_mfma>::selected_mfma.k_per_blk);
auto blockwise_gemm0 = BlockwiseGemmXdlops_v2<
BlockSize,
......
......@@ -343,7 +343,13 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
const auto M = d0_grid_desc_m_n.GetLength(I0);
const auto N = d0_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma;
constexpr bool is_single_rate_mfma =
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
math::lcm(AK1, BK1) <= 4)
? true
: false;
constexpr auto mfma =
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size;
......@@ -552,8 +558,16 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
// acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
lcm_AK1_BK1,
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize,
......
......@@ -469,8 +469,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
lcm_AK1_BK1,
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize,
......
......@@ -498,8 +498,16 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
lcm_AK1_BK1,
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
......
......@@ -464,8 +464,16 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
lcm_AK1_BK1,
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
......
......@@ -599,9 +599,16 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<AComputeType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<AComputeType, half_t>::value || is_same<AComputeType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack = math::max(
lcm_AK1_BK1,
MfmaSelector<AComputeType, MPerXdl, NPerXdl, AComputeType, is_single_rate_mfma>::
selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
......
......@@ -451,8 +451,16 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
lcm_AK1_BK1,
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
......
......@@ -581,9 +581,16 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<ABDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
math::max(lcm_AK1_BK1,
MfmaSelector<ABDataType, MPerXdl, NPerXdl, ABDataType, is_single_rate_mfma>::
selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
......@@ -1006,9 +1013,16 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<ABDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
math::max(lcm_AK1_BK1,
MfmaSelector<ABDataType, MPerXdl, NPerXdl, ABDataType, is_single_rate_mfma>::
selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
......
......@@ -595,9 +595,16 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<ComputeType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<ComputeType, half_t>::value || is_same<ComputeType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack = math::max(
lcm_AK1_BK1,
MfmaSelector<ComputeType, MPerXdl, NPerXdl, ComputeType, is_single_rate_mfma>::
selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
......
......@@ -79,9 +79,16 @@ struct GridwiseGemm_xdl_cshuffle_v3
static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{};
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
static constexpr bool is_single_rate_mfma =
((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
static constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number),
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
math::max(lcm_AK1_BK1,
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>::
selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
......
......@@ -139,9 +139,16 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{};
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
static constexpr bool is_single_rate_mfma =
((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
static constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number),
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
math::max(lcm_AK1_BK1,
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>::
selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
__host__ static auto CalculateMPadded(index_t M)
......
......@@ -869,9 +869,16 @@ struct GridwiseGemm_xdl_cshuffle_v2
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number),
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
constexpr bool is_single_rate_mfma =
((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack = math::max(
lcm_AK1_BK1,
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>::
selected_mfma.k_per_blk);
// auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
// BlockSize,
......
......@@ -147,9 +147,16 @@ struct GridwiseGemm_xdl_cshuffle_v3
static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{};
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
static constexpr bool is_single_rate_mfma =
((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
static constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number),
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
math::max(lcm_AK1_BK1,
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>::
selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
......
......@@ -155,9 +155,16 @@ struct GridwiseGemm_xdl_cshuffle_v3
static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{};
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
static constexpr bool is_single_rate_mfma =
((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
static constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number),
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
math::max(lcm_AK1_BK1,
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>::
selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
......@@ -1424,7 +1431,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
// b scale
// static_assert(KPerBlock <= ScaleBlockK);
static constexpr auto mfma = MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>{};
static constexpr auto mfma =
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>{};
static constexpr auto KPerXdlops = mfma.GetKPerXdlops();
static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops();
static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
......@@ -1895,7 +1903,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
KPerBlock);
// B scale
static constexpr auto mfma = MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>{};
static constexpr auto mfma =
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>{};
static constexpr auto KPerXdlops = mfma.GetKPerXdlops();
static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops();
static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
......
......@@ -489,8 +489,16 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
lcm_AK1_BK1,
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
......
......@@ -487,9 +487,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
else if(TileMathThreadGroup::IsBelong())
{
// branch early for math wave
constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<ABDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t KPack = math::max(
lcm_AK1_BK1,
MfmaSelector<ABDataType, MPerXdl, NPerXdl, ABDataType, is_single_rate_mfma>::
selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<
TileMathThreadGroupSize,
......
......@@ -446,8 +446,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
constexpr index_t k_pack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
lcm_AK1_BK1,
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
.k_per_blk);
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment