Unverified Commit 2d7779f8 authored by Ilya Markov's avatar Ilya Markov Committed by GitHub
Browse files

[Perf] SM100 FP8 GEMM Optimizations after cutlass_profiler (#20071)


Signed-off-by: default avatarilmarkov <imarkov@redhat.com>
Co-authored-by: default avatarilmarkov <imarkov@redhat.com>
parent a57d57fa
...@@ -29,12 +29,12 @@ struct sm100_fp8_config_default { ...@@ -29,12 +29,12 @@ struct sm100_fp8_config_default {
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue> template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_M256 { struct sm100_fp8_config_M256 {
// M in (128, 256] // M in (64, 256]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_128, _128, _128>; using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_2, _2, _1>; using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>; KernelSchedule, EpilogueSchedule>;
...@@ -42,13 +42,13 @@ struct sm100_fp8_config_M256 { ...@@ -42,13 +42,13 @@ struct sm100_fp8_config_M256 {
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue> template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_M128 { struct sm100_fp8_config_M64 {
// M in (64, 128] // M in (16, 64]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_128, _128, _256>; using TileShape = Shape<_64, _64, _128>;
using ClusterShape = Shape<_2, _4, _1>; using ClusterShape = Shape<_1, _1, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>; KernelSchedule, EpilogueSchedule>;
...@@ -56,13 +56,13 @@ struct sm100_fp8_config_M128 { ...@@ -56,13 +56,13 @@ struct sm100_fp8_config_M128 {
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue> template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_M64 { struct sm100_fp8_config_M16 {
// M in [1, 64] // M in [1, 16]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_64, _64, _256>; using TileShape = Shape<_64, _64, _128>;
using ClusterShape = Shape<_1, _8, _1>; using ClusterShape = Shape<_1, _4, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>; KernelSchedule, EpilogueSchedule>;
...@@ -82,27 +82,27 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out, ...@@ -82,27 +82,27 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
using Cutlass3xGemmDefault = using Cutlass3xGemmDefault =
typename sm100_fp8_config_default<InType, OutType, typename sm100_fp8_config_default<InType, OutType,
Epilogue>::Cutlass3xGemm; Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM16 =
typename sm100_fp8_config_M16<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM64 = using Cutlass3xGemmM64 =
typename sm100_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm; typename sm100_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM128 =
typename sm100_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM256 = using Cutlass3xGemmM256 =
typename sm100_fp8_config_M256<InType, OutType, Epilogue>::Cutlass3xGemm; typename sm100_fp8_config_M256<InType, OutType, Epilogue>::Cutlass3xGemm;
uint32_t const m = a.size(0); uint32_t const m = a.size(0);
uint32_t const mp2 = uint32_t const mp2 =
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2 std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
if (mp2 <= 64) { if (mp2 <= 16) {
// m in [1, 64] // m in [1, 16]
return cutlass_gemm_caller<Cutlass3xGemmM64>( return cutlass_gemm_caller<Cutlass3xGemmM16>(
out, a, b, std::forward<EpilogueArgs>(args)...); out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) { } else if (mp2 <= 64) {
// m in (64, 128] // m in (16, 64]
return cutlass_gemm_caller<Cutlass3xGemmM128>( return cutlass_gemm_caller<Cutlass3xGemmM64>(
out, a, b, std::forward<EpilogueArgs>(args)...); out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 256) { } else if (mp2 <= 256) {
// m in (128, 256] // m in (64, 256]
return cutlass_gemm_caller<Cutlass3xGemmM256>( return cutlass_gemm_caller<Cutlass3xGemmM256>(
out, a, b, std::forward<EpilogueArgs>(args)...); out, a, b, std::forward<EpilogueArgs>(args)...);
} else { } else {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment