Unverified Commit a8bffaa1 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Kernel] Add MXFP4 W4A4 CUTLASS MoE kernel for SM100 (#37463)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent 5cdddddd
......@@ -141,6 +141,7 @@ steps:
- pytest -v -s tests/kernels/quantization/test_nvfp4_qutlass.py
- pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
- pytest -v -s tests/kernels/moe/test_mxfp4_moe.py
- pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py
- pytest -v -s tests/kernels/moe/test_flashinfer.py
- pytest -v -s tests/kernels/moe/test_flashinfer_moe.py
......
......@@ -952,7 +952,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/libtorch_stable/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
"csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu"
"csrc/libtorch_stable/quantization/fp4/mxfp4_experts_quant.cu"
"csrc/libtorch_stable/quantization/fp4/mxfp4_blockwise_moe_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
......
......@@ -134,4 +134,27 @@ void silu_and_mul_nvfp4_quant(torch::stable::Tensor& out,
torch::stable::Tensor& input,
torch::stable::Tensor& input_global_scale);
void mxfp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts,
int64_t n_experts);
void silu_and_mul_mxfp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts,
int64_t n_experts);
void cutlass_mxfp4_group_mm(torch::stable::Tensor& output,
const torch::stable::Tensor& a,
const torch::stable::Tensor& b,
const torch::stable::Tensor& a_blockscale,
const torch::stable::Tensor& b_blockscales,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets);
#endif
/*
* SPDX-License-Identifier: Apache-2.0
* SPDX-FileCopyrightText: Copyright contributors to the vLLM project
*
* MXFP4 activation quantization kernel for MoE experts.
* Quantizes BF16/FP16 activations to MXFP4: E2M1 values with E8M0 block scales
* over 32-element groups.
*
* Uses PACK16 E2M1 conversion helpers (nvfp4_utils.cuh) configured for:
* - Block size 32 (2 threads per SF in PACK16 mode)
* - E8M0 (power-of-two) scale factors
* - SF layout: [numMTiles, numKTiles, 32, 4, 4] where numKTiles=ceil(K/128)
*/
// MXFP4 requires PACK16 mode (16 elements per thread) so that
// 2 threads cover 32-element blocks. This requires CUDA >= 12.9.
// Must be defined before any header that (transitively) includes
// nvfp4_utils.cuh.
#define NVFP4_ENABLE_ELTS16 1
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <cuda_runtime.h>
#include <cuda_fp8.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "libtorch_stable/dispatch_utils.h"
#include "cuda_vec_utils.cuh"
#include "cuda_utils.h"
#include "nvfp4_utils.cuh"
static_assert(CVT_FP4_ELTS_PER_THREAD == 16,
"MXFP4 experts quant requires PACK16 mode (CUDA >= 12.9)");
#include "launch_bounds_utils.h"
namespace vllm {
// MXFP4 block size constants
static constexpr int MXFP4_SF_VEC_SIZE = 32;
// For PACK16 mode (CVT_FP4_ELTS_PER_THREAD=16): 2 threads per SF
// For PACK8 mode (CVT_FP4_ELTS_PER_THREAD=8): 4 threads per SF
static constexpr int MXFP4_NUM_THREADS_PER_SF =
MXFP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD;
// MXFP4 quantization kernel for experts.
// Uses 32-element blocks with E8M0 (UE8M0) scale factors.
// When FUSE_SILU_MUL=true, expects input with gate||up layout and fuses
// SiLU(gate)*up before quantization.
template <class Type, bool FUSE_SILU_MUL = false,
bool SMALL_NUM_EXPERTS = false>
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
mxfp4_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
fp4_packed_t* out, uint32_t* SFout,
uint32_t* input_offset_by_experts,
uint32_t* output_scale_offset_by_experts,
int n_experts, bool low_latency) {
using PackedVec = PackedVec<Type, CVT_FP4_PACK16>;
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
"Vec size is not matched.");
// MXFP4: numKTiles = ceil(numCols / 128) since block_size=32, 4 SFs/tile
int32_t const numKTiles = (numCols + 127) / 128;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
int inColsPerRow = FUSE_SILU_MUL ? colsPerRow * 2 : colsPerRow;
for (int globalIdx = tid; globalIdx < numRows * colsPerRow;
globalIdx += gridDim.x * blockDim.x) {
int rowIdx = globalIdx / colsPerRow;
int colIdx = globalIdx % colsPerRow;
int rowIdx_in_expert = 0;
int expert_idx = 0;
if constexpr (SMALL_NUM_EXPERTS) {
for (int i = 0; i < n_experts; i++) {
uint32_t current_offset = __ldca(&input_offset_by_experts[i]);
uint32_t next_offset = __ldca(&input_offset_by_experts[i + 1]);
if (rowIdx >= current_offset && rowIdx < next_offset) {
rowIdx_in_expert = rowIdx - current_offset;
expert_idx = i;
break;
}
}
} else {
uint32_t local_offsets[17];
for (int chunk_start = 0; chunk_start < n_experts; chunk_start += 16) {
*reinterpret_cast<int4*>(local_offsets) =
__ldca(reinterpret_cast<const int4*>(
&input_offset_by_experts[chunk_start]));
*reinterpret_cast<int4*>(local_offsets + 4) =
__ldca(reinterpret_cast<const int4*>(
&input_offset_by_experts[chunk_start + 4]));
*reinterpret_cast<int4*>(local_offsets + 8) =
__ldca(reinterpret_cast<const int4*>(
&input_offset_by_experts[chunk_start + 8]));
*reinterpret_cast<int4*>(local_offsets + 12) =
__ldca(reinterpret_cast<const int4*>(
&input_offset_by_experts[chunk_start + 12]));
local_offsets[16] = __ldca(&input_offset_by_experts[chunk_start + 16]);
#pragma unroll
for (int i = 0; i < 16; i++) {
if (rowIdx >= local_offsets[i] && rowIdx < local_offsets[i + 1]) {
rowIdx_in_expert = rowIdx - local_offsets[i];
expert_idx = chunk_start + i;
break;
}
}
}
}
// Load input and optionally apply fused SiLU+Mul
int64_t inOffset = rowIdx * inColsPerRow + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
PackedVec quant_input;
if constexpr (FUSE_SILU_MUL) {
PackedVec in_vec_up =
reinterpret_cast<PackedVec const*>(in)[inOffset + colsPerRow];
quant_input = compute_silu_mul(in_vec, in_vec_up);
} else {
quant_input = in_vec;
}
// In PACK16 mode, each thread outputs 16 E2M1 values = u32x2
int64_t outOffset = rowIdx * colsPerRow + colIdx;
auto& out_pos = out[outOffset];
uint32_t* SFout_in_expert =
SFout + output_scale_offset_by_experts[expert_idx] * numKTiles;
// Use MXFP4_NUM_THREADS_PER_SF (2 for PACK16) for 32-element blocks
auto sf_out =
cvt_quant_to_fp4_get_sf_out_offset<uint32_t, MXFP4_NUM_THREADS_PER_SF>(
rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert);
// Block E8M0 scales only; no extra tensor-level scale in this path
constexpr float SFScaleVal = 1.0f;
// UE8M0_SF=true for MXFP4 E8M0 scale factors
out_pos =
cvt_warp_fp16_to_fp4<Type, MXFP4_NUM_THREADS_PER_SF, /*UE8M0_SF=*/true>(
quant_input, SFScaleVal, sf_out);
}
}
// Large M_topk variant using shared memory for expert offsets
template <class Type, bool FUSE_SILU_MUL = false,
bool SMALL_NUM_EXPERTS = false>
__global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
mxfp4_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
fp4_packed_t* out, uint32_t* SFout,
uint32_t* input_offset_by_experts,
uint32_t* output_scale_offset_by_experts,
int n_experts) {
using PackedVec = PackedVec<Type, CVT_FP4_PACK16>;
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
"Vec size is not matched.");
// MXFP4: numKTiles = ceil(numCols / 128)
int32_t const numKTiles = (numCols + 127) / 128;
extern __shared__ uint32_t shared_input_offsets[];
if constexpr (SMALL_NUM_EXPERTS) {
for (int i = threadIdx.x; i < n_experts + 1; i += blockDim.x) {
shared_input_offsets[i] = input_offset_by_experts[i];
}
} else {
for (int i = threadIdx.x * 4; i < n_experts; i += blockDim.x * 4) {
*reinterpret_cast<int4*>(&shared_input_offsets[i]) =
*reinterpret_cast<const int4*>(&input_offset_by_experts[i]);
}
if (threadIdx.x == 0) {
shared_input_offsets[n_experts] = input_offset_by_experts[n_experts];
}
}
__syncthreads();
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
int inColsPerRow = FUSE_SILU_MUL ? colsPerRow * 2 : colsPerRow;
for (int globalIdx = tid; globalIdx < numRows * colsPerRow;
globalIdx += gridDim.x * blockDim.x) {
int rowIdx = globalIdx / colsPerRow;
int colIdx = globalIdx % colsPerRow;
int rowIdx_in_expert = 0;
int expert_idx = 0;
// Binary search through experts using shared memory
int left = 0, right = n_experts - 1;
while (left <= right) {
int mid = (left + right) / 2;
uint32_t mid_offset = shared_input_offsets[mid];
uint32_t next_offset = shared_input_offsets[mid + 1];
if (rowIdx >= mid_offset && rowIdx < next_offset) {
rowIdx_in_expert = rowIdx - mid_offset;
expert_idx = mid;
break;
} else if (rowIdx < mid_offset) {
right = mid - 1;
} else {
left = mid + 1;
}
}
int64_t inOffset = rowIdx * inColsPerRow + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
PackedVec quant_input;
if constexpr (FUSE_SILU_MUL) {
PackedVec in_vec_up =
reinterpret_cast<PackedVec const*>(in)[inOffset + colsPerRow];
quant_input = compute_silu_mul(in_vec, in_vec_up);
} else {
quant_input = in_vec;
}
int64_t outOffset = rowIdx * colsPerRow + colIdx;
auto& out_pos = out[outOffset];
// MXFP4 has no global scale - only block-level E8M0 scale factors
constexpr float SFScaleVal = 1.0f;
uint32_t* SFout_in_expert =
SFout + output_scale_offset_by_experts[expert_idx] * numKTiles;
auto sf_out =
cvt_quant_to_fp4_get_sf_out_offset<uint32_t, MXFP4_NUM_THREADS_PER_SF>(
rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert);
out_pos =
cvt_warp_fp16_to_fp4<Type, MXFP4_NUM_THREADS_PER_SF, /*UE8M0_SF=*/true>(
quant_input, SFScaleVal, sf_out);
}
}
template <typename T, bool FUSE_SILU_MUL = false>
void mxfp4_quant_impl(void* output, void* output_scale, void* input,
void* input_offset_by_experts,
void* output_scale_offset_by_experts, int m_topk, int k,
int n_experts, cudaStream_t stream) {
int multiProcessorCount =
get_device_attribute(cudaDevAttrMultiProcessorCount, -1);
int const workSizePerRow = k / ELTS_PER_THREAD;
int const totalWorkSize = m_topk * workSizePerRow;
dim3 block(std::min(workSizePerRow, 512));
int const numBlocksPerSM =
vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
dim3 grid(std::min(static_cast<int>((totalWorkSize + block.x - 1) / block.x),
multiProcessorCount * numBlocksPerSM));
while (grid.x <= multiProcessorCount && block.x > 64) {
grid.x *= 2;
block.x = (block.x + 1) / 2;
}
int const blockRepeat =
(totalWorkSize + block.x * grid.x - 1) / (block.x * grid.x);
if (blockRepeat > 1) {
size_t shared_mem_size = (n_experts + 1) * sizeof(uint32_t);
if (n_experts >= 4) {
mxfp4_cvt_fp16_to_fp4<T, FUSE_SILU_MUL, false>
<<<grid, block, shared_mem_size, stream>>>(
m_topk, k, reinterpret_cast<T*>(input),
reinterpret_cast<fp4_packed_t*>(output),
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<uint32_t*>(input_offset_by_experts),
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
n_experts);
} else {
mxfp4_cvt_fp16_to_fp4<T, FUSE_SILU_MUL, true>
<<<grid, block, shared_mem_size, stream>>>(
m_topk, k, reinterpret_cast<T*>(input),
reinterpret_cast<fp4_packed_t*>(output),
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<uint32_t*>(input_offset_by_experts),
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
n_experts);
}
} else {
if (n_experts >= 16) {
mxfp4_cvt_fp16_to_fp4<T, FUSE_SILU_MUL, false>
<<<grid, block, 0, stream>>>(
m_topk, k, reinterpret_cast<T*>(input),
reinterpret_cast<fp4_packed_t*>(output),
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<uint32_t*>(input_offset_by_experts),
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
n_experts, /* bool low_latency */ true);
} else {
mxfp4_cvt_fp16_to_fp4<T, FUSE_SILU_MUL, true><<<grid, block, 0, stream>>>(
m_topk, k, reinterpret_cast<T*>(input),
reinterpret_cast<fp4_packed_t*>(output),
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<uint32_t*>(input_offset_by_experts),
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
n_experts, /* bool low_latency */ true);
}
}
}
} // namespace vllm
/*Quantization entry for mxfp4 experts quantization*/
#define CHECK_TH_CUDA(x, m) \
STD_TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \
STD_TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
#define CHECK_INPUT(x, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m);
constexpr auto HALF = torch::headeronly::ScalarType::Half;
constexpr auto BF16 = torch::headeronly::ScalarType::BFloat16;
constexpr auto INT = torch::headeronly::ScalarType::Int;
constexpr auto UINT8 = torch::headeronly::ScalarType::Byte;
static constexpr int MXFP4_BLOCK_SIZE = 32;
static void validate_mxfp4_experts_quant_inputs(
torch::stable::Tensor const& output,
torch::stable::Tensor const& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts,
int64_t n_experts, int64_t m_topk, int64_t k) {
CHECK_INPUT(output, "output");
CHECK_INPUT(output_scale, "output_scale");
CHECK_INPUT(input, "input");
CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts");
CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts");
STD_TORCH_CHECK(output.dim() == 2);
STD_TORCH_CHECK(output_scale.dim() == 2);
STD_TORCH_CHECK(input.dim() == 2);
STD_TORCH_CHECK(input_offset_by_experts.dim() == 1);
STD_TORCH_CHECK(output_scale_offset_by_experts.dim() == 1);
STD_TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
STD_TORCH_CHECK(input_offset_by_experts.scalar_type() == INT);
STD_TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT);
// output is uint8 (two mxfp4 values packed into one uint8)
// output_scale is int32 (four E8M0 values packed into one int32)
STD_TORCH_CHECK(output.scalar_type() == UINT8);
STD_TORCH_CHECK(output_scale.scalar_type() == INT);
STD_TORCH_CHECK(k % MXFP4_BLOCK_SIZE == 0, "k must be a multiple of 32");
STD_TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
STD_TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1);
STD_TORCH_CHECK(output.size(0) == m_topk);
STD_TORCH_CHECK(output.size(1) == k / 2);
int scales_k = k / MXFP4_BLOCK_SIZE;
// K-dimension scale columns padded to a multiple of 4 for swizzle layout
int padded_k = (scales_k + (4 - 1)) / 4 * 4;
// 4 = 4 E8M0 values packed into one int32
STD_TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
}
void mxfp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts,
int64_t n_experts) {
auto m_topk = input.size(0);
auto k = input.size(1);
validate_mxfp4_experts_quant_inputs(
output, output_scale, input, input_offset_by_experts,
output_scale_offset_by_experts, n_experts, m_topk, k);
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(input.get_device_index());
VLLM_STABLE_DISPATCH_HALF_TYPES(
input.scalar_type(), "mxfp4_experts_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
vllm::mxfp4_quant_impl<cuda_type, /*FUSE_SILU_MUL=*/false>(
output.data_ptr(), output_scale.data_ptr(), input.data_ptr(),
input_offset_by_experts.data_ptr(),
output_scale_offset_by_experts.data_ptr(), m_topk, k, n_experts,
stream);
});
}
void silu_and_mul_mxfp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts,
int64_t n_experts) {
auto m_topk = input.size(0);
auto k_times_2 = input.size(1);
STD_TORCH_CHECK(k_times_2 % 2 == 0, "input width must be even (gate || up)");
auto k = k_times_2 / 2;
validate_mxfp4_experts_quant_inputs(
output, output_scale, input, input_offset_by_experts,
output_scale_offset_by_experts, n_experts, m_topk, k);
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(input.get_device_index());
VLLM_STABLE_DISPATCH_HALF_TYPES(
input.scalar_type(), "silu_mul_mxfp4_experts_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
vllm::mxfp4_quant_impl<cuda_type, /*FUSE_SILU_MUL=*/true>(
output.data_ptr(), output_scale.data_ptr(), input.data_ptr(),
input_offset_by_experts.data_ptr(),
output_scale_offset_by_experts.data_ptr(), m_topk, k, n_experts,
stream);
});
}
......@@ -116,6 +116,12 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
" Tensor a_blockscale, Tensor b_blockscales, Tensor alphas,"
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()");
// cutlass mxfp4 block scaled group GEMM (MXFP4 x MXFP4 MoE)
ops.def(
"cutlass_mxfp4_group_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor a_blockscale, Tensor b_blockscales,"
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()");
// Compute NVFP4 block quantized tensor.
ops.def(
"scaled_fp4_quant(Tensor input,"
......@@ -149,6 +155,19 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts) -> ()");
// Compute MXFP4 experts quantization (32-element blocks, E8M0 SFs).
ops.def(
"mxfp4_experts_quant(Tensor! output, Tensor! output_scale,"
"Tensor input, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts, int n_experts) -> ()");
// Fused SiLU+Mul+MXFP4 experts quantization.
ops.def(
"silu_and_mul_mxfp4_experts_quant(Tensor! output, Tensor! "
"output_scale,"
"Tensor input, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts, int n_experts) -> ()");
// Fused SiLU+Mul+NVFP4 quantization.
ops.def(
"silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, "
......@@ -233,6 +252,9 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
ops.impl("silu_and_mul_scaled_fp4_experts_quant",
TORCH_BOX(&silu_and_mul_scaled_fp4_experts_quant));
ops.impl("silu_and_mul_nvfp4_quant", TORCH_BOX(&silu_and_mul_nvfp4_quant));
ops.impl("mxfp4_experts_quant", TORCH_BOX(&mxfp4_experts_quant));
ops.impl("silu_and_mul_mxfp4_experts_quant",
TORCH_BOX(&silu_and_mul_mxfp4_experts_quant));
// W4A8 ops: impl registrations are in the source files
// (w4a8_mm_entry.cu and w4a8_grouped_mm_entry.cu)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for SM100 CUTLASS MXFP4 x MXFP4 grouped MoE kernels."""
import random
import pytest
import torch
from tests.kernels.utils import torch_moe_single
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
random.seed(42)
set_random_seed(42)
MXFP4_BLOCK_SIZE = 32
def align(val: int, alignment: int = 128) -> int:
return int((val + alignment - 1) // alignment * alignment)
def calc_diff(x, y):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
def is_sm100_supported() -> bool:
return current_platform.is_cuda() and current_platform.is_device_capability_family(
100
)
def compute_ref_output(
input_tensor: torch.Tensor,
weight_list: list[torch.Tensor],
expert_offsets: list[int],
expert_offset: int,
num_experts: int,
) -> torch.Tensor:
"""Reference output using torch_moe_single with top-1 routing."""
score = torch.full(
(expert_offset, num_experts),
-1e9,
device=input_tensor.device,
dtype=torch.float32,
)
for g in range(num_experts):
start = expert_offsets[g]
end = expert_offsets[g + 1] if g + 1 < num_experts else expert_offset
score[start:end, g] = 0.0
return torch_moe_single(
input_tensor, torch.stack(weight_list, dim=0), score, topk=1
)
@pytest.mark.skipif(
not is_sm100_supported(),
reason="cutlass_mxfp4_group_mm requires CUDA SM100",
)
@pytest.mark.parametrize("num_experts", [8, 16, 32])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16])
def test_cutlass_mxfp4_grouped_mm(num_experts, out_dtype):
"""
Test the MXFP4 grouped GEMM kernel by:
1. Creating random per-expert inputs and weights
2. Quantizing both to MXFP4 using the CUDA kernel
3. Running the CUTLASS grouped GEMM
4. Comparing against BF16 reference
"""
device = "cuda"
alignment = 128
# N and K must be multiples of 128 for clean swizzle layout
n_g = random.randint(1, 16) * alignment
k_g = random.randint(1, 16) * alignment
expert_offset = 0
expert_offsets_input = []
problem_sizes = []
input_list = []
weight_list = []
for g in range(num_experts):
m_g = random.randint(1, 256)
expert_offsets_input.append(expert_offset)
expert_offset += m_g
problem_sizes.append([m_g, n_g, k_g])
input_list.append(
torch.normal(0.0, std=0.5, size=(m_g, k_g), device=device, dtype=out_dtype)
)
weight_list.append(
torch.normal(0.0, std=0.5, size=(n_g, k_g), device=device, dtype=out_dtype)
)
input_tensor = torch.concat(input_list, dim=0) # [M_total, K]
# --- Quantize INPUTS via mxfp4_experts_quant ---
input_bs_offsets = []
tot = 0
for g in range(num_experts):
input_bs_offsets.append(tot)
tot += align(problem_sizes[g][0], 128)
input_bs_offsets.append(tot)
_inp_expert_offsets = torch.tensor(
expert_offsets_input + [expert_offset], device=device, dtype=torch.int32
)
_inp_bs_offsets = torch.tensor(input_bs_offsets, device=device, dtype=torch.int32)
input_quant, input_sf = ops.mxfp4_experts_quant(
input_tensor,
_inp_expert_offsets,
_inp_bs_offsets,
num_experts,
topk=1,
)
# --- Quantize WEIGHTS via mxfp4_experts_quant ---
# Treat each expert's N weight rows as an "expert" with N tokens
weight_tensor = torch.concat(weight_list, dim=0) # [E*N, K]
weight_expert_offsets = [g * n_g for g in range(num_experts)] + [num_experts * n_g]
# N is always multiple of 128, so blockscale offsets are clean
weight_bs_offsets = [g * n_g for g in range(num_experts)] + [num_experts * n_g]
_wt_expert_offsets = torch.tensor(
weight_expert_offsets, device=device, dtype=torch.int32
)
_wt_bs_offsets = torch.tensor(weight_bs_offsets, device=device, dtype=torch.int32)
weight_quant, weight_sf = ops.mxfp4_experts_quant(
weight_tensor,
_wt_expert_offsets,
_wt_bs_offsets,
num_experts,
topk=1,
)
# Reshape weight quantized data to [E, N, K//2]
weight_quant = weight_quant[: num_experts * n_g].view(num_experts, n_g, k_g // 2)
# Reshape weight scale factors to [E, N, K//32]
# The quant kernel produces uint8 SF buffer. Each row has K//32 SFs.
scales_per_row = k_g // MXFP4_BLOCK_SIZE
weight_sf_flat = weight_sf.view(-1)[: num_experts * n_g * scales_per_row]
weight_sf_3d = weight_sf_flat.view(num_experts, n_g, scales_per_row)
# Output
output = torch.empty((expert_offset, n_g), device=device, dtype=out_dtype)
_problem_sizes = torch.tensor(problem_sizes, device=device, dtype=torch.int32)
_expert_offsets = torch.tensor(
expert_offsets_input, device=device, dtype=torch.int32
)
_input_bs = torch.tensor(input_bs_offsets[:-1], device=device, dtype=torch.int32)
# Run the MXFP4 grouped GEMM
ops.cutlass_mxfp4_moe_mm(
output,
input_quant,
weight_quant,
input_sf,
weight_sf_3d,
_problem_sizes,
_expert_offsets,
_input_bs,
)
# Reference: BF16 matmul
ref_output = compute_ref_output(
input_tensor=input_tensor,
weight_list=weight_list,
expert_offsets=expert_offsets_input,
expert_offset=expert_offset,
num_experts=num_experts,
)
# Compare per-expert
for g in range(num_experts):
start = expert_offsets_input[g]
end = expert_offsets_input[g + 1] if g + 1 < num_experts else expert_offset
if start == end:
continue
baseline = ref_output[start:end]
actual = output[start:end]
diff = calc_diff(actual, baseline)
print(
f"m_g={end - start} n_g={n_g} k_g={k_g} "
f"num_experts={num_experts}, "
f"out_dtype={out_dtype}, diff={diff:.5f}"
)
# FP4 quantization is very lossy (~4 bits precision)
# Comparing quantized vs full-precision gives cosine diff of 0.05-0.15
assert diff < 0.15, f"Expert {g}: diff={diff:.5f} exceeds threshold"
@pytest.mark.skipif(
not is_sm100_supported(),
reason="mxfp4_experts_quant requires CUDA SM100",
)
def test_mxfp4_experts_quant_basic():
"""
Basic smoke test for the MXFP4 experts quantization kernel.
"""
device = "cuda"
num_experts = 4
k = 256
tokens_per_expert = 16
total_tokens = tokens_per_expert * num_experts
input_tensor = torch.randn(total_tokens, k, device=device, dtype=torch.bfloat16) / 5
expert_offsets = [i * tokens_per_expert for i in range(num_experts + 1)]
blockscale_offsets = [
align(i * tokens_per_expert, 128) for i in range(num_experts + 1)
]
_expert_offsets = torch.tensor(expert_offsets, device=device, dtype=torch.int32)
_blockscale_offsets = torch.tensor(
blockscale_offsets, device=device, dtype=torch.int32
)
output, output_sf = ops.mxfp4_experts_quant(
input_tensor,
_expert_offsets,
_blockscale_offsets,
num_experts,
topk=1,
)
assert output.shape == (total_tokens, k // 2)
assert output.dtype == torch.uint8
assert output_sf.dtype == torch.uint8
assert output.any(), "Quantized output is all zeros"
print(
f"MXFP4 experts quant: output shape={output.shape}, sf shape={output_sf.shape}"
)
print("PASSED")
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])
......@@ -1150,6 +1150,38 @@ def cutlass_fp4_moe_mm(
)
def cutlass_mxfp4_moe_mm(
out_tensors: torch.Tensor,
a_tensors: torch.Tensor,
b_tensors: torch.Tensor,
a_scales: torch.Tensor,
b_scales: torch.Tensor,
problem_sizes: torch.Tensor,
expert_offsets: torch.Tensor,
sf_offsets: torch.Tensor,
):
"""
An MXFP4 Blockscaled Group Gemm for MoE (MXFP4 x MXFP4).
Uses mx_float4_t types with E8M0 scale factors and 32-element blocks.
- a/b_tensors: MXFP4 packed activations/weights (uint8, 2 E2M1 per byte)
- a_/b_scales: E8M0 blockscales (uint8, stored in swizzled layout)
- Epilogue uses scalar alpha=1, beta=0 inside the CUDA op (no global scales).
- expert_offsets/sf_offsets: expert boundary indices
- problem_sizes: (num_experts, 3) with (M, N, K) per expert
"""
return torch.ops._C.cutlass_mxfp4_group_mm(
out_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
problem_sizes,
expert_offsets,
sf_offsets,
)
def mxfp8_experts_quant(
input_tensor: torch.Tensor,
problem_sizes: torch.Tensor,
......@@ -1848,6 +1880,109 @@ def silu_and_mul_scaled_fp4_experts_quant(
return output, output_scales
def mxfp4_experts_quant(
input_tensor: torch.Tensor,
expert_offsets: torch.Tensor,
blockscale_offsets: torch.Tensor,
n_experts: int,
topk: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to MXFP4 for packed MoE inputs.
Uses 32-element blocks with E8M0 (power-of-two) scale factors.
MXFP4 has no global scale - only block-level E8M0 scale factors.
Args:
input_tensor: [m_topk, k] BF16/FP16 activations
expert_offsets: [n_experts+1] token boundaries per expert
blockscale_offsets: [n_experts+1] SF row boundaries per expert
n_experts: number of experts
topk: number of top-k experts
Returns:
output: [m_topk, k//2] packed E2M1 values (uint8)
output_scales: E8M0 blockscales in swizzled layout (uint8 view)
"""
assert not current_platform.is_rocm()
assert input_tensor.ndim == 2
MAX_TOKENS_PER_EXPERT = envs.VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE
m_numtopk, k = input_tensor.shape
assert m_numtopk <= MAX_TOKENS_PER_EXPERT * topk, (
f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT("
f"{MAX_TOKENS_PER_EXPERT})"
f" for cutlass_moe_mxfp4, observed m_numtopk = {m_numtopk}. Use"
f" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE to set this value."
)
scales_k = k // 32
padded_k = (scales_k + (4 - 1)) // 4
output = torch.empty(
m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8
)
output_scales = torch.empty(
MAX_TOKENS_PER_EXPERT * topk,
padded_k,
dtype=torch.int32,
device=input_tensor.device,
)
torch.ops._C.mxfp4_experts_quant(
output,
output_scales,
input_tensor,
expert_offsets,
blockscale_offsets,
n_experts,
)
# E8M0 SFs are stored as uint8
output_scales = output_scales.view(torch.uint8)
return output, output_scales
def silu_and_mul_mxfp4_experts_quant(
input_tensor: torch.Tensor,
expert_offsets: torch.Tensor,
blockscale_offsets: torch.Tensor,
n_experts: int,
topk: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Fused SiLU+Mul+MXFP4 quantization for MoE intermediate activations.
MXFP4 has no global scale - only block-level E8M0 scale factors.
"""
assert not current_platform.is_rocm()
assert input_tensor.ndim == 2
MAX_TOKENS_PER_EXPERT = envs.VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE
m_numtopk, k_times_2 = input_tensor.shape
assert k_times_2 % 2 == 0, "input width must be even (gate || up layout)"
k = k_times_2 // 2
assert m_numtopk <= MAX_TOKENS_PER_EXPERT * topk
scales_k = k // 32
padded_k = (scales_k + (4 - 1)) // 4
output = torch.empty(
m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8
)
output_scales = torch.empty(
MAX_TOKENS_PER_EXPERT * topk,
padded_k,
dtype=torch.int32,
device=input_tensor.device,
)
torch.ops._C.silu_and_mul_mxfp4_experts_quant(
output,
output_scales,
input_tensor,
expert_offsets,
blockscale_offsets,
n_experts,
)
output_scales = output_scales.view(torch.uint8)
return output, output_scales
# fp8
def scaled_fp8_quant(
input: torch.Tensor,
......
......@@ -762,6 +762,25 @@ def nvfp4_moe_quant_config(
)
def mxfp4_moe_quant_config(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for MXFP4 x MXFP4 MoE.
MXFP4 uses block scaling only (E8M0 scales, 32-element groups), with no
separate alphas / global activation scales in this config.
"""
return FusedMoEQuantConfig.make(
"mxfp4",
w1_scale=w1_scale,
w2_scale=w2_scale,
per_act_token_quant=False,
per_out_ch_quant=False,
block_shape=None,
)
def nvfp4_w4a16_moe_quant_config(
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
......
......@@ -36,6 +36,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8DynamicTokenSym,
kFp8StaticChannelSym,
kFp8StaticTensorSym,
kMxfp4Dynamic,
kMxfp4Static,
kNvfp4Dynamic,
kNvfp4Static,
)
......@@ -795,6 +797,299 @@ class CutlassExpertsFp4(mk.FusedMoEExpertsModular):
)
def run_cutlass_moe_mxfp4(
output: torch.Tensor,
a: torch.Tensor,
w1_fp4: torch.Tensor,
w1_blockscale: torch.Tensor,
w2_fp4: torch.Tensor,
w2_blockscale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: MoEActivation,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
m: int,
n: int,
k: int,
e: int,
device: torch.device,
apply_router_weight_on_input: bool = False,
) -> None:
"""MXFP4 x MXFP4 MoE implementation using CUTLASS grouped GEMM."""
is_gated = activation.is_gated
w1_n = n * 2 if is_gated else n
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert w1_fp4.dtype == torch.uint8, "weight 1 must be uint8"
assert w2_fp4.dtype == torch.uint8, "weight 2 must be uint8"
assert (
w1_fp4.ndim == 3
and w2_fp4.ndim == 3
and w1_blockscale.ndim == 3
and w2_blockscale.ndim == 3
), "All Weights must be of rank 3 for cutlass_moe_mxfp4"
m_a, k_a = a.shape
e_w1, w1_n_actual, half_k_w1 = w1_fp4.shape
e_w2, k_w2, half_n_w2 = w2_fp4.shape
assert e_w1 == e_w2 and e_w1 == e
assert k_a == half_k_w1 * 2 and k == k_w2
assert w1_n_actual == w1_n and half_n_w2 * 2 == n
assert m == m_a
assert 2 * half_k_w1 == k_w2
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
assert topk_weights.size(0) == m and topk_ids.size(0) == m
topk = topk_ids.size(1)
out_dtype = a.dtype
num_topk = topk_ids.size(1)
expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device)
blockscale_offsets = torch.empty((e + 1), dtype=torch.int32, device=device)
problem_sizes1 = torch.empty((e, 3), dtype=torch.int32, device=device)
problem_sizes2 = torch.empty((e, 3), dtype=torch.int32, device=device)
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
if apply_router_weight_on_input:
assert num_topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
a.mul_(topk_weights.to(out_dtype))
ops.get_cutlass_moe_mm_data(
topk_ids,
expert_offsets,
problem_sizes1,
problem_sizes2,
a_map,
c_map,
e,
n,
k,
blockscale_offsets,
is_gated=is_gated,
)
a = ops.shuffle_rows(a, a_map)
rep_a_fp4, rep_a_blockscale = ops.mxfp4_experts_quant(
a,
expert_offsets,
blockscale_offsets,
e,
num_topk,
)
c1 = _resize_cache(workspace13, (m * topk, w1_n))
c2 = _resize_cache(workspace2, (m * topk, n))
c3 = _resize_cache(workspace13, (m * topk, k))
ops.cutlass_mxfp4_moe_mm(
c1,
rep_a_fp4,
w1_fp4,
rep_a_blockscale,
w1_blockscale,
problem_sizes1,
expert_offsets[:-1],
blockscale_offsets[:-1],
)
del rep_a_fp4, rep_a_blockscale
if activation == MoEActivation.SILU:
int_fp4, int_blockscale = ops.silu_and_mul_mxfp4_experts_quant(
c1, expert_offsets, blockscale_offsets, e, num_topk
)
else:
apply_moe_activation(activation, c2, c1)
int_fp4, int_blockscale = ops.mxfp4_experts_quant(
c2, expert_offsets, blockscale_offsets, e, num_topk
)
ops.cutlass_mxfp4_moe_mm(
c3,
int_fp4,
w2_fp4,
int_blockscale,
w2_blockscale,
problem_sizes2,
expert_offsets[:-1],
blockscale_offsets[:-1],
)
del int_fp4, int_blockscale
c3 = ops.shuffle_rows(c3, c_map)
assert output.dtype == out_dtype
if not apply_router_weight_on_input:
output.copy_(
(
c3.view(m, num_topk, k)
* topk_weights.view(m, num_topk, 1).to(out_dtype)
).sum(dim=1),
non_blocking=True,
)
else:
output.copy_(c3.view(m, num_topk, k).sum(dim=1), non_blocking=True)
return
def swizzle_mxfp4_scales(
scales: torch.Tensor,
N: int,
K: int,
) -> torch.Tensor:
"""Swizzle flat [N, K//32] E8M0 scales to CUTLASS tiled layout.
CUTLASS expects MX scale factors in a tiled layout:
[numMTiles, numKTiles, 32, 4, 4]
where numMTiles = ceil(N/128), numKTiles = ceil(K/128),
and the inner dimensions correspond to the swizzle pattern:
mTileIdx = mIdx / 128
outerMIdx = mIdx % 32
innerMIdx = (mIdx / 32) % 4
kTileIdx = kIdx / 4
innerKIdx = kIdx % 4
with kIdx = col_in_scale_space (i.e., index into K//32).
"""
assert scales.dtype == torch.uint8
num_scale_cols = K // 32 # number of E8M0 scale values per row
num_m_tiles = (N + 127) // 128
num_k_tiles = (num_scale_cols + 3) // 4
# Pad N to multiple of 128 and scale_cols to multiple of 4
padded_N = num_m_tiles * 128
padded_scale_cols = num_k_tiles * 4
# Start with flat scales, pad if needed
padded = torch.zeros(
padded_N, padded_scale_cols, dtype=torch.uint8, device=scales.device
)
padded[:N, :num_scale_cols] = scales
# Reshape to tile structure:
# [numMTiles, 4, 32, numKTiles, 4]
# mTileIdx, innerMIdx, outerMIdx, kTileIdx, innerKIdx
tiled = padded.reshape(num_m_tiles, 4, 32, num_k_tiles, 4)
# Permute to [numMTiles, numKTiles, 32, 4, 4]
# (outerMIdx, innerMIdx, innerKIdx)
tiled = tiled.permute(0, 3, 2, 1, 4).contiguous()
return tiled.reshape(-1)
class CutlassExpertsMxfp4(mk.FusedMoEExpertsModular):
"""CUTLASS MXFP4 x MXFP4 fused MoE expert implementation."""
@property
def expects_unquantized_inputs(self) -> bool:
return True
@staticmethod
def _supports_current_device() -> bool:
p = current_platform
return p.is_cuda() and p.is_device_capability_family(100)
@staticmethod
def _supports_no_act_and_mul() -> bool:
return True
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
return (weight_key, activation_key) == (kMxfp4Static, kMxfp4Dynamic)
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.SWIGLUOAI,
MoEActivation.SWIGLUSTEP,
MoEActivation.SILU_NO_MUL,
MoEActivation.GELU_NO_MUL,
MoEActivation.RELU2_NO_MUL,
]
@staticmethod
def _supports_parallel_config(
moe_parallel_config: FusedMoEParallelConfig,
) -> bool:
return moe_parallel_config.ep_size == 1
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def supports_expert_map(self) -> bool:
return False
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
return act_dtype
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
workspace1 = (M * topk, max(2 * N, K))
workspace2 = (M * topk, N)
output = (M, K)
return (workspace1, workspace2, output)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor | None,
workspace2: torch.Tensor | None,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
e, m, n, k, _ = self.moe_problem_size(hidden_states, w1, w2, topk_ids)
n = w2.shape[2] * 2
run_cutlass_moe_mxfp4(
output=output,
a=hidden_states,
w1_fp4=w1,
w1_blockscale=self.w1_scale,
w2_fp4=w2,
w2_blockscale=self.w2_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
workspace13=workspace13,
workspace2=workspace2,
m=m,
n=n,
k=k,
e=e,
device=hidden_states.device,
apply_router_weight_on_input=apply_router_weight_on_input,
)
# W4A8
def run_cutlass_moe_w4a8_fp8(
output: torch.Tensor,
......
......@@ -4,6 +4,7 @@
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
......@@ -11,6 +12,10 @@ from vllm.model_executor.layers.fused_moe import (
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
mxfp4_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsMxfp4,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
......@@ -36,7 +41,14 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
super().__init__(moe)
self.group_size = 32
self.mxfp4_backend = Mxfp4MoeBackend.MARLIN
self.experts_cls = MarlinExperts
self.use_cutlass_mxfp4 = CutlassExpertsMxfp4._supports_current_device()
self.experts_cls: type[mk.FusedMoEExperts]
if self.use_cutlass_mxfp4:
logger.info_once("Using CutlassExpertsMxfp4 for MXFP4 MoE", scope="local")
self.experts_cls = CutlassExpertsMxfp4
else:
logger.info_once("Using MarlinExperts for MXFP4 MoE", scope="local")
self.experts_cls = MarlinExperts
def create_weights(
self,
......@@ -109,11 +121,19 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
return make_mxfp4_moe_quant_config(
mxfp4_backend=self.mxfp4_backend,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
)
if self.use_cutlass_mxfp4:
# W4A4: both weights and activations quantized to MXFP4
return mxfp4_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
)
else:
# W4A16: weight-only via Marlin
return make_mxfp4_moe_quant_config(
mxfp4_backend=self.mxfp4_backend,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
)
def process_weights_after_loading(self, layer: FusedMoE) -> None:
layer.w13_weight = torch.nn.Parameter(
......@@ -126,13 +146,45 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
)
delattr(layer, "w2_weight_packed")
logger.warning_once(
"Your GPU does not have native support for FP4 computation but "
"FP4 quantization is being used. Weight-only FP4 compression "
"will be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
prepare_moe_fp4_layer_for_marlin(layer)
if self.use_cutlass_mxfp4:
# Swizzle weight scales from flat checkpoint layout [E, N, K//32]
# to CUTLASS tiled layout [E, numMTiles*numKTiles*512].
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
swizzle_mxfp4_scales,
)
E = layer.w13_weight_scale.shape[0]
w13_N = layer.w13_weight_scale.shape[1]
w13_scale_K = layer.w13_weight_scale.shape[2]
w13_K = w13_scale_K * 32
w2_M = layer.w2_weight_scale.shape[1]
w2_scale_N = layer.w2_weight_scale.shape[2]
w2_N = w2_scale_N * 32
swizzled_w13 = []
swizzled_w2 = []
for e_idx in range(E):
s13 = layer.w13_weight_scale[e_idx]
sw13 = swizzle_mxfp4_scales(s13, w13_N, w13_K)
swizzled_w13.append(sw13.reshape(w13_N, w13_scale_K))
s2 = layer.w2_weight_scale[e_idx]
sw2 = swizzle_mxfp4_scales(s2, w2_M, w2_N)
swizzled_w2.append(sw2.reshape(w2_M, w2_scale_N))
layer.w13_weight_scale = torch.nn.Parameter(
torch.stack(swizzled_w13), requires_grad=False
)
layer.w2_weight_scale = torch.nn.Parameter(
torch.stack(swizzled_w2), requires_grad=False
)
else:
logger.warning_once(
"Your GPU does not have native support for FP4 computation "
"but FP4 quantization is being used. Weight-only FP4 "
"compression will be used leveraging the Marlin kernel. "
"This may degrade performance for compute-heavy workloads."
)
prepare_moe_fp4_layer_for_marlin(layer)
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment