"...composable_kernel_rocm.git" did not exist on "827301d95af93d581ddac8d2734ec759ea215c6c"
Commit e95cb82b authored by jefyang1's avatar jefyang1
Browse files

Fix gemm gemm on gfx950

parent c38163cd
......@@ -5,6 +5,6 @@ if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp)
endif(USE_BITINT_EXTENSION_INT4)
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1")
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx95" AND NOT GPU_TARGETS MATCHES "gfx1")
add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp)
endif()
......@@ -5,6 +5,6 @@ if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp)
endif(USE_BITINT_EXTENSION_INT4)
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1")
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx95" AND NOT GPU_TARGETS MATCHES "gfx1")
add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp)
endif()
......@@ -608,14 +608,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
#if defined(__gfx950__)
// TODO: fix logic for gfx950 as it's temporary hack for passing compiling
constexpr index_t Gemm1KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
#else
constexpr index_t Gemm1KPack =
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size;
#endif
auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize,
......
......@@ -856,11 +856,18 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
static_cast<A0B0B1DataType*>(p_shared) + SharedMemTrait::b1_block_space_offset,
b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr index_t Gemm1KPack = math::max(
math::lcm(
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma.group_size,
B1K1),
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma.k_per_blk);
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
// selected_mfma.k_per_blk <= Gemm1KPack
//
// Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
// multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr index_t Gemm1KPack =
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma.group_size;
auto blockwise_gemm1 = BlockwiseGemmXdlops_v2<
BlockSize,
......
......@@ -773,14 +773,10 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
#if defined(__gfx950__)
// TODO: fix logic for gfx950 as it's temporary hack for passing compiling
constexpr index_t Gemm1KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
#else
constexpr index_t Gemm1KPack =
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size;
#endif
auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize,
FloatAB,
......
......@@ -628,14 +628,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
#if defined(__gfx950__)
// TODO: fix logic for gfx950 as it's temporary hack for passing compiling
constexpr index_t Gemm1KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
#else
constexpr index_t Gemm1KPack =
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size;
#endif
auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize,
FloatAB,
......
......@@ -880,13 +880,15 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
template <typename base_type,
index_t MPerXdlops,
index_t NPerXdlops,
typename additional_type = base_type>
typename additional_type = base_type,
bool is_single_rate_mfma = false>
struct MfmaSelector
{
template <typename base_type_,
index_t MPerXdlops_,
index_t NPerXdlops_,
typename additional_type_ = base_type_>
typename additional_type_ = base_type_,
bool is_single_rate_mfma_ = false>
static constexpr auto GetMfma();
template <>
......@@ -950,7 +952,7 @@ struct MfmaSelector
}
template <>
constexpr auto GetMfma<half_t, 32, 32>()
constexpr auto GetMfma<half_t, 32, 32, half_t, false>()
{
#if defined(__gfx950__)
return MfmaInstr::mfma_f32_32x32x16f16;
......@@ -958,9 +960,14 @@ struct MfmaSelector
return MfmaInstr::mfma_f32_32x32x8f16;
#endif
}
template <>
constexpr auto GetMfma<half_t, 32, 32, half_t, true>()
{
return MfmaInstr::mfma_f32_32x32x8f16;
}
template <>
constexpr auto GetMfma<half_t, 16, 16>()
constexpr auto GetMfma<half_t, 16, 16, half_t, false>()
{
#if defined(__gfx950__)
return MfmaInstr::mfma_f32_16x16x32f16;
......@@ -969,6 +976,12 @@ struct MfmaSelector
#endif
}
template <>
constexpr auto GetMfma<half_t, 16, 16, half_t, true>()
{
return MfmaInstr::mfma_f32_16x16x16f16;
}
template <>
constexpr auto GetMfma<half_t, 16, 64>()
{
......@@ -988,7 +1001,7 @@ struct MfmaSelector
}
template <>
constexpr auto GetMfma<bhalf_t, 32, 32>()
constexpr auto GetMfma<bhalf_t, 32, 32, bhalf_t, false>()
{
#if defined(__gfx950__)
return MfmaInstr::mfma_f32_32x32x16bf16;
......@@ -1000,7 +1013,17 @@ struct MfmaSelector
}
template <>
constexpr auto GetMfma<bhalf_t, 16, 16>()
constexpr auto GetMfma<bhalf_t, 32, 32, bhalf_t, true>()
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_32x32x8bf16_1k;
#else
return MfmaInstr::mfma_f32_32x32x4bf16;
#endif
}
template <>
constexpr auto GetMfma<bhalf_t, 16, 16, bhalf_t, false>()
{
#if defined(__gfx950__)
return MfmaInstr::mfma_f32_16x16x32bf16;
......@@ -1011,6 +1034,16 @@ struct MfmaSelector
#endif
}
template <>
constexpr auto GetMfma<bhalf_t, 16, 16, bhalf_t, true>()
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_16x16x16bf16_1k;
#else
return MfmaInstr::mfma_f32_16x16x8bf16;
#endif
}
#if defined(__gfx950__)
template <>
constexpr auto GetMfma<int8_t, 32, 32>()
......@@ -1095,7 +1128,7 @@ struct MfmaSelector
}
static constexpr auto selected_mfma =
mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops, additional_type>()>{};
mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops, additional_type, is_single_rate_mfma>()>{};
__host__ __device__ constexpr MfmaSelector()
{
......@@ -1397,7 +1430,9 @@ struct XdlopsGemm
return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td};
}
static constexpr auto mfma = MfmaSelector<base_type, MPerXdlops, NPerXdlops, additional_type>{};
// Falls back to single rate instruction on gfx950 if KPack <= 4; no change on gfx942-
static constexpr auto mfma = MfmaSelector<base_type, MPerXdlops, NPerXdlops, additional_type,
((is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value) && KPack <= 4) ? true : false>{};
static constexpr auto mfma_instr = mfma.selected_mfma;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment