Unverified Commit e13945f9 authored by Ilya Markov's avatar Ilya Markov Committed by GitHub
Browse files

[Perf] Further tunings for SM100 FP8 CUTLASS kernel (#19566)

parent 08500011
...@@ -15,11 +15,25 @@ using c3x::cutlass_gemm_caller; ...@@ -15,11 +15,25 @@ using c3x::cutlass_gemm_caller;
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue> template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_default { struct sm100_fp8_config_default {
// M in (128, inf) // M in (256, inf)
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_256, _128, _64>; using TileShape = Shape<_256, _128, _128>;
using ClusterShape = Shape<_2, _2, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_M256 {
// M in (128, 256]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_2, _2, _1>; using ClusterShape = Shape<_2, _2, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
...@@ -33,8 +47,8 @@ struct sm100_fp8_config_M128 { ...@@ -33,8 +47,8 @@ struct sm100_fp8_config_M128 {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_128, _128, _64>; using TileShape = Shape<_128, _128, _256>;
using ClusterShape = Shape<_2, _2, _1>; using ClusterShape = Shape<_2, _4, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>; KernelSchedule, EpilogueSchedule>;
...@@ -72,6 +86,8 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out, ...@@ -72,6 +86,8 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
typename sm100_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm; typename sm100_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM128 = using Cutlass3xGemmM128 =
typename sm100_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm; typename sm100_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM256 =
typename sm100_fp8_config_M256<InType, OutType, Epilogue>::Cutlass3xGemm;
uint32_t const m = a.size(0); uint32_t const m = a.size(0);
uint32_t const mp2 = uint32_t const mp2 =
...@@ -85,8 +101,12 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out, ...@@ -85,8 +101,12 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
// m in (64, 128] // m in (64, 128]
return cutlass_gemm_caller<Cutlass3xGemmM128>( return cutlass_gemm_caller<Cutlass3xGemmM128>(
out, a, b, std::forward<EpilogueArgs>(args)...); out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 256) {
// m in (128, 256]
return cutlass_gemm_caller<Cutlass3xGemmM256>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else { } else {
// m in (128, inf) // m in (256, inf)
return cutlass_gemm_caller<Cutlass3xGemmDefault>( return cutlass_gemm_caller<Cutlass3xGemmDefault>(
out, a, b, std::forward<EpilogueArgs>(args)...); out, a, b, std::forward<EpilogueArgs>(args)...);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment