Unverified Commit 6e92da8f authored by Qi Yuhang's avatar Qi Yuhang Committed by GitHub
Browse files

[Fix][Ready]Fix register spilling in cutlass nvfp4 gemm kernel on Blackwell (#8127)

parent e1020dc5
...@@ -40,27 +40,21 @@ using namespace cute; ...@@ -40,27 +40,21 @@ using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
// Kernel Perf config // Kernel Perf config
template <typename T> template <typename T>
struct KernelTraits; struct KernelTraits {
template <>
struct KernelTraits<float> {
using MmaTileShape = Shape<_128, _128, _256>;
using ClusterShape = Shape<_1, _1, _1>;
using PerSmTileShape_MNK = Shape<_128, _128, _256>;
};
template <>
struct KernelTraits<cutlass::half_t> {
using MmaTileShape = Shape<_256, _256, _256>; using MmaTileShape = Shape<_256, _256, _256>;
using ClusterShape = Shape<_4, _4, _1>; using ClusterShape = Shape<int, int, _1>;
using PerSmTileShape_MNK = Shape<_128, _256, _256>; using EpilogueTile = Shape<_128, _64>;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm;
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100;
}; };
template <> template <>
struct KernelTraits<cutlass::bfloat16_t> { struct KernelTraits<float> {
using MmaTileShape = Shape<_256, _256, _256>; using MmaTileShape = Shape<_128, _128, _256>;
using ClusterShape = Shape<_4, _4, _1>; using ClusterShape = Shape<int, int, _1>;
using PerSmTileShape_MNK = Shape<_128, _256, _256>; using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm;
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100;
}; };
template <typename T> template <typename T>
...@@ -90,23 +84,26 @@ struct Fp4GemmSm100 { ...@@ -90,23 +84,26 @@ struct Fp4GemmSm100 {
// Kernel Perf config // Kernel Perf config
using MmaTileShape = typename KernelTraits<T>::MmaTileShape; using MmaTileShape = typename KernelTraits<T>::MmaTileShape;
using ClusterShape = typename KernelTraits<T>::ClusterShape; using ClusterShape = typename KernelTraits<T>::ClusterShape;
using PerSmTileShape_MNK = typename KernelTraits<T>::PerSmTileShape_MNK; using EpilogueTile = typename KernelTraits<T>::EpilogueTile;
using EpilogueSchedule = typename KernelTraits<T>::EpilogueSchedule;
using MainloopSchedule = typename KernelTraits<T>::MainloopSchedule;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, ArchTag,
OperatorClass, cutlass::arch::OpClassTensorOp,
PerSmTileShape_MNK, MmaTileShape,
ClusterShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto, EpilogueTile,
ElementAccumulator, ElementAccumulator,
ElementAccumulator, ElementAccumulator,
ElementC, void,
LayoutCTag, LayoutCTag,
AlignmentC, AlignmentC,
ElementD, ElementD,
LayoutDTag, LayoutDTag,
AlignmentD, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp; EpilogueSchedule,
cutlass::epilogue::fusion::LinearCombination<ElementD, float, void, float>>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, ArchTag,
...@@ -122,7 +119,7 @@ struct Fp4GemmSm100 { ...@@ -122,7 +119,7 @@ struct Fp4GemmSm100 {
ClusterShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>( cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>, sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp; MainloopSchedule>::CollectiveOp;
using GemmKernel = using GemmKernel =
cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>; cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
...@@ -191,6 +188,13 @@ typename T::Gemm::Arguments args_from_options( ...@@ -191,6 +188,13 @@ typename T::Gemm::Arguments args_from_options(
stride_D}}; stride_D}};
auto& fusion_args = arguments.epilogue.thread; auto& fusion_args = arguments.epilogue.thread;
fusion_args.alpha_ptr = static_cast<ElementCompute const*>(alpha.data_ptr()); fusion_args.alpha_ptr = static_cast<ElementCompute const*>(alpha.data_ptr());
if constexpr (std::is_same_v<T, float>) {
arguments.hw_info.cluster_shape = dim3(1, 4, 1);
arguments.hw_info.cluster_shape_fallback = dim3(1, 1, 1);
} else {
arguments.hw_info.cluster_shape = dim3(4, 4, 1);
arguments.hw_info.cluster_shape_fallback = dim3(2, 1, 1);
}
return arguments; return arguments;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment