#pragma once // Misc #include "cute/tensor.hpp" #include "cutlass/arch/arch.h" #include "cutlass/arch/mma.h" #include "cutlass/cutlass.h" #include "cutlass/detail/blockwise_scale_layout.hpp" #include "cutlass/epilogue/dispatch_policy.hpp" #include "cutlass/gemm/dispatch_policy.hpp" #include "cutlass/gemm/group_array_problem_shape.hpp" #include "cutlass/layout/layout.h" #include "cutlass/numeric_conversion.h" #include "cutlass/numeric_size.h" // Collective Builder #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" #include "cutlass/epilogue/thread/activation.h" #include "cutlass/gemm/collective/collective_builder.hpp" // Integration #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" namespace expert_specialization { using namespace cute; struct PerfConfigLowMH20 { // Swap A/B using ElementA = cutlass::float_e4m3_t; using MmaTileShape = Shape<_128, _32, _128>; using ClusterShape = Shape<_2, _1, _1>; using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<128, 1, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>; using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); }; struct PerfConfigLowMHx00 { // Swap A/B using ElementA = cutlass::float_e4m3_t; using MmaTileShape = Shape<_256, _32, _128>; using ClusterShape = Shape<_2, _1, _1>; using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<128, 1, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>; using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); }; struct PerfConfigMiddleMH20 { using ElementA = cutlass::float_e4m3_t; using MmaTileShape = Shape<_64, _128, _128>; using ClusterShape = Shape<_1, _2, _1>; using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>; using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); }; struct PerfConfigMiddleMHx00 { using ElementA = cutlass::float_e4m3_t; using MmaTileShape = Shape<_256, _64, _128>; using ClusterShape = Shape<_2, _1, _1>; using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<128, 1, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>; using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); }; struct PerfConfigHighMH20 { using ElementA = cutlass::float_e4m3_t; using MmaTileShape = Shape<_64, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>; using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); }; struct PerfConfigHighMHx00 { using ElementA = cutlass::float_e4m3_t; using MmaTileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_1, _2, _1>; using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>; using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); }; template struct ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits { using ElementA = cutlass::float_e4m3_t; using ElementB = cutlass::float_e4m3_t; using ElementC = void; using ElementD = OutType; using ElementAccumulator = float; using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = LayoutD; using LayoutSFA = typename PerfConfig::LayoutSFA; using LayoutSFB = typename PerfConfig::LayoutSFB; using ProblemShape = cutlass::gemm::GroupProblemShape>; static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; using CustomEVTIdentity = // acc cutlass::epilogue::fusion::Sm90EVT< cutlass::epilogue::fusion:: Sm90Compute, cutlass::epilogue::fusion::Sm90AccFetch>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, typename PerfConfig::MmaTileShape, typename PerfConfig::ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, ElementC, // Use void to avoid load Matrix C LayoutC*, AlignmentC, ElementD, LayoutD*, AlignmentD, typename PerfConfig::EpilogueSchedule, CustomEVTIdentity>::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, ElementA, cute::tuple, AlignmentA, ElementB, cute::tuple, AlignmentB, ElementAccumulator, typename PerfConfig::MmaTileShape, typename PerfConfig::ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename CollectiveEpilogue::SharedStorage))>, typename PerfConfig::KernelSchedule>::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape; using StrideA = typename Gemm::GemmKernel::InternalStrideA; using StrideB = typename Gemm::GemmKernel::InternalStrideB; using StrideC = typename Gemm::GemmKernel::InternalStrideC; using StrideD = typename Gemm::GemmKernel::InternalStrideD; }; } // namespace expert_specialization