Unverified Commit 6e92da8f authored by Qi Yuhang's avatar Qi Yuhang Committed by GitHub
Browse files

[Fix][Ready]Fix register spilling in cutlass nvfp4 gemm kernel on Blackwell (#8127)

parent e1020dc5
......@@ -40,27 +40,21 @@ using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
// Kernel Perf config
template <typename T>
struct KernelTraits;
template <>
struct KernelTraits<float> {
using MmaTileShape = Shape<_128, _128, _256>;
using ClusterShape = Shape<_1, _1, _1>;
using PerSmTileShape_MNK = Shape<_128, _128, _256>;
};
template <>
struct KernelTraits<cutlass::half_t> {
struct KernelTraits {
using MmaTileShape = Shape<_256, _256, _256>;
using ClusterShape = Shape<_4, _4, _1>;
using PerSmTileShape_MNK = Shape<_128, _256, _256>;
using ClusterShape = Shape<int, int, _1>;
using EpilogueTile = Shape<_128, _64>;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm;
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100;
};
template <>
struct KernelTraits<cutlass::bfloat16_t> {
using MmaTileShape = Shape<_256, _256, _256>;
using ClusterShape = Shape<_4, _4, _1>;
using PerSmTileShape_MNK = Shape<_128, _256, _256>;
struct KernelTraits<float> {
using MmaTileShape = Shape<_128, _128, _256>;
using ClusterShape = Shape<int, int, _1>;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm;
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100;
};
template <typename T>
......@@ -90,23 +84,26 @@ struct Fp4GemmSm100 {
// Kernel Perf config
using MmaTileShape = typename KernelTraits<T>::MmaTileShape;
using ClusterShape = typename KernelTraits<T>::ClusterShape;
using PerSmTileShape_MNK = typename KernelTraits<T>::PerSmTileShape_MNK;
using EpilogueTile = typename KernelTraits<T>::EpilogueTile;
using EpilogueSchedule = typename KernelTraits<T>::EpilogueSchedule;
using MainloopSchedule = typename KernelTraits<T>::MainloopSchedule;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
PerSmTileShape_MNK,
cutlass::arch::OpClassTensorOp,
MmaTileShape,
ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
EpilogueTile,
ElementAccumulator,
ElementAccumulator,
ElementC,
void,
LayoutCTag,
AlignmentC,
ElementD,
LayoutDTag,
AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp;
EpilogueSchedule,
cutlass::epilogue::fusion::LinearCombination<ElementD, float, void, float>>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
......@@ -122,7 +119,7 @@ struct Fp4GemmSm100 {
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp;
MainloopSchedule>::CollectiveOp;
using GemmKernel =
cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
......@@ -191,6 +188,13 @@ typename T::Gemm::Arguments args_from_options(
stride_D}};
auto& fusion_args = arguments.epilogue.thread;
fusion_args.alpha_ptr = static_cast<ElementCompute const*>(alpha.data_ptr());
if constexpr (std::is_same_v<T, float>) {
arguments.hw_info.cluster_shape = dim3(1, 4, 1);
arguments.hw_info.cluster_shape_fallback = dim3(1, 1, 1);
} else {
arguments.hw_info.cluster_shape = dim3(4, 4, 1);
arguments.hw_info.cluster_shape_fallback = dim3(2, 1, 1);
}
return arguments;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment