Unverified Commit 074854b2 authored by Duncan Moss's avatar Duncan Moss Committed by GitHub
Browse files

[Kernel][B200] `mxfp4` fused cutlass moe (#23696)


Signed-off-by: default avatarDuncan Moss <djm.moss@gmail.com>
Signed-off-by: default avatarMichael Goin <mgoin64@gmail.com>
Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Co-authored-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
parent 79ac59f3
......@@ -11,6 +11,7 @@ import torch
from packaging import version
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
QUARK_MXFP4_AVAILABLE = importlib.util.find_spec(
"quark") is not None and version.parse(
......@@ -19,6 +20,10 @@ QUARK_MXFP4_AVAILABLE = importlib.util.find_spec(
TRTLLM_GEN_MXFP4_AVAILABLE = current_platform.is_cuda(
) and current_platform.is_device_capability(100)
HOPPER_MXFP4_BF16_AVAILABLE = (current_platform.is_cuda()
and current_platform.is_device_capability(90)
and has_flashinfer())
if TRTLLM_GEN_MXFP4_AVAILABLE:
from flashinfer import (fp4_quantize, mxfp8_quantize,
next_positive_power_of_2,
......@@ -542,3 +547,317 @@ def test_trtllm_gen_mxfp4_fused_moe(
transpose_optimized=transpose_optimized)
# relatively loose check since the mxfp4 quantization is less accurate
check_accuracy(ref_result, tg_result, atol=0, rtol=0.3, percent=0.8)
def _interleave_scales_lastdim_by4(scales: torch.Tensor) -> torch.Tensor:
"""Interleave scales on the last dimension by groups of 4, matching
the transformation in mxfp4.py's BF16 (Hopper) path."""
s = scales.to(torch.uint8)
s_shape = s.shape
assert s_shape[-1] % 4 == 0
s = s.reshape(*s_shape[:-1], s_shape[-1] // 4, 4)
# Move the 4-group dimension before the row dimension
permuted = s.permute(0, 2, 1, 3)
# Merge the row dim with the 4-group dim
return permuted.reshape(s_shape[0], s_shape[-1] // 4, s_shape[1] * 4)
@pytest.mark.parametrize("topk", [1, 4])
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("num_tokens", [1, 128])
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None),
(1.702, 1.0, 7.0)])
@pytest.mark.skipif(
not HOPPER_MXFP4_BF16_AVAILABLE,
reason="nvidia gpu sm90 and flashinfer are required for this test",
)
def test_flashinfer_cutlass_mxfp4_fused_moe(
topk: int,
num_experts: int,
num_tokens: int,
intermediate_size: int,
hidden_size: int,
alpha: float,
beta: float,
limit: Optional[float],
):
torch.manual_seed(42)
device = "cuda:0"
# Inputs
hidden_states = torch.randn(num_tokens,
hidden_size,
device=device,
dtype=torch.bfloat16)
# Random MXFP4 weights and scales (uint8), contiguous [w1; w3]
w13_q = torch.randint(
0,
256, (num_experts, 2 * intermediate_size, hidden_size // 2),
device=device,
dtype=torch.uint8)
w13_scale = torch.randint(
118,
123, (num_experts, 2 * intermediate_size, hidden_size // 32),
device=device,
dtype=torch.uint8)
w2_q = torch.randint(0,
256,
(num_experts, hidden_size, intermediate_size // 2),
device=device,
dtype=torch.uint8)
w2_scale = torch.randint(
118,
123, (num_experts, hidden_size, intermediate_size // 32),
device=device,
dtype=torch.uint8)
# Bias contiguous [b1; b3]
bias13 = (torch.randn(num_experts,
2 * intermediate_size,
device=device,
dtype=torch.bfloat16) * 10)
bias2 = (torch.randn(
num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10)
router_logits = torch.rand(num_tokens,
num_experts,
dtype=torch.float32,
device=device)
w13_ref = mxfp4_dequantize(w13_q.clone(), w13_scale.clone()).reshape(
num_experts, 2 * intermediate_size, hidden_size)
w2_ref = mxfp4_dequantize(w2_q.clone(), w2_scale.clone()).reshape(
num_experts, hidden_size, intermediate_size)
ref = reference_moe(router_logits.to(torch.float32), topk, num_experts,
hidden_states.to(torch.float32), w13_ref,
bias13.to(torch.float32), w2_ref,
bias2.to(torch.float32), alpha, beta, limit, 'bf16')
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
# Swap halves to arrange as [w3; w1] (kernel expectation)
w1_w, w3_w = torch.chunk(w13_q, 2, dim=1)
w13_q_swapped = torch.cat([w3_w, w1_w], dim=1)
b1, b3 = torch.chunk(bias13.to(torch.float32), 2, dim=-1)
w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)
w1_s, w3_s = torch.chunk(w13_scale, 2, dim=1)
w13_s = torch.cat([w3_s, w1_s], dim=1)
w13_s_inter = _interleave_scales_lastdim_by4(w13_s)
w2_s_inter = _interleave_scales_lastdim_by4(w2_scale)
routing_weights = torch.nn.functional.softmax(router_logits,
dim=1,
dtype=torch.float32)
token_final_scales, token_selected_experts = torch.topk(routing_weights,
topk,
dim=-1)
token_final_scales = (token_final_scales /
token_final_scales.sum(dim=-1, keepdim=True))
token_selected_experts = token_selected_experts.to(torch.int).contiguous()
out = torch.empty_like(hidden_states, dtype=torch.bfloat16)
if alpha is not None:
alpha = torch.full((num_experts, ), alpha, device=hidden_states.device)
if beta is not None:
beta = torch.full((num_experts, ), beta, device=hidden_states.device)
if limit is not None:
limit = torch.full((num_experts, ), limit, device=hidden_states.device)
_ = flashinfer_cutlass_fused_moe(
input=hidden_states,
token_selected_experts=token_selected_experts,
token_final_scales=token_final_scales,
fc1_expert_weights=w13_q_swapped,
fc2_expert_weights=w2_q,
output_dtype=torch.bfloat16,
output=out,
quant_scales=[w13_s_inter.to(torch.uint8),
w2_s_inter.to(torch.uint8)],
fc1_expert_biases=w13_b,
fc2_expert_biases=bias2.to(torch.bfloat16),
swiglu_alpha=alpha,
swiglu_beta=beta,
swiglu_limit=limit,
tp_size=1,
tp_rank=0,
ep_size=1,
ep_rank=0,
use_w4_group_scaling=True,
)
# Allow some mismatch due to MXFP4 quantization
check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8)
@pytest.mark.parametrize("topk", [1, 4])
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("num_tokens", [1, 128])
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None),
(1.702, 1.0, 7.0)])
@pytest.mark.skipif(
not (current_platform.is_cuda()
and current_platform.is_device_capability(100) and has_flashinfer()),
reason="NVIDIA GPU sm100 and flashinfer are required for this test",
)
def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
topk: int,
num_experts: int,
num_tokens: int,
intermediate_size: int,
hidden_size: int,
alpha: Optional[float],
beta: Optional[float],
limit: Optional[float],
):
torch.manual_seed(42)
device = "cuda:0"
# Inputs
hidden_states = torch.randn(num_tokens,
hidden_size,
device=device,
dtype=torch.bfloat16)
# Float weights in w13 format [w1; w3]
w13 = (torch.randn(num_experts,
2 * intermediate_size,
hidden_size,
device=device,
dtype=torch.bfloat16) / 10)
w2 = (torch.randn(num_experts,
hidden_size,
intermediate_size,
device=device,
dtype=torch.bfloat16) / 10)
# Bias contiguous [b1; b3]
bias13 = (torch.randn(num_experts,
2 * intermediate_size,
device=device,
dtype=torch.bfloat16) * 10)
bias2 = (torch.randn(
num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10)
router_logits = torch.rand(num_tokens,
num_experts,
dtype=torch.float32,
device=device)
# Quantize weights to MXFP4 per expert (SM100 path)
from flashinfer import mxfp4_quantize
def quant_mxfp4_batches(a: torch.Tensor, e: int):
qs, sfs = [], []
for i in range(e):
q, sf = mxfp4_quantize(a[i].cuda())
qs.append(q)
sfs.append(sf)
return torch.stack(qs), torch.stack(sfs)
def dequant_mxfp4_batches(mat_fp4: torch.Tensor,
scale_tensor: torch.Tensor):
num_batches = mat_fp4.size(0)
scale_tensor = scale_tensor.view(num_batches, -1)
from flashinfer import mxfp4_dequantize
return torch.stack([
mxfp4_dequantize(mat_fp4[b, :, :], scale_tensor[b, :])
for b in range(num_batches)
])
w13_q, w13_scale = quant_mxfp4_batches(w13, num_experts)
w2_q, w2_scale = quant_mxfp4_batches(w2, num_experts)
# Reference result using dequantized tensors and reference_moe
w13_ref = dequant_mxfp4_batches(
w13_q.view(torch.uint8),
w13_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape(
num_experts, 2 * intermediate_size, hidden_size)
w2_ref = dequant_mxfp4_batches(
w2_q.view(torch.uint8),
w2_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape(
num_experts, hidden_size, intermediate_size)
# Quantize activations for SM100 path and dequantize for reference
hidden_states_q, hidden_states_sf = mxfp8_quantize(hidden_states, True, 32)
# Reference uses BF16 input but quantizes intermediate activation to MXFP8
ref = reference_moe(router_logits.to(torch.float32), topk, num_experts,
hidden_states.to(torch.float32), w13_ref,
bias13.to(torch.float32), w2_ref,
bias2.to(torch.float32), alpha, beta, limit, 'mxfp8')
# Prepare inputs for FlashInfer CUTLASS fused MoE
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
# Swap halves to arrange as [w3; w1] (kernel expectation)
w1_w, w3_w = torch.chunk(w13_q, 2, dim=1)
w13_q_swapped = torch.cat([w3_w, w1_w], dim=1)
# Swap scales halves to match swapped weights
s1, s3 = torch.chunk(w13_scale, 2, dim=1)
w13_scale_swapped = torch.cat([s3, s1], dim=1)
b1, b3 = torch.chunk(bias13.to(torch.float32), 2, dim=-1)
w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)
# Build routing for kernel
routing_weights = torch.nn.functional.softmax(router_logits,
dim=1,
dtype=torch.float32)
token_final_scales, token_selected_experts = torch.topk(routing_weights,
topk,
dim=-1)
token_final_scales = (token_final_scales /
token_final_scales.sum(dim=-1, keepdim=True))
token_selected_experts = token_selected_experts.to(torch.int).contiguous()
out = torch.empty_like(hidden_states, dtype=torch.bfloat16)
if alpha is not None:
alpha_t = torch.full((num_experts, ),
alpha,
device=hidden_states.device)
else:
alpha_t = None
if beta is not None:
beta_t = torch.full((num_experts, ), beta, device=hidden_states.device)
else:
beta_t = None
if limit is not None:
limit_t = torch.full((num_experts, ),
limit,
device=hidden_states.device)
else:
limit_t = None
# Quant scales for SM100 MXFP8+MXFP4 path
fake_input_scale = torch.ones(num_experts, device=device)
quant_scales = [
w13_scale_swapped.view(torch.int32),
fake_input_scale,
w2_scale.view(torch.int32),
fake_input_scale,
]
_ = flashinfer_cutlass_fused_moe(
input=hidden_states_q,
token_selected_experts=token_selected_experts,
token_final_scales=token_final_scales,
fc1_expert_weights=w13_q_swapped.contiguous().view(torch.long),
fc2_expert_weights=w2_q.contiguous().view(torch.long),
output_dtype=torch.bfloat16,
output=out,
quant_scales=quant_scales,
fc1_expert_biases=w13_b,
fc2_expert_biases=bias2.to(torch.bfloat16),
swiglu_alpha=alpha_t,
swiglu_beta=beta_t,
swiglu_limit=limit_t,
tp_size=1,
tp_rank=0,
ep_size=1,
ep_rank=0,
use_mxfp8_act_scaling=True,
input_sf=hidden_states_sf,
)
# Allow some mismatch due to MXFP4 quantization
check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8)
......@@ -166,7 +166,8 @@ if TYPE_CHECKING:
VLLM_HAS_FLASHINFER_CUBIN: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
VLLM_GPT_OSS_USE_CONTAINER_TOOL: bool = False
......@@ -1004,6 +1005,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8":
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "0"))),
# If set to 1, use the FlashInfer CUTLASS backend for
# MXFP8 (activation) x MXFP4 (weight) MoE.
# This is separate from the TRTLLMGEN path controlled by
# VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8.
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS":
lambda: bool(int(
os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "0")
)),
# If set to 1, use the FlashInfer
# BF16 (activation) x MXFP4 (weight) MoE backend.
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16":
......@@ -1296,6 +1306,7 @@ def compute_hash() -> str:
"VLLM_USE_FLASHINFER_MOE_FP8",
"VLLM_USE_FLASHINFER_MOE_FP4",
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8",
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS",
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16",
"VLLM_USE_CUDNN_PREFILL",
"VLLM_USE_TRTLLM_ATTENTION",
......
......@@ -813,9 +813,16 @@ class FusedMoE(CustomOp):
# we are padding globally so EP buffer allocation works
if quant_config and quant_config.get_name() == "mxfp4":
from vllm.model_executor.layers.quantization.mxfp4 import ( # noqa: E501
should_use_flashinfer_mxfp4)
if current_platform.is_rocm() or should_use_flashinfer_mxfp4():
from vllm.model_executor.layers.quantization.mxfp4 import (
Mxfp4Backend, get_mxfp4_backend)
current_mxfp4_backend = get_mxfp4_backend()
if (current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
or current_mxfp4_backend
== Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS):
hidden_size = round_up(hidden_size, 128)
elif (current_platform.is_rocm() or current_mxfp4_backend
== Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM or
current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
hidden_size = round_up(hidden_size, 256)
# For smuggling this layer into the fused moe custom op
......
......@@ -33,8 +33,8 @@ def kernel_warmup(worker: "Worker"):
max_tokens = worker.scheduler_config.max_num_batched_tokens
deep_gemm_warmup(model, max_tokens)
# FlashInfer kernel autotune for Blackwell (SM 10.0) GPUs
if has_flashinfer() and current_platform.is_device_capability(100):
# FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs
if has_flashinfer() and current_platform.has_device_capability(90):
flashinfer_autotune(worker.model_runner)
# FlashInfer attention warmup
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment