Unverified Commit a8bffaa1 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Kernel] Add MXFP4 W4A4 CUTLASS MoE kernel for SM100 (#37463)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent 5cdddddd
......@@ -141,6 +141,7 @@ steps:
- pytest -v -s tests/kernels/quantization/test_nvfp4_qutlass.py
- pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
- pytest -v -s tests/kernels/moe/test_mxfp4_moe.py
- pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py
- pytest -v -s tests/kernels/moe/test_flashinfer.py
- pytest -v -s tests/kernels/moe/test_flashinfer_moe.py
......
......@@ -952,7 +952,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/libtorch_stable/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
"csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu"
"csrc/libtorch_stable/quantization/fp4/mxfp4_experts_quant.cu"
"csrc/libtorch_stable/quantization/fp4/mxfp4_blockwise_moe_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
......
......@@ -134,4 +134,27 @@ void silu_and_mul_nvfp4_quant(torch::stable::Tensor& out,
torch::stable::Tensor& input,
torch::stable::Tensor& input_global_scale);
void mxfp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts,
int64_t n_experts);
void silu_and_mul_mxfp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts,
int64_t n_experts);
void cutlass_mxfp4_group_mm(torch::stable::Tensor& output,
const torch::stable::Tensor& a,
const torch::stable::Tensor& b,
const torch::stable::Tensor& a_blockscale,
const torch::stable::Tensor& b_blockscales,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets);
#endif
/*
* SPDX-License-Identifier: Apache-2.0
* SPDX-FileCopyrightText: Copyright contributors to the vLLM project
*
* MXFP4 x MXFP4 block-scaled grouped GEMM kernel for MoE on SM100.
* Uses Cutlass mx_float4_t operands, E8M0 block scales, and 32-element groups.
*/
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include <cutlass/arch/arch.h>
#include "cutlass_extensions/common.hpp"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/packed_stride.hpp"
#include <cassert>
using namespace cute;
// Offset-computation kernel for MXFP4 grouped GEMM (group size 32).
template <typename ElementAB, typename ElementC, typename ElementSF,
typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
__global__ void __mxfp4_get_group_gemm_starts(
ElementAB** a_offsets, ElementAB** b_offsets, ElementC** out_offsets,
ElementSF** a_scales_offsets, ElementSF** b_scales_offsets,
LayoutSFA* layout_sfa_base_as_int, LayoutSFB* layout_sfb_base_as_int,
ElementAB* a_base_as_int, ElementAB* b_base_as_int,
ElementC* out_base_as_int, ElementSF* a_scales_base_as_int,
ElementSF* b_scales_base_as_int, const int32_t* expert_offsets,
const int32_t* sf_offsets, const int32_t* problem_sizes_as_shapes,
int64_t* a_strides, int64_t* b_strides, int64_t* c_strides,
const int64_t a_stride_val, const int64_t b_stride_val,
const int64_t c_stride_val, const int K, const int N) {
int64_t expert_id = threadIdx.x;
if (expert_id >= gridDim.x * blockDim.x) {
return;
}
int64_t expert_offset = static_cast<int64_t>(expert_offsets[expert_id]);
int64_t sf_offset = static_cast<int64_t>(sf_offsets[expert_id]);
int64_t group_size = 32;
int64_t m = static_cast<int64_t>(problem_sizes_as_shapes[expert_id * 3]);
int64_t n = static_cast<int64_t>(problem_sizes_as_shapes[expert_id * 3 + 1]);
int64_t k = static_cast<int64_t>(problem_sizes_as_shapes[expert_id * 3 + 2]);
assert((m >= 0 && n == N && k == K && k % 2 == 0) &&
"unexpected problem sizes");
int64_t half_k = static_cast<int64_t>(k / 2);
int64_t group_k = static_cast<int64_t>(k / group_size);
// Shape of A as uint8/byte = [M, K // 2]
a_offsets[expert_id] = a_base_as_int + expert_offset * half_k;
// Shape of B as uint8/byte = [E, N, K // 2]
b_offsets[expert_id] = b_base_as_int + expert_id * n * half_k;
// Shape of C = [M, N]
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
// Shape of a_scale = [sum(sf_sizes), K // group_size]
a_scales_offsets[expert_id] = a_scales_base_as_int + sf_offset * group_k;
assert((reinterpret_cast<uintptr_t>(a_scales_offsets[expert_id]) % 128) ==
0 &&
"TMA requires 128-byte alignment");
// Shape of B scale = [E, N, K // group_size]
b_scales_offsets[expert_id] = b_scales_base_as_int + expert_id * n * group_k;
assert((reinterpret_cast<uintptr_t>(b_scales_offsets[expert_id]) % 128) ==
0 &&
"TMA requires 128-byte alignment");
// Initialize strides
a_strides[expert_id] = a_stride_val;
b_strides[expert_id] = b_stride_val;
c_strides[expert_id] = c_stride_val;
LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id;
LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id;
*layout_sfa_ptr = ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(
static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), 1));
*layout_sfb_ptr = ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(
static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), 1));
}
#define __CALL_MXFP4_GET_STARTS_KERNEL(ELEMENT_AB_TYPE, SF_TYPE, \
TENSOR_C_TYPE, C_TYPE, LayoutSFA, \
LayoutSFB, ScaleConfig) \
else if (out_tensors.scalar_type() == TENSOR_C_TYPE) { \
__mxfp4_get_group_gemm_starts<ELEMENT_AB_TYPE, C_TYPE, SF_TYPE, LayoutSFA, \
LayoutSFB, ScaleConfig> \
<<<1, num_experts, 0, stream>>>( \
static_cast<ELEMENT_AB_TYPE**>(a_starts.data_ptr()), \
static_cast<ELEMENT_AB_TYPE**>(b_starts.data_ptr()), \
static_cast<C_TYPE**>(out_starts.data_ptr()), \
static_cast<SF_TYPE**>(a_scales_starts.data_ptr()), \
static_cast<SF_TYPE**>(b_scales_starts.data_ptr()), \
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()), \
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr()), \
static_cast<ELEMENT_AB_TYPE*>(a_tensors.data_ptr()), \
static_cast<ELEMENT_AB_TYPE*>(b_tensors.data_ptr()), \
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
static_cast<SF_TYPE*>(a_scales.data_ptr()), \
static_cast<SF_TYPE*>(b_scales.data_ptr()), \
static_cast<int32_t*>(expert_offsets.data_ptr()), \
static_cast<int32_t*>(sf_offsets.data_ptr()), \
static_cast<int32_t*>(problem_sizes.data_ptr()), \
static_cast<int64_t*>(a_strides.data_ptr()), \
static_cast<int64_t*>(b_strides.data_ptr()), \
static_cast<int64_t*>(c_strides.data_ptr()), a_stride_val, \
b_stride_val, c_stride_val, K, N); \
}
template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
void mxfp4_run_get_group_gemm_starts(
const torch::stable::Tensor& a_starts,
const torch::stable::Tensor& b_starts,
const torch::stable::Tensor& out_starts,
const torch::stable::Tensor& a_scales_starts,
const torch::stable::Tensor& b_scales_starts,
const torch::stable::Tensor& layout_sfa,
const torch::stable::Tensor& layout_sfb,
const torch::stable::Tensor& a_strides,
const torch::stable::Tensor& b_strides,
const torch::stable::Tensor& c_strides, int64_t a_stride_val,
int64_t b_stride_val, int64_t c_stride_val,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& out_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& sf_offsets,
torch::stable::Tensor const& problem_sizes, int M, int N, int K) {
int num_experts = (int)expert_offsets.size(0);
auto stream = get_current_cuda_stream(a_tensors.get_device_index());
STD_TORCH_CHECK(out_tensors.size(1) == N,
"Output tensor shape doesn't match expected shape");
STD_TORCH_CHECK(K / 2 == b_tensors.size(2),
"b_tensors(dim = 2) and a_tensors(dim = 1) trailing"
" dimension must match");
if (false) {
}
// MXFP4 uses E8M0 (float_ue8m0_t) scale factors
__CALL_MXFP4_GET_STARTS_KERNEL(cutlass::float_e2m1_t, cutlass::float_ue8m0_t,
torch::headeronly::ScalarType::BFloat16,
cutlass::bfloat16_t, LayoutSFA, LayoutSFB,
ScaleConfig)
__CALL_MXFP4_GET_STARTS_KERNEL(cutlass::float_e2m1_t, cutlass::float_ue8m0_t,
torch::headeronly::ScalarType::Half, half,
LayoutSFA, LayoutSFB, ScaleConfig)
else {
STD_TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
}
}
template <typename OutType>
void run_mxfp4_blockwise_scaled_group_mm_sm100(
torch::stable::Tensor& output, const torch::stable::Tensor& a,
const torch::stable::Tensor& b, const torch::stable::Tensor& a_blockscale,
const torch::stable::Tensor& b_blockscales,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets, int M, int N, int K) {
using ProblemShape =
cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
using ElementType = cutlass::float_e2m1_t;
using ElementSFType = cutlass::float_ue8m0_t;
using ElementA = cutlass::mx_float4_t<cutlass::float_e2m1_t>;
using ElementB = cutlass::mx_float4_t<cutlass::float_e2m1_t>;
using ElementC = OutType;
using ElementD = ElementC;
using ElementAccumulator = float;
// Layout definitions
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = LayoutC;
static constexpr int AlignmentA = 32;
static constexpr int AlignmentB = 32;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
// Architecture definitions
using ArchTag = cutlass::arch::Sm100;
using EpilogueOperatorClass = cutlass::arch::OpClassTensorOp;
using MainloopOperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
using StageCountType = cutlass::gemm::collective::StageCountAuto;
using ClusterShape = Shape<_1, _1, _1>;
struct MMA1SMConfig {
using MmaTileShape = Shape<_128, _128, _128>;
using KernelSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf4Sm100;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
};
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, EpilogueOperatorClass, typename MMA1SMConfig::MmaTileShape,
ClusterShape, Shape<_128, _64>, ElementAccumulator,
ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD,
LayoutC*, AlignmentD,
typename MMA1SMConfig::EpilogueSchedule>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, MainloopOperatorClass, ElementA, LayoutA*, AlignmentA,
ElementB, LayoutB*, AlignmentB, ElementAccumulator,
typename MMA1SMConfig::MmaTileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
typename MMA1SMConfig::KernelSchedule>::CollectiveOp;
using GemmKernel =
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
CollectiveEpilogue>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
using LayoutSFA =
typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA;
using LayoutSFB =
typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB;
using ScaleConfig =
typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
int num_experts = static_cast<int>(expert_offsets.size(0));
torch::stable::Tensor a_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor b_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor out_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor a_scales_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor b_scales_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor layout_sfa = torch::stable::empty(
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
a.device());
torch::stable::Tensor layout_sfb = torch::stable::empty(
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
a.device());
torch::stable::Tensor a_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor b_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor c_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
mxfp4_run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, layout_sfa,
layout_sfb, a_strides1, b_strides1, c_strides1, a.stride(0) * 2,
b.stride(1) * 2, output.stride(0), a, b, output, a_blockscale,
b_blockscales, expert_offsets, sf_offsets, problem_sizes, M, N, K);
// Create an instance of the GEMM
Gemm gemm_op;
UnderlyingProblemShape* problem_sizes_as_shapes =
static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
// Set the Scheduler info
cutlass::KernelHardwareInfo hw_info;
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::
PersistentTileSchedulerSm100GroupParams<
typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
scheduler.raster_order = RasterOrderOptions::AlongM;
hw_info.device_id = a.get_device_index();
static std::unordered_map<int, int> cached_sm_counts;
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
cached_sm_counts[hw_info.device_id] =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
hw_info.device_id);
}
hw_info.sm_count = min(cached_sm_counts[hw_info.device_id], INT_MAX);
// Mainloop Arguments
typename GemmKernel::MainloopArguments mainloop_args{
static_cast<const ElementType**>(a_ptrs.data_ptr()),
static_cast<StrideA*>(a_strides1.data_ptr()),
static_cast<const ElementType**>(b_ptrs.data_ptr()),
static_cast<StrideB*>(b_strides1.data_ptr()),
static_cast<const ElementSFType**>(a_scales_ptrs.data_ptr()),
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()),
static_cast<const ElementSFType**>(b_scales_ptrs.data_ptr()),
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr())};
// Epilogue Arguments
typename GemmKernel::EpilogueArguments epilogue_args{
{}, // epilogue.thread
nullptr,
static_cast<StrideC*>(c_strides1.data_ptr()),
static_cast<ElementD**>(out_ptrs.data_ptr()),
static_cast<StrideC*>(c_strides1.data_ptr())};
auto& fusion_args = epilogue_args.thread;
// Scalar epilogue (CUTLASS grouped GEMM): D = 1 * accum + 0 * C
fusion_args.alpha_ptr = nullptr;
fusion_args.beta_ptr = nullptr;
fusion_args.alpha = 1.0f;
fusion_args.alpha_ptr_array = nullptr;
fusion_args.dAlpha = {_0{}, _0{}, 0};
fusion_args.beta = 0.0f;
fusion_args.beta_ptr_array = nullptr;
fusion_args.dBeta = {_0{}, _0{}, 0};
// Gemm Arguments
typename GemmKernel::Arguments args{
cutlass::gemm::GemmUniversalMode::kGrouped,
{num_experts, problem_sizes_as_shapes, nullptr},
mainloop_args,
epilogue_args,
hw_info,
scheduler};
size_t workspace_size = Gemm::get_workspace_size(args);
auto workspace =
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
std::nullopt, a.device());
const cudaStream_t stream = get_current_cuda_stream(a.get_device_index());
auto can_implement_status = gemm_op.can_implement(args);
STD_TORCH_CHECK(
can_implement_status == cutlass::Status::kSuccess,
"Failed to implement MXFP4 GEMM: status=", (int)can_implement_status);
// Run the GEMM
auto status = gemm_op.initialize(args, workspace.data_ptr());
STD_TORCH_CHECK(status == cutlass::Status::kSuccess,
"Failed to initialize MXFP4 GEMM: status=", (int)status,
" workspace_size=", workspace_size,
" num_experts=", num_experts, " M=", M, " N=", N, " K=", K);
status = gemm_op.run(args, workspace.data_ptr(), stream);
STD_TORCH_CHECK(status == cutlass::Status::kSuccess,
"Failed to run MXFP4 GEMM");
}
template <typename OutType>
void run_mxfp4_blockwise_scaled_group_mm(
torch::stable::Tensor& output, const torch::stable::Tensor& a,
const torch::stable::Tensor& b, const torch::stable::Tensor& a_blockscale,
const torch::stable::Tensor& b_blockscales,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets, int M, int N, int K) {
int32_t version_num = get_sm_version_num();
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
if (version_num >= 100 && version_num < 120) {
run_mxfp4_blockwise_scaled_group_mm_sm100<OutType>(
output, a, b, a_blockscale, b_blockscales, problem_sizes,
expert_offsets, sf_offsets, M, N, K);
return;
}
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_mxfp4_group_mm kernel for CUDA device capability: ",
version_num, ". Required capability: 100");
}
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
constexpr auto MXFP4_FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte;
// E8M0 scale factors stored as uint8
constexpr auto MXFP4_SF_DTYPE = torch::headeronly::ScalarType::Byte;
#endif
#define CHECK_TYPE(x, st, m) \
STD_TORCH_CHECK(x.scalar_type() == st, \
": Inconsistency of torch::stable::Tensor type:", m)
#define CHECK_TH_CUDA(x, m) \
STD_TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
#define CHECK_CONTIGUOUS(x, m) \
STD_TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m)
void cutlass_mxfp4_group_mm(torch::stable::Tensor& output,
const torch::stable::Tensor& a,
const torch::stable::Tensor& b,
const torch::stable::Tensor& a_blockscale,
const torch::stable::Tensor& b_blockscales,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets) {
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
// Input validation
CHECK_INPUT(a, MXFP4_FLOAT4_E2M1X2, "a");
CHECK_INPUT(b, MXFP4_FLOAT4_E2M1X2, "b");
// MXFP4 uses E8M0 scale factors (stored as uint8)
CHECK_INPUT(a_blockscale, MXFP4_SF_DTYPE, "a_blockscale");
CHECK_INPUT(b_blockscales, MXFP4_SF_DTYPE, "b_blockscales");
STD_TORCH_CHECK(
a_blockscale.dim() == 2,
"expected a_blockscale to be of shape [num_experts, rounded_m,"
" k // group_size], observed rank: ",
a_blockscale.dim())
STD_TORCH_CHECK(b_blockscales.dim() == 3,
"expected b_blockscale to be of shape: "
" [num_experts, n, k // group_size], observed rank: ",
b_blockscales.dim())
STD_TORCH_CHECK(problem_sizes.dim() == 2,
"problem_sizes must be a 2D tensor");
STD_TORCH_CHECK(problem_sizes.size(1) == 3,
"problem_sizes must have the shape (num_experts, 3)");
STD_TORCH_CHECK(
problem_sizes.size(0) == expert_offsets.size(0),
"Number of experts in problem_sizes must match expert_offsets");
STD_TORCH_CHECK(
problem_sizes.scalar_type() == torch::headeronly::ScalarType::Int,
"problem_sizes must be int32.");
int M = static_cast<int>(a.size(0));
int N = static_cast<int>(b.size(1));
int E = static_cast<int>(b.size(0));
int K = static_cast<int>(2 * b.size(2));
if (output.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
run_mxfp4_blockwise_scaled_group_mm<cutlass::bfloat16_t>(
output, a, b, a_blockscale, b_blockscales, problem_sizes,
expert_offsets, sf_offsets, M, N, K);
} else {
run_mxfp4_blockwise_scaled_group_mm<cutlass::half_t>(
output, a, b, a_blockscale, b_blockscales, problem_sizes,
expert_offsets, sf_offsets, M, N, K);
}
#else
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_mxfp4_group_mm kernel; build vLLM with "
"SM100 block-scaled FP4 MoE (ENABLE_NVFP4_SM100) and CUDA 12.8+.");
#endif
}
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
m.impl("cutlass_mxfp4_group_mm", TORCH_BOX(&cutlass_mxfp4_group_mm));
}
/*
* SPDX-License-Identifier: Apache-2.0
* SPDX-FileCopyrightText: Copyright contributors to the vLLM project
*
* MXFP4 activation quantization kernel for MoE experts.
* Quantizes BF16/FP16 activations to MXFP4: E2M1 values with E8M0 block scales
* over 32-element groups.
*
* Uses PACK16 E2M1 conversion helpers (nvfp4_utils.cuh) configured for:
* - Block size 32 (2 threads per SF in PACK16 mode)
* - E8M0 (power-of-two) scale factors
* - SF layout: [numMTiles, numKTiles, 32, 4, 4] where numKTiles=ceil(K/128)
*/
// MXFP4 requires PACK16 mode (16 elements per thread) so that
// 2 threads cover 32-element blocks. This requires CUDA >= 12.9.
// Must be defined before any header that (transitively) includes
// nvfp4_utils.cuh.
#define NVFP4_ENABLE_ELTS16 1
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <cuda_runtime.h>
#include <cuda_fp8.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "libtorch_stable/dispatch_utils.h"
#include "cuda_vec_utils.cuh"
#include "cuda_utils.h"
#include "nvfp4_utils.cuh"
static_assert(CVT_FP4_ELTS_PER_THREAD == 16,
"MXFP4 experts quant requires PACK16 mode (CUDA >= 12.9)");
#include "launch_bounds_utils.h"
namespace vllm {
// MXFP4 block size constants
static constexpr int MXFP4_SF_VEC_SIZE = 32;
// For PACK16 mode (CVT_FP4_ELTS_PER_THREAD=16): 2 threads per SF
// For PACK8 mode (CVT_FP4_ELTS_PER_THREAD=8): 4 threads per SF
static constexpr int MXFP4_NUM_THREADS_PER_SF =
MXFP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD;
// MXFP4 quantization kernel for experts.
// Uses 32-element blocks with E8M0 (UE8M0) scale factors.
// When FUSE_SILU_MUL=true, expects input with gate||up layout and fuses
// SiLU(gate)*up before quantization.
template <class Type, bool FUSE_SILU_MUL = false,
bool SMALL_NUM_EXPERTS = false>
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
mxfp4_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
fp4_packed_t* out, uint32_t* SFout,
uint32_t* input_offset_by_experts,
uint32_t* output_scale_offset_by_experts,
int n_experts, bool low_latency) {
using PackedVec = PackedVec<Type, CVT_FP4_PACK16>;
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
"Vec size is not matched.");
// MXFP4: numKTiles = ceil(numCols / 128) since block_size=32, 4 SFs/tile
int32_t const numKTiles = (numCols + 127) / 128;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
int inColsPerRow = FUSE_SILU_MUL ? colsPerRow * 2 : colsPerRow;
for (int globalIdx = tid; globalIdx < numRows * colsPerRow;
globalIdx += gridDim.x * blockDim.x) {
int rowIdx = globalIdx / colsPerRow;
int colIdx = globalIdx % colsPerRow;
int rowIdx_in_expert = 0;
int expert_idx = 0;
if constexpr (SMALL_NUM_EXPERTS) {
for (int i = 0; i < n_experts; i++) {
uint32_t current_offset = __ldca(&input_offset_by_experts[i]);
uint32_t next_offset = __ldca(&input_offset_by_experts[i + 1]);
if (rowIdx >= current_offset && rowIdx < next_offset) {
rowIdx_in_expert = rowIdx - current_offset;
expert_idx = i;
break;
}
}
} else {
uint32_t local_offsets[17];
for (int chunk_start = 0; chunk_start < n_experts; chunk_start += 16) {
*reinterpret_cast<int4*>(local_offsets) =
__ldca(reinterpret_cast<const int4*>(
&input_offset_by_experts[chunk_start]));
*reinterpret_cast<int4*>(local_offsets + 4) =
__ldca(reinterpret_cast<const int4*>(
&input_offset_by_experts[chunk_start + 4]));
*reinterpret_cast<int4*>(local_offsets + 8) =
__ldca(reinterpret_cast<const int4*>(
&input_offset_by_experts[chunk_start + 8]));
*reinterpret_cast<int4*>(local_offsets + 12) =
__ldca(reinterpret_cast<const int4*>(
&input_offset_by_experts[chunk_start + 12]));
local_offsets[16] = __ldca(&input_offset_by_experts[chunk_start + 16]);
#pragma unroll
for (int i = 0; i < 16; i++) {
if (rowIdx >= local_offsets[i] && rowIdx < local_offsets[i + 1]) {
rowIdx_in_expert = rowIdx - local_offsets[i];
expert_idx = chunk_start + i;
break;
}
}
}
}
// Load input and optionally apply fused SiLU+Mul
int64_t inOffset = rowIdx * inColsPerRow + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
PackedVec quant_input;
if constexpr (FUSE_SILU_MUL) {
PackedVec in_vec_up =
reinterpret_cast<PackedVec const*>(in)[inOffset + colsPerRow];
quant_input = compute_silu_mul(in_vec, in_vec_up);
} else {
quant_input = in_vec;
}
// In PACK16 mode, each thread outputs 16 E2M1 values = u32x2
int64_t outOffset = rowIdx * colsPerRow + colIdx;
auto& out_pos = out[outOffset];
uint32_t* SFout_in_expert =
SFout + output_scale_offset_by_experts[expert_idx] * numKTiles;
// Use MXFP4_NUM_THREADS_PER_SF (2 for PACK16) for 32-element blocks
auto sf_out =
cvt_quant_to_fp4_get_sf_out_offset<uint32_t, MXFP4_NUM_THREADS_PER_SF>(
rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert);
// Block E8M0 scales only; no extra tensor-level scale in this path
constexpr float SFScaleVal = 1.0f;
// UE8M0_SF=true for MXFP4 E8M0 scale factors
out_pos =
cvt_warp_fp16_to_fp4<Type, MXFP4_NUM_THREADS_PER_SF, /*UE8M0_SF=*/true>(
quant_input, SFScaleVal, sf_out);
}
}
// Large M_topk variant using shared memory for expert offsets
template <class Type, bool FUSE_SILU_MUL = false,
bool SMALL_NUM_EXPERTS = false>
__global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
mxfp4_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
fp4_packed_t* out, uint32_t* SFout,
uint32_t* input_offset_by_experts,
uint32_t* output_scale_offset_by_experts,
int n_experts) {
using PackedVec = PackedVec<Type, CVT_FP4_PACK16>;
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
"Vec size is not matched.");
// MXFP4: numKTiles = ceil(numCols / 128)
int32_t const numKTiles = (numCols + 127) / 128;
extern __shared__ uint32_t shared_input_offsets[];
if constexpr (SMALL_NUM_EXPERTS) {
for (int i = threadIdx.x; i < n_experts + 1; i += blockDim.x) {
shared_input_offsets[i] = input_offset_by_experts[i];
}
} else {
for (int i = threadIdx.x * 4; i < n_experts; i += blockDim.x * 4) {
*reinterpret_cast<int4*>(&shared_input_offsets[i]) =
*reinterpret_cast<const int4*>(&input_offset_by_experts[i]);
}
if (threadIdx.x == 0) {
shared_input_offsets[n_experts] = input_offset_by_experts[n_experts];
}
}
__syncthreads();
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
int inColsPerRow = FUSE_SILU_MUL ? colsPerRow * 2 : colsPerRow;
for (int globalIdx = tid; globalIdx < numRows * colsPerRow;
globalIdx += gridDim.x * blockDim.x) {
int rowIdx = globalIdx / colsPerRow;
int colIdx = globalIdx % colsPerRow;
int rowIdx_in_expert = 0;
int expert_idx = 0;
// Binary search through experts using shared memory
int left = 0, right = n_experts - 1;
while (left <= right) {
int mid = (left + right) / 2;
uint32_t mid_offset = shared_input_offsets[mid];
uint32_t next_offset = shared_input_offsets[mid + 1];
if (rowIdx >= mid_offset && rowIdx < next_offset) {
rowIdx_in_expert = rowIdx - mid_offset;
expert_idx = mid;
break;
} else if (rowIdx < mid_offset) {
right = mid - 1;
} else {
left = mid + 1;
}
}
int64_t inOffset = rowIdx * inColsPerRow + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
PackedVec quant_input;
if constexpr (FUSE_SILU_MUL) {
PackedVec in_vec_up =
reinterpret_cast<PackedVec const*>(in)[inOffset + colsPerRow];
quant_input = compute_silu_mul(in_vec, in_vec_up);
} else {
quant_input = in_vec;
}
int64_t outOffset = rowIdx * colsPerRow + colIdx;
auto& out_pos = out[outOffset];
// MXFP4 has no global scale - only block-level E8M0 scale factors
constexpr float SFScaleVal = 1.0f;
uint32_t* SFout_in_expert =
SFout + output_scale_offset_by_experts[expert_idx] * numKTiles;
auto sf_out =
cvt_quant_to_fp4_get_sf_out_offset<uint32_t, MXFP4_NUM_THREADS_PER_SF>(
rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert);
out_pos =
cvt_warp_fp16_to_fp4<Type, MXFP4_NUM_THREADS_PER_SF, /*UE8M0_SF=*/true>(
quant_input, SFScaleVal, sf_out);
}
}
template <typename T, bool FUSE_SILU_MUL = false>
void mxfp4_quant_impl(void* output, void* output_scale, void* input,
void* input_offset_by_experts,
void* output_scale_offset_by_experts, int m_topk, int k,
int n_experts, cudaStream_t stream) {
int multiProcessorCount =
get_device_attribute(cudaDevAttrMultiProcessorCount, -1);
int const workSizePerRow = k / ELTS_PER_THREAD;
int const totalWorkSize = m_topk * workSizePerRow;
dim3 block(std::min(workSizePerRow, 512));
int const numBlocksPerSM =
vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
dim3 grid(std::min(static_cast<int>((totalWorkSize + block.x - 1) / block.x),
multiProcessorCount * numBlocksPerSM));
while (grid.x <= multiProcessorCount && block.x > 64) {
grid.x *= 2;
block.x = (block.x + 1) / 2;
}
int const blockRepeat =
(totalWorkSize + block.x * grid.x - 1) / (block.x * grid.x);
if (blockRepeat > 1) {
size_t shared_mem_size = (n_experts + 1) * sizeof(uint32_t);
if (n_experts >= 4) {
mxfp4_cvt_fp16_to_fp4<T, FUSE_SILU_MUL, false>
<<<grid, block, shared_mem_size, stream>>>(
m_topk, k, reinterpret_cast<T*>(input),
reinterpret_cast<fp4_packed_t*>(output),
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<uint32_t*>(input_offset_by_experts),
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
n_experts);
} else {
mxfp4_cvt_fp16_to_fp4<T, FUSE_SILU_MUL, true>
<<<grid, block, shared_mem_size, stream>>>(
m_topk, k, reinterpret_cast<T*>(input),
reinterpret_cast<fp4_packed_t*>(output),
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<uint32_t*>(input_offset_by_experts),
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
n_experts);
}
} else {
if (n_experts >= 16) {
mxfp4_cvt_fp16_to_fp4<T, FUSE_SILU_MUL, false>
<<<grid, block, 0, stream>>>(
m_topk, k, reinterpret_cast<T*>(input),
reinterpret_cast<fp4_packed_t*>(output),
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<uint32_t*>(input_offset_by_experts),
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
n_experts, /* bool low_latency */ true);
} else {
mxfp4_cvt_fp16_to_fp4<T, FUSE_SILU_MUL, true><<<grid, block, 0, stream>>>(
m_topk, k, reinterpret_cast<T*>(input),
reinterpret_cast<fp4_packed_t*>(output),
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<uint32_t*>(input_offset_by_experts),
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
n_experts, /* bool low_latency */ true);
}
}
}
} // namespace vllm
/*Quantization entry for mxfp4 experts quantization*/
#define CHECK_TH_CUDA(x, m) \
STD_TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \
STD_TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
#define CHECK_INPUT(x, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m);
constexpr auto HALF = torch::headeronly::ScalarType::Half;
constexpr auto BF16 = torch::headeronly::ScalarType::BFloat16;
constexpr auto INT = torch::headeronly::ScalarType::Int;
constexpr auto UINT8 = torch::headeronly::ScalarType::Byte;
static constexpr int MXFP4_BLOCK_SIZE = 32;
static void validate_mxfp4_experts_quant_inputs(
torch::stable::Tensor const& output,
torch::stable::Tensor const& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts,
int64_t n_experts, int64_t m_topk, int64_t k) {
CHECK_INPUT(output, "output");
CHECK_INPUT(output_scale, "output_scale");
CHECK_INPUT(input, "input");
CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts");
CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts");
STD_TORCH_CHECK(output.dim() == 2);
STD_TORCH_CHECK(output_scale.dim() == 2);
STD_TORCH_CHECK(input.dim() == 2);
STD_TORCH_CHECK(input_offset_by_experts.dim() == 1);
STD_TORCH_CHECK(output_scale_offset_by_experts.dim() == 1);
STD_TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
STD_TORCH_CHECK(input_offset_by_experts.scalar_type() == INT);
STD_TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT);
// output is uint8 (two mxfp4 values packed into one uint8)
// output_scale is int32 (four E8M0 values packed into one int32)
STD_TORCH_CHECK(output.scalar_type() == UINT8);
STD_TORCH_CHECK(output_scale.scalar_type() == INT);
STD_TORCH_CHECK(k % MXFP4_BLOCK_SIZE == 0, "k must be a multiple of 32");
STD_TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
STD_TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1);
STD_TORCH_CHECK(output.size(0) == m_topk);
STD_TORCH_CHECK(output.size(1) == k / 2);
int scales_k = k / MXFP4_BLOCK_SIZE;
// K-dimension scale columns padded to a multiple of 4 for swizzle layout
int padded_k = (scales_k + (4 - 1)) / 4 * 4;
// 4 = 4 E8M0 values packed into one int32
STD_TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
}
void mxfp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts,
int64_t n_experts) {
auto m_topk = input.size(0);
auto k = input.size(1);
validate_mxfp4_experts_quant_inputs(
output, output_scale, input, input_offset_by_experts,
output_scale_offset_by_experts, n_experts, m_topk, k);
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(input.get_device_index());
VLLM_STABLE_DISPATCH_HALF_TYPES(
input.scalar_type(), "mxfp4_experts_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
vllm::mxfp4_quant_impl<cuda_type, /*FUSE_SILU_MUL=*/false>(
output.data_ptr(), output_scale.data_ptr(), input.data_ptr(),
input_offset_by_experts.data_ptr(),
output_scale_offset_by_experts.data_ptr(), m_topk, k, n_experts,
stream);
});
}
void silu_and_mul_mxfp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts,
int64_t n_experts) {
auto m_topk = input.size(0);
auto k_times_2 = input.size(1);
STD_TORCH_CHECK(k_times_2 % 2 == 0, "input width must be even (gate || up)");
auto k = k_times_2 / 2;
validate_mxfp4_experts_quant_inputs(
output, output_scale, input, input_offset_by_experts,
output_scale_offset_by_experts, n_experts, m_topk, k);
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(input.get_device_index());
VLLM_STABLE_DISPATCH_HALF_TYPES(
input.scalar_type(), "silu_mul_mxfp4_experts_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
vllm::mxfp4_quant_impl<cuda_type, /*FUSE_SILU_MUL=*/true>(
output.data_ptr(), output_scale.data_ptr(), input.data_ptr(),
input_offset_by_experts.data_ptr(),
output_scale_offset_by_experts.data_ptr(), m_topk, k, n_experts,
stream);
});
}
......@@ -116,6 +116,12 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
" Tensor a_blockscale, Tensor b_blockscales, Tensor alphas,"
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()");
// cutlass mxfp4 block scaled group GEMM (MXFP4 x MXFP4 MoE)
ops.def(
"cutlass_mxfp4_group_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor a_blockscale, Tensor b_blockscales,"
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()");
// Compute NVFP4 block quantized tensor.
ops.def(
"scaled_fp4_quant(Tensor input,"
......@@ -149,6 +155,19 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts) -> ()");
// Compute MXFP4 experts quantization (32-element blocks, E8M0 SFs).
ops.def(
"mxfp4_experts_quant(Tensor! output, Tensor! output_scale,"
"Tensor input, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts, int n_experts) -> ()");
// Fused SiLU+Mul+MXFP4 experts quantization.
ops.def(
"silu_and_mul_mxfp4_experts_quant(Tensor! output, Tensor! "
"output_scale,"
"Tensor input, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts, int n_experts) -> ()");
// Fused SiLU+Mul+NVFP4 quantization.
ops.def(
"silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, "
......@@ -233,6 +252,9 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
ops.impl("silu_and_mul_scaled_fp4_experts_quant",
TORCH_BOX(&silu_and_mul_scaled_fp4_experts_quant));
ops.impl("silu_and_mul_nvfp4_quant", TORCH_BOX(&silu_and_mul_nvfp4_quant));
ops.impl("mxfp4_experts_quant", TORCH_BOX(&mxfp4_experts_quant));
ops.impl("silu_and_mul_mxfp4_experts_quant",
TORCH_BOX(&silu_and_mul_mxfp4_experts_quant));
// W4A8 ops: impl registrations are in the source files
// (w4a8_mm_entry.cu and w4a8_grouped_mm_entry.cu)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for SM100 CUTLASS MXFP4 x MXFP4 grouped MoE kernels."""
import random
import pytest
import torch
from tests.kernels.utils import torch_moe_single
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
random.seed(42)
set_random_seed(42)
MXFP4_BLOCK_SIZE = 32
def align(val: int, alignment: int = 128) -> int:
return int((val + alignment - 1) // alignment * alignment)
def calc_diff(x, y):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
def is_sm100_supported() -> bool:
return current_platform.is_cuda() and current_platform.is_device_capability_family(
100
)
def compute_ref_output(
input_tensor: torch.Tensor,
weight_list: list[torch.Tensor],
expert_offsets: list[int],
expert_offset: int,
num_experts: int,
) -> torch.Tensor:
"""Reference output using torch_moe_single with top-1 routing."""
score = torch.full(
(expert_offset, num_experts),
-1e9,
device=input_tensor.device,
dtype=torch.float32,
)
for g in range(num_experts):
start = expert_offsets[g]
end = expert_offsets[g + 1] if g + 1 < num_experts else expert_offset
score[start:end, g] = 0.0
return torch_moe_single(
input_tensor, torch.stack(weight_list, dim=0), score, topk=1
)
@pytest.mark.skipif(
not is_sm100_supported(),
reason="cutlass_mxfp4_group_mm requires CUDA SM100",
)
@pytest.mark.parametrize("num_experts", [8, 16, 32])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16])
def test_cutlass_mxfp4_grouped_mm(num_experts, out_dtype):
"""
Test the MXFP4 grouped GEMM kernel by:
1. Creating random per-expert inputs and weights
2. Quantizing both to MXFP4 using the CUDA kernel
3. Running the CUTLASS grouped GEMM
4. Comparing against BF16 reference
"""
device = "cuda"
alignment = 128
# N and K must be multiples of 128 for clean swizzle layout
n_g = random.randint(1, 16) * alignment
k_g = random.randint(1, 16) * alignment
expert_offset = 0
expert_offsets_input = []
problem_sizes = []
input_list = []
weight_list = []
for g in range(num_experts):
m_g = random.randint(1, 256)
expert_offsets_input.append(expert_offset)
expert_offset += m_g
problem_sizes.append([m_g, n_g, k_g])
input_list.append(
torch.normal(0.0, std=0.5, size=(m_g, k_g), device=device, dtype=out_dtype)
)
weight_list.append(
torch.normal(0.0, std=0.5, size=(n_g, k_g), device=device, dtype=out_dtype)
)
input_tensor = torch.concat(input_list, dim=0) # [M_total, K]
# --- Quantize INPUTS via mxfp4_experts_quant ---
input_bs_offsets = []
tot = 0
for g in range(num_experts):
input_bs_offsets.append(tot)
tot += align(problem_sizes[g][0], 128)
input_bs_offsets.append(tot)
_inp_expert_offsets = torch.tensor(
expert_offsets_input + [expert_offset], device=device, dtype=torch.int32
)
_inp_bs_offsets = torch.tensor(input_bs_offsets, device=device, dtype=torch.int32)
input_quant, input_sf = ops.mxfp4_experts_quant(
input_tensor,
_inp_expert_offsets,
_inp_bs_offsets,
num_experts,
topk=1,
)
# --- Quantize WEIGHTS via mxfp4_experts_quant ---
# Treat each expert's N weight rows as an "expert" with N tokens
weight_tensor = torch.concat(weight_list, dim=0) # [E*N, K]
weight_expert_offsets = [g * n_g for g in range(num_experts)] + [num_experts * n_g]
# N is always multiple of 128, so blockscale offsets are clean
weight_bs_offsets = [g * n_g for g in range(num_experts)] + [num_experts * n_g]
_wt_expert_offsets = torch.tensor(
weight_expert_offsets, device=device, dtype=torch.int32
)
_wt_bs_offsets = torch.tensor(weight_bs_offsets, device=device, dtype=torch.int32)
weight_quant, weight_sf = ops.mxfp4_experts_quant(
weight_tensor,
_wt_expert_offsets,
_wt_bs_offsets,
num_experts,
topk=1,
)
# Reshape weight quantized data to [E, N, K//2]
weight_quant = weight_quant[: num_experts * n_g].view(num_experts, n_g, k_g // 2)
# Reshape weight scale factors to [E, N, K//32]
# The quant kernel produces uint8 SF buffer. Each row has K//32 SFs.
scales_per_row = k_g // MXFP4_BLOCK_SIZE
weight_sf_flat = weight_sf.view(-1)[: num_experts * n_g * scales_per_row]
weight_sf_3d = weight_sf_flat.view(num_experts, n_g, scales_per_row)
# Output
output = torch.empty((expert_offset, n_g), device=device, dtype=out_dtype)
_problem_sizes = torch.tensor(problem_sizes, device=device, dtype=torch.int32)
_expert_offsets = torch.tensor(
expert_offsets_input, device=device, dtype=torch.int32
)
_input_bs = torch.tensor(input_bs_offsets[:-1], device=device, dtype=torch.int32)
# Run the MXFP4 grouped GEMM
ops.cutlass_mxfp4_moe_mm(
output,
input_quant,
weight_quant,
input_sf,
weight_sf_3d,
_problem_sizes,
_expert_offsets,
_input_bs,
)
# Reference: BF16 matmul
ref_output = compute_ref_output(
input_tensor=input_tensor,
weight_list=weight_list,
expert_offsets=expert_offsets_input,
expert_offset=expert_offset,
num_experts=num_experts,
)
# Compare per-expert
for g in range(num_experts):
start = expert_offsets_input[g]
end = expert_offsets_input[g + 1] if g + 1 < num_experts else expert_offset
if start == end:
continue
baseline = ref_output[start:end]
actual = output[start:end]
diff = calc_diff(actual, baseline)
print(
f"m_g={end - start} n_g={n_g} k_g={k_g} "
f"num_experts={num_experts}, "
f"out_dtype={out_dtype}, diff={diff:.5f}"
)
# FP4 quantization is very lossy (~4 bits precision)
# Comparing quantized vs full-precision gives cosine diff of 0.05-0.15
assert diff < 0.15, f"Expert {g}: diff={diff:.5f} exceeds threshold"
@pytest.mark.skipif(
not is_sm100_supported(),
reason="mxfp4_experts_quant requires CUDA SM100",
)
def test_mxfp4_experts_quant_basic():
"""
Basic smoke test for the MXFP4 experts quantization kernel.
"""
device = "cuda"
num_experts = 4
k = 256
tokens_per_expert = 16
total_tokens = tokens_per_expert * num_experts
input_tensor = torch.randn(total_tokens, k, device=device, dtype=torch.bfloat16) / 5
expert_offsets = [i * tokens_per_expert for i in range(num_experts + 1)]
blockscale_offsets = [
align(i * tokens_per_expert, 128) for i in range(num_experts + 1)
]
_expert_offsets = torch.tensor(expert_offsets, device=device, dtype=torch.int32)
_blockscale_offsets = torch.tensor(
blockscale_offsets, device=device, dtype=torch.int32
)
output, output_sf = ops.mxfp4_experts_quant(
input_tensor,
_expert_offsets,
_blockscale_offsets,
num_experts,
topk=1,
)
assert output.shape == (total_tokens, k // 2)
assert output.dtype == torch.uint8
assert output_sf.dtype == torch.uint8
assert output.any(), "Quantized output is all zeros"
print(
f"MXFP4 experts quant: output shape={output.shape}, sf shape={output_sf.shape}"
)
print("PASSED")
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])
......@@ -1150,6 +1150,38 @@ def cutlass_fp4_moe_mm(
)
def cutlass_mxfp4_moe_mm(
out_tensors: torch.Tensor,
a_tensors: torch.Tensor,
b_tensors: torch.Tensor,
a_scales: torch.Tensor,
b_scales: torch.Tensor,
problem_sizes: torch.Tensor,
expert_offsets: torch.Tensor,
sf_offsets: torch.Tensor,
):
"""
An MXFP4 Blockscaled Group Gemm for MoE (MXFP4 x MXFP4).
Uses mx_float4_t types with E8M0 scale factors and 32-element blocks.
- a/b_tensors: MXFP4 packed activations/weights (uint8, 2 E2M1 per byte)
- a_/b_scales: E8M0 blockscales (uint8, stored in swizzled layout)
- Epilogue uses scalar alpha=1, beta=0 inside the CUDA op (no global scales).
- expert_offsets/sf_offsets: expert boundary indices
- problem_sizes: (num_experts, 3) with (M, N, K) per expert
"""
return torch.ops._C.cutlass_mxfp4_group_mm(
out_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
problem_sizes,
expert_offsets,
sf_offsets,
)
def mxfp8_experts_quant(
input_tensor: torch.Tensor,
problem_sizes: torch.Tensor,
......@@ -1848,6 +1880,109 @@ def silu_and_mul_scaled_fp4_experts_quant(
return output, output_scales
def mxfp4_experts_quant(
input_tensor: torch.Tensor,
expert_offsets: torch.Tensor,
blockscale_offsets: torch.Tensor,
n_experts: int,
topk: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to MXFP4 for packed MoE inputs.
Uses 32-element blocks with E8M0 (power-of-two) scale factors.
MXFP4 has no global scale - only block-level E8M0 scale factors.
Args:
input_tensor: [m_topk, k] BF16/FP16 activations
expert_offsets: [n_experts+1] token boundaries per expert
blockscale_offsets: [n_experts+1] SF row boundaries per expert
n_experts: number of experts
topk: number of top-k experts
Returns:
output: [m_topk, k//2] packed E2M1 values (uint8)
output_scales: E8M0 blockscales in swizzled layout (uint8 view)
"""
assert not current_platform.is_rocm()
assert input_tensor.ndim == 2
MAX_TOKENS_PER_EXPERT = envs.VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE
m_numtopk, k = input_tensor.shape
assert m_numtopk <= MAX_TOKENS_PER_EXPERT * topk, (
f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT("
f"{MAX_TOKENS_PER_EXPERT})"
f" for cutlass_moe_mxfp4, observed m_numtopk = {m_numtopk}. Use"
f" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE to set this value."
)
scales_k = k // 32
padded_k = (scales_k + (4 - 1)) // 4
output = torch.empty(
m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8
)
output_scales = torch.empty(
MAX_TOKENS_PER_EXPERT * topk,
padded_k,
dtype=torch.int32,
device=input_tensor.device,
)
torch.ops._C.mxfp4_experts_quant(
output,
output_scales,
input_tensor,
expert_offsets,
blockscale_offsets,
n_experts,
)
# E8M0 SFs are stored as uint8
output_scales = output_scales.view(torch.uint8)
return output, output_scales
def silu_and_mul_mxfp4_experts_quant(
input_tensor: torch.Tensor,
expert_offsets: torch.Tensor,
blockscale_offsets: torch.Tensor,
n_experts: int,
topk: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Fused SiLU+Mul+MXFP4 quantization for MoE intermediate activations.
MXFP4 has no global scale - only block-level E8M0 scale factors.
"""
assert not current_platform.is_rocm()
assert input_tensor.ndim == 2
MAX_TOKENS_PER_EXPERT = envs.VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE
m_numtopk, k_times_2 = input_tensor.shape
assert k_times_2 % 2 == 0, "input width must be even (gate || up layout)"
k = k_times_2 // 2
assert m_numtopk <= MAX_TOKENS_PER_EXPERT * topk
scales_k = k // 32
padded_k = (scales_k + (4 - 1)) // 4
output = torch.empty(
m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8
)
output_scales = torch.empty(
MAX_TOKENS_PER_EXPERT * topk,
padded_k,
dtype=torch.int32,
device=input_tensor.device,
)
torch.ops._C.silu_and_mul_mxfp4_experts_quant(
output,
output_scales,
input_tensor,
expert_offsets,
blockscale_offsets,
n_experts,
)
output_scales = output_scales.view(torch.uint8)
return output, output_scales
# fp8
def scaled_fp8_quant(
input: torch.Tensor,
......
......@@ -762,6 +762,25 @@ def nvfp4_moe_quant_config(
)
def mxfp4_moe_quant_config(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for MXFP4 x MXFP4 MoE.
MXFP4 uses block scaling only (E8M0 scales, 32-element groups), with no
separate alphas / global activation scales in this config.
"""
return FusedMoEQuantConfig.make(
"mxfp4",
w1_scale=w1_scale,
w2_scale=w2_scale,
per_act_token_quant=False,
per_out_ch_quant=False,
block_shape=None,
)
def nvfp4_w4a16_moe_quant_config(
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
......
......@@ -36,6 +36,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8DynamicTokenSym,
kFp8StaticChannelSym,
kFp8StaticTensorSym,
kMxfp4Dynamic,
kMxfp4Static,
kNvfp4Dynamic,
kNvfp4Static,
)
......@@ -795,6 +797,299 @@ class CutlassExpertsFp4(mk.FusedMoEExpertsModular):
)
def run_cutlass_moe_mxfp4(
output: torch.Tensor,
a: torch.Tensor,
w1_fp4: torch.Tensor,
w1_blockscale: torch.Tensor,
w2_fp4: torch.Tensor,
w2_blockscale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: MoEActivation,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
m: int,
n: int,
k: int,
e: int,
device: torch.device,
apply_router_weight_on_input: bool = False,
) -> None:
"""MXFP4 x MXFP4 MoE implementation using CUTLASS grouped GEMM."""
is_gated = activation.is_gated
w1_n = n * 2 if is_gated else n
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert w1_fp4.dtype == torch.uint8, "weight 1 must be uint8"
assert w2_fp4.dtype == torch.uint8, "weight 2 must be uint8"
assert (
w1_fp4.ndim == 3
and w2_fp4.ndim == 3
and w1_blockscale.ndim == 3
and w2_blockscale.ndim == 3
), "All Weights must be of rank 3 for cutlass_moe_mxfp4"
m_a, k_a = a.shape
e_w1, w1_n_actual, half_k_w1 = w1_fp4.shape
e_w2, k_w2, half_n_w2 = w2_fp4.shape
assert e_w1 == e_w2 and e_w1 == e
assert k_a == half_k_w1 * 2 and k == k_w2
assert w1_n_actual == w1_n and half_n_w2 * 2 == n
assert m == m_a
assert 2 * half_k_w1 == k_w2
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
assert topk_weights.size(0) == m and topk_ids.size(0) == m
topk = topk_ids.size(1)
out_dtype = a.dtype
num_topk = topk_ids.size(1)
expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device)
blockscale_offsets = torch.empty((e + 1), dtype=torch.int32, device=device)
problem_sizes1 = torch.empty((e, 3), dtype=torch.int32, device=device)
problem_sizes2 = torch.empty((e, 3), dtype=torch.int32, device=device)
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
if apply_router_weight_on_input:
assert num_topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
a.mul_(topk_weights.to(out_dtype))
ops.get_cutlass_moe_mm_data(
topk_ids,
expert_offsets,
problem_sizes1,
problem_sizes2,
a_map,
c_map,
e,
n,
k,
blockscale_offsets,
is_gated=is_gated,
)
a = ops.shuffle_rows(a, a_map)
rep_a_fp4, rep_a_blockscale = ops.mxfp4_experts_quant(
a,
expert_offsets,
blockscale_offsets,
e,
num_topk,
)
c1 = _resize_cache(workspace13, (m * topk, w1_n))
c2 = _resize_cache(workspace2, (m * topk, n))
c3 = _resize_cache(workspace13, (m * topk, k))
ops.cutlass_mxfp4_moe_mm(
c1,
rep_a_fp4,
w1_fp4,
rep_a_blockscale,
w1_blockscale,
problem_sizes1,
expert_offsets[:-1],
blockscale_offsets[:-1],
)
del rep_a_fp4, rep_a_blockscale
if activation == MoEActivation.SILU:
int_fp4, int_blockscale = ops.silu_and_mul_mxfp4_experts_quant(
c1, expert_offsets, blockscale_offsets, e, num_topk
)
else:
apply_moe_activation(activation, c2, c1)
int_fp4, int_blockscale = ops.mxfp4_experts_quant(
c2, expert_offsets, blockscale_offsets, e, num_topk
)
ops.cutlass_mxfp4_moe_mm(
c3,
int_fp4,
w2_fp4,
int_blockscale,
w2_blockscale,
problem_sizes2,
expert_offsets[:-1],
blockscale_offsets[:-1],
)
del int_fp4, int_blockscale
c3 = ops.shuffle_rows(c3, c_map)
assert output.dtype == out_dtype
if not apply_router_weight_on_input:
output.copy_(
(
c3.view(m, num_topk, k)
* topk_weights.view(m, num_topk, 1).to(out_dtype)
).sum(dim=1),
non_blocking=True,
)
else:
output.copy_(c3.view(m, num_topk, k).sum(dim=1), non_blocking=True)
return
def swizzle_mxfp4_scales(
scales: torch.Tensor,
N: int,
K: int,
) -> torch.Tensor:
"""Swizzle flat [N, K//32] E8M0 scales to CUTLASS tiled layout.
CUTLASS expects MX scale factors in a tiled layout:
[numMTiles, numKTiles, 32, 4, 4]
where numMTiles = ceil(N/128), numKTiles = ceil(K/128),
and the inner dimensions correspond to the swizzle pattern:
mTileIdx = mIdx / 128
outerMIdx = mIdx % 32
innerMIdx = (mIdx / 32) % 4
kTileIdx = kIdx / 4
innerKIdx = kIdx % 4
with kIdx = col_in_scale_space (i.e., index into K//32).
"""
assert scales.dtype == torch.uint8
num_scale_cols = K // 32 # number of E8M0 scale values per row
num_m_tiles = (N + 127) // 128
num_k_tiles = (num_scale_cols + 3) // 4
# Pad N to multiple of 128 and scale_cols to multiple of 4
padded_N = num_m_tiles * 128
padded_scale_cols = num_k_tiles * 4
# Start with flat scales, pad if needed
padded = torch.zeros(
padded_N, padded_scale_cols, dtype=torch.uint8, device=scales.device
)
padded[:N, :num_scale_cols] = scales
# Reshape to tile structure:
# [numMTiles, 4, 32, numKTiles, 4]
# mTileIdx, innerMIdx, outerMIdx, kTileIdx, innerKIdx
tiled = padded.reshape(num_m_tiles, 4, 32, num_k_tiles, 4)
# Permute to [numMTiles, numKTiles, 32, 4, 4]
# (outerMIdx, innerMIdx, innerKIdx)
tiled = tiled.permute(0, 3, 2, 1, 4).contiguous()
return tiled.reshape(-1)
class CutlassExpertsMxfp4(mk.FusedMoEExpertsModular):
"""CUTLASS MXFP4 x MXFP4 fused MoE expert implementation."""
@property
def expects_unquantized_inputs(self) -> bool:
return True
@staticmethod
def _supports_current_device() -> bool:
p = current_platform
return p.is_cuda() and p.is_device_capability_family(100)
@staticmethod
def _supports_no_act_and_mul() -> bool:
return True
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
return (weight_key, activation_key) == (kMxfp4Static, kMxfp4Dynamic)
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.SWIGLUOAI,
MoEActivation.SWIGLUSTEP,
MoEActivation.SILU_NO_MUL,
MoEActivation.GELU_NO_MUL,
MoEActivation.RELU2_NO_MUL,
]
@staticmethod
def _supports_parallel_config(
moe_parallel_config: FusedMoEParallelConfig,
) -> bool:
return moe_parallel_config.ep_size == 1
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def supports_expert_map(self) -> bool:
return False
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
return act_dtype
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
workspace1 = (M * topk, max(2 * N, K))
workspace2 = (M * topk, N)
output = (M, K)
return (workspace1, workspace2, output)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor | None,
workspace2: torch.Tensor | None,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
e, m, n, k, _ = self.moe_problem_size(hidden_states, w1, w2, topk_ids)
n = w2.shape[2] * 2
run_cutlass_moe_mxfp4(
output=output,
a=hidden_states,
w1_fp4=w1,
w1_blockscale=self.w1_scale,
w2_fp4=w2,
w2_blockscale=self.w2_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
workspace13=workspace13,
workspace2=workspace2,
m=m,
n=n,
k=k,
e=e,
device=hidden_states.device,
apply_router_weight_on_input=apply_router_weight_on_input,
)
# W4A8
def run_cutlass_moe_w4a8_fp8(
output: torch.Tensor,
......
......@@ -4,6 +4,7 @@
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
......@@ -11,6 +12,10 @@ from vllm.model_executor.layers.fused_moe import (
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
mxfp4_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsMxfp4,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
......@@ -36,7 +41,14 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
super().__init__(moe)
self.group_size = 32
self.mxfp4_backend = Mxfp4MoeBackend.MARLIN
self.experts_cls = MarlinExperts
self.use_cutlass_mxfp4 = CutlassExpertsMxfp4._supports_current_device()
self.experts_cls: type[mk.FusedMoEExperts]
if self.use_cutlass_mxfp4:
logger.info_once("Using CutlassExpertsMxfp4 for MXFP4 MoE", scope="local")
self.experts_cls = CutlassExpertsMxfp4
else:
logger.info_once("Using MarlinExperts for MXFP4 MoE", scope="local")
self.experts_cls = MarlinExperts
def create_weights(
self,
......@@ -109,11 +121,19 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
return make_mxfp4_moe_quant_config(
mxfp4_backend=self.mxfp4_backend,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
)
if self.use_cutlass_mxfp4:
# W4A4: both weights and activations quantized to MXFP4
return mxfp4_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
)
else:
# W4A16: weight-only via Marlin
return make_mxfp4_moe_quant_config(
mxfp4_backend=self.mxfp4_backend,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
)
def process_weights_after_loading(self, layer: FusedMoE) -> None:
layer.w13_weight = torch.nn.Parameter(
......@@ -126,13 +146,45 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
)
delattr(layer, "w2_weight_packed")
logger.warning_once(
"Your GPU does not have native support for FP4 computation but "
"FP4 quantization is being used. Weight-only FP4 compression "
"will be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
prepare_moe_fp4_layer_for_marlin(layer)
if self.use_cutlass_mxfp4:
# Swizzle weight scales from flat checkpoint layout [E, N, K//32]
# to CUTLASS tiled layout [E, numMTiles*numKTiles*512].
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
swizzle_mxfp4_scales,
)
E = layer.w13_weight_scale.shape[0]
w13_N = layer.w13_weight_scale.shape[1]
w13_scale_K = layer.w13_weight_scale.shape[2]
w13_K = w13_scale_K * 32
w2_M = layer.w2_weight_scale.shape[1]
w2_scale_N = layer.w2_weight_scale.shape[2]
w2_N = w2_scale_N * 32
swizzled_w13 = []
swizzled_w2 = []
for e_idx in range(E):
s13 = layer.w13_weight_scale[e_idx]
sw13 = swizzle_mxfp4_scales(s13, w13_N, w13_K)
swizzled_w13.append(sw13.reshape(w13_N, w13_scale_K))
s2 = layer.w2_weight_scale[e_idx]
sw2 = swizzle_mxfp4_scales(s2, w2_M, w2_N)
swizzled_w2.append(sw2.reshape(w2_M, w2_scale_N))
layer.w13_weight_scale = torch.nn.Parameter(
torch.stack(swizzled_w13), requires_grad=False
)
layer.w2_weight_scale = torch.nn.Parameter(
torch.stack(swizzled_w2), requires_grad=False
)
else:
logger.warning_once(
"Your GPU does not have native support for FP4 computation "
"but FP4 quantization is being used. Weight-only FP4 "
"compression will be used leveraging the Marlin kernel. "
"This may degrade performance for compute-heavy workloads."
)
prepare_moe_fp4_layer_for_marlin(layer)
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment