Unverified Commit 3df05f4d authored by Shu Wang's avatar Shu Wang Committed by GitHub
Browse files

[NVIDIA] [3/N] Nvfp4 Masked Gemm: Add flashinfer grouped_gemm_nt_masked (#9199)

parent 7b141f81
......@@ -40,6 +40,11 @@ SGLang supports various environment variables that can be used to configure its
| `SGL_DG_USE_NVRTC` | Use NVRTC (instead of Triton) for JIT compilation (Experimental) | `"0"` |
| `SGL_USE_DEEPGEMM_BMM` | Use DeepGEMM for Batched Matrix Multiplication (BMM) operations | `"false"` |
## DeepEP Configuration
| Environment Variable | Description | Default Value |
| `SGLANG_DEEPEP_BF16_DISPATCH` | Use Bfloat16 for dispatch | `"false"` |
## Memory Management
| Environment Variable | Description | Default Value |
......
......@@ -459,6 +459,8 @@ class DeepEPMoE(EPMoE):
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
return self.forward_deepgemm_contiguous(dispatch_output)
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
if get_moe_runner_backend().is_flashinfer_cutedsl():
return self.forward_flashinfer_cutedsl(dispatch_output)
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
return self.forward_deepgemm_masked(dispatch_output)
else:
......@@ -638,6 +640,22 @@ class DeepEPMoE(EPMoE):
return gather_out
def forward_flashinfer_cutedsl(
self,
dispatch_output: DeepEPLLOutput,
):
hidden_states, _, _, masked_m, _ = dispatch_output
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
output = self.quant_method.apply_without_routing_weights(
layer=self,
x=hidden_states,
masked_m=masked_m,
moe_runner_config=self.moe_runner_config,
)
return output
def forward_deepgemm_masked(
self,
dispatch_output: DeepEPLLOutput,
......
from typing import Any, Dict, Optional
import torch
from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked
from sgl_kernel.gemm import (
scaled_fp4_grouped_quant,
silu_and_mul_scaled_fp4_grouped_quant,
)
def get_cute_dtype(input: torch.Tensor) -> str:
if input.dtype == torch.bfloat16:
return "bfloat16"
elif input.dtype == torch.float16:
return "float16"
elif input.dtype == torch.float32:
return "float32"
else:
raise ValueError(f"Unsupported cute dtype {input.dtype}")
def flashinfer_cutedsl_moe_masked(
hidden_states: torch.Tensor,
input_global_scale: torch.Tensor,
w1: torch.Tensor,
w1_blockscale: torch.Tensor,
w1_alpha,
w2: torch.Tensor,
a2_global_scale: torch.Tensor,
w2_blockscale: torch.Tensor,
w2_alpha,
masked_m: torch.Tensor,
):
"""
Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL
kernels.
Args:
hidden_states (torch.Tensor): [num_experts, m, k], bf16
input_global_scale (torch.Tensor): (l,)
w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8
w1_blockscale (torch.Tensor): blockscale factors, e4m3,
w1_alpha (torch.Tensor): (l,)
w2 (torch.Tensor): fp4 weights, [l, k, n // 2], uint8
a2_global_scale (torch.Tensor): (l,)
w2_blockscale (torch.Tensor): blockscale factors, e4m3,
w2_alpha (torch.Tensor): (l,)
masked_m (torch.Tensor): Masked dimension indices
Notes:
- Assumes max(masked_m) <= m.
"""
# === Assertions on dtypes ===
assert (
input_global_scale.dtype == torch.float32
), f"input_global_scale must be float32, got {input_global_scale.dtype}"
assert w1.dtype == torch.uint8, f"w1 must be uint8 (fp4 packed), got {w1.dtype}"
assert (
w1_blockscale.dtype == torch.float8_e4m3fn
), f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}"
assert (
w1_alpha.dtype == torch.float32
), f"w1_alpha must be float32, got {w1_alpha.dtype}"
assert w2.dtype == torch.uint8, f"w2 must be uint8 (fp4 packed), got {w2.dtype}"
assert (
a2_global_scale.dtype == torch.float32
), f"a2_global_scale must be float32, got {a2_global_scale.dtype}"
assert (
w2_blockscale.dtype == torch.float8_e4m3fn
), f"w2_blockscale must be float8_e4m3fn, got {w2_blockscale.dtype}"
assert (
w2_alpha.dtype == torch.float32
), f"w2_alpha must be float32, got {w2_alpha.dtype}"
# === Assertions on shapes ===
n = w2.shape[-1] * 2 # intermediate dimension
num_experts, m, k = hidden_states.shape
assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}"
assert (
w1.shape[-1] * 2 == k
), f"w1 last dim * 2 must equal k, got {w1.shape[-1]} vs k={k}"
assert w2.shape[-2:] == (
k,
n // 2,
), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n//2)}"
assert input_global_scale.shape == (
num_experts,
), f"input_global_scale must be (l,), got {input_global_scale.shape}"
assert w1_alpha.shape == (
num_experts,
), f"w1_alpha must be (l,), got {w1_alpha.shape}"
assert a2_global_scale.shape == (
num_experts,
), f"a2_global_scale must be (l,), got {a2_global_scale.shape}"
assert w2_alpha.shape == (
num_experts,
), f"w2_alpha must be (l,), got {w2_alpha.shape}"
aq, aq_sf = scaled_fp4_grouped_quant(
hidden_states,
input_global_scale,
masked_m,
)
gateup_output = torch.empty(
(num_experts, m, n * 2), dtype=hidden_states.dtype, device=aq.device
)
gateup_output = gateup_output.permute(1, 2, 0) # requirement of kernel
sf_vec_size = 16
assert aq_sf.dtype == torch.float8_e4m3fn
assert aq.dtype == torch.uint8
ab_dtype = "float4_e2m1fn"
sf_dtype = "float8_e4m3fn"
c_dtype = get_cute_dtype(hidden_states)
# Gemm1
grouped_gemm_nt_masked(
(aq, aq_sf),
(w1.permute(1, 2, 0), w1_blockscale),
gateup_output,
masked_m,
ab_dtype=ab_dtype,
sf_dtype=sf_dtype,
c_dtype=c_dtype,
sf_vec_size=sf_vec_size,
alpha=w1_alpha.view(1, 1, num_experts),
alpha_dtype=get_cute_dtype(w1_alpha),
) # in logical [m, n, l]
# SILU and quantization
diq, diq_sf = silu_and_mul_scaled_fp4_grouped_quant(
gateup_output.permute(2, 0, 1),
a2_global_scale,
masked_m,
)
# Gemm2
out = torch.empty_like(hidden_states)
out = out.permute(1, 2, 0) # requirement of kernel
grouped_gemm_nt_masked(
(diq, diq_sf),
(w2.permute(1, 2, 0), w2_blockscale),
out,
masked_m,
ab_dtype=ab_dtype,
sf_dtype=sf_dtype,
c_dtype=c_dtype,
sf_vec_size=sf_vec_size,
alpha=w2_alpha.view(1, 1, num_experts),
alpha_dtype=get_cute_dtype(w2_alpha),
) # in logical [m, k, l]
return out.permute(2, 0, 1)
......@@ -508,7 +508,8 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
hidden_states, masked_m, event, hook = self._dispatch_core(
hidden_states,
topk_idx,
use_fp8=True,
# TODO(shuw): pending https://github.com/deepseek-ai/DeepEP/pull/341
use_fp8=not get_bool_env_var("SGLANG_DEEPEP_BF16_DISPATCH"),
)
return (
hidden_states,
......
......@@ -49,6 +49,7 @@ class MoeRunnerBackend(Enum):
FLASHINFER = "flashinfer_trtllm"
FLASHINFER_CUTLASS = "flashinfer_cutlass"
FLASHINFER_MXFP4 = "flashinfer_mxfp4"
FLASHINFER_CUTEDSL = "flashinfer_cutedsl"
def is_auto(self):
return self == MoeRunnerBackend.AUTO
......@@ -65,6 +66,9 @@ class MoeRunnerBackend(Enum):
def is_flashinfer_cutlass(self):
return self == MoeRunnerBackend.FLASHINFER_CUTLASS
def is_flashinfer_cutedsl(self):
return self == MoeRunnerBackend.FLASHINFER_CUTEDSL
def is_flashinfer_mxfp4(self):
return self == MoeRunnerBackend.FLASHINFER_MXFP4
......
......@@ -878,6 +878,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
"""Access the global enable_flashinfer_cutlass_moe setting."""
return get_moe_runner_backend().is_flashinfer_cutlass()
@property
def enable_flashinfer_cutedsl_moe(self) -> bool:
from sglang.srt.layers.moe import get_moe_runner_backend
"""Access the global enable_flashinfer_cutedsl_moe setting."""
return get_moe_runner_backend().is_flashinfer_cutedsl()
def create_weights(
self,
layer: torch.nn.Module,
......@@ -1398,5 +1405,38 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
).to(x.dtype)
# Scale by routed_scaling_factor is fused into select_experts.
return StandardCombineInput(hidden_states=output)
def apply_without_routing_weights(
self,
layer: FusedMoE,
x: torch.Tensor,
masked_m: torch.Tensor,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
assert (
moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
assert self.enable_flashinfer_cutedsl_moe, "only support flashinfer cutedsl moe"
assert (
not moe_runner_config.apply_router_weight_on_input
), "apply_router_weight_on_input is not supported for Flashinfer"
from sglang.srt.layers.moe.flashinfer_cutedsl_moe import (
flashinfer_cutedsl_moe_masked,
)
out = flashinfer_cutedsl_moe_masked(
hidden_states=x,
input_global_scale=layer.w13_input_scale_quant,
w1=layer.w13_weight,
w1_blockscale=layer.w13_blockscale_swizzled,
w1_alpha=layer.g1_alphas,
w2=layer.w2_weight,
a2_global_scale=layer.w2_input_scale_quant,
w2_blockscale=layer.w2_blockscale_swizzled,
w2_alpha=layer.g2_alphas,
masked_m=masked_m,
)
return out
......@@ -673,10 +673,14 @@ class DeepseekV2MoE(nn.Module):
if shared_output is not None:
x = shared_output
x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
if self.experts.should_fuse_routed_scaling_factor_in_topk():
x.add_(final_hidden_states)
else:
x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
final_hidden_states = x
else:
final_hidden_states *= self.routed_scaling_factor
if not self.experts.should_fuse_routed_scaling_factor_in_topk():
final_hidden_states *= self.routed_scaling_factor
return final_hidden_states
......
......@@ -399,6 +399,7 @@ class ServerArgs:
enable_ep_moe: bool = False
enable_deepep_moe: bool = False
enable_flashinfer_cutlass_moe: bool = False
enable_flashinfer_cutedsl_moe: bool = False
enable_flashinfer_trtllm_moe: bool = False
enable_triton_kernel_moe: bool = False
enable_flashinfer_mxfp4_moe: bool = False
......@@ -420,6 +421,11 @@ class ServerArgs:
print_deprecated_warning(
"NOTE: --enable-triton-kernel-moe is deprecated. Please set `--moe-runner-backend` to 'triton_kernel' instead."
)
if self.enable_flashinfer_cutedsl_moe:
self.moe_runner_backend = "flashinfer_cutedsl"
print_deprecated_warning(
"NOTE: --enable-flashinfer-cutedsl-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutedsl' instead."
)
if self.enable_flashinfer_cutlass_moe:
self.moe_runner_backend = "flashinfer_cutlass"
print_deprecated_warning(
......@@ -1622,6 +1628,7 @@ class ServerArgs:
"flashinfer_trtllm",
"flashinfer_cutlass",
"flashinfer_mxfp4",
"flashinfer_cutedsl",
],
default=ServerArgs.moe_runner_backend,
help="Choose the runner backend for MoE.",
......@@ -2204,6 +2211,11 @@ class ServerArgs:
action="store_true",
help="(Deprecated) Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP",
)
parser.add_argument(
"--enable-flashinfer-cutedsl-moe",
action="store_true",
help="(Deprecated) Enable FlashInfer CuteDSL MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP",
)
parser.add_argument(
"--enable-flashinfer-trtllm-moe",
action="store_true",
......
......@@ -3,12 +3,15 @@ from typing import Callable
import pytest
import torch
from flashinfer import fp4_quantize
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
from sgl_kernel import scaled_fp4_quant
from sgl_kernel import scaled_fp4_grouped_quant, scaled_fp4_quant
from torch.nn import functional as F
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
from sglang.srt.layers.moe.flashinfer_cutedsl_moe import flashinfer_cutedsl_moe_masked
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
if torch.cuda.get_device_capability() < (10, 0):
......@@ -78,6 +81,37 @@ def break_fp4_bytes(a, dtype):
return values.reshape(m, n * 2).to(dtype=dtype)
def compute_routing(router_logits: torch.Tensor, top_k: int):
routing_weights = torch.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.float()
return routing_weights, selected_experts
def prepare_inputs(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
num_experts: int,
topk: int,
):
routing_weights, topk_idx = compute_routing(router_logits, topk)
masked_m = []
for i in range(num_experts):
mask = topk_idx.view(-1) == i
masked_m.append(mask.sum())
masked_m = torch.tensor(masked_m, dtype=torch.int32)
hidden_states_3d = torch.empty(
(num_experts, max(masked_m), hidden_states.shape[1]), dtype=hidden_states.dtype
)
for i in range(num_experts):
hidden_states_3d[i, : masked_m[i], :] = hidden_states[topk_idx.view(-1) == i]
return hidden_states_3d, masked_m, topk_idx, routing_weights
MNK_FACTORS = [
(2, 1024, 1024),
(2, 1024, 1536),
......@@ -114,6 +148,99 @@ def torch_moe(a, w1, w2, score, topk, expert_map):
).sum(dim=1)
def torch_moe_nvfp4(a, w1, w2, topk, topk_weight, topk_ids):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
m = w1[i].shape[0]
assert m % 2 == 0
# Note: w1 and w3 are swapped!
w3_expert, w1_expert = w1[i][m // 2 :, :], w1[i][: m // 2, :]
inter = F.silu(a[mask] @ w1_expert.t()) * (a[mask] @ w3_expert.t())
inter_gs = torch.tensor(1.0).cuda()
inter_q, inter_blockscale = fp4_quantize(inter, inter_gs)
inter = dequantize_nvfp4_to_dtype(
inter_q,
inter_blockscale,
inter_gs,
dtype=inter.dtype,
device=inter.device,
block_size=16,
).cuda()
out[mask] = inter @ w2[i].transpose(0, 1)
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
def flashinfer_cutedsl_grouped_gemm_nt_masked(
hidden_states: torch.Tensor, # 3d
input_global_scale: torch.Tensor, # (l,)
weights: torch.Tensor,
w_global_scale: torch.Tensor, # (l,)
masked_m: torch.Tensor,
):
from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked
# hidden_states: [l, m, k]
# weights: [l, n, k]
aq, aq_sf = scaled_fp4_grouped_quant(
hidden_states,
input_global_scale,
masked_m.to(hidden_states.device),
)
num_experts, n, k = weights.shape
bq, bq_sf = scaled_fp4_grouped_quant(
weights,
w_global_scale,
torch.ones(num_experts, device=weights.device, dtype=torch.int32) * n,
)
out = torch.zeros(
(num_experts, max(masked_m), n), dtype=weights.dtype, device=aq.device
)
out = out.permute(1, 2, 0) # requirement of kernel
sf_vec_size = 16
ab_dtype = "float4_e2m1fn"
sf_dtype = "float8_e4m3fn"
c_dtype = "bfloat16"
alpha = 1.0 / (input_global_scale * w_global_scale).to(out.dtype).view(
1, 1, num_experts
)
def get_cute_dtype(input: torch.Tensor) -> str:
if input.dtype == torch.bfloat16:
return "bfloat16"
elif input.dtype == torch.float16:
return "float16"
elif input.dtype == torch.float32:
return "float32"
else:
raise ValueError(f"Unsupported cute dtype {input.dtype}")
grouped_gemm_nt_masked(
(aq, aq_sf),
(bq, bq_sf),
out,
masked_m.to(aq.device),
ab_dtype=ab_dtype,
sf_dtype=sf_dtype,
c_dtype=c_dtype,
sf_vec_size=sf_vec_size,
alpha=alpha,
alpha_dtype=get_cute_dtype(alpha),
)
return out
def check_moe(
m: int,
n: int,
......@@ -324,6 +451,248 @@ def test_flashinfer_fp4_moe_no_graph(
check_moe(m, n, k, e, topk, dtype, flashinfer_moe_impl, flip_w13=True)
@pytest.mark.parametrize("bs, hidden_dim, inter_dim", [(2, 128, 256), (16, 128, 512)])
@pytest.mark.parametrize("topk", [1, 2, 4])
@torch.inference_mode()
def test_flashinfer_cutedsl_moe_masked(
bs: int, hidden_dim: int, inter_dim: int, topk: int
):
torch.manual_seed(42)
device = "cuda"
dtype = torch.bfloat16
num_experts = 8
hidden_states = (
torch.randn(bs, hidden_dim, dtype=torch.bfloat16, device=device) / 5.0
)
w1 = (
torch.randn(
num_experts, 2 * inter_dim, hidden_dim, dtype=torch.bfloat16, device=device
)
/ 10.0
)
w2 = (
torch.randn(
num_experts, hidden_dim, inter_dim, dtype=torch.bfloat16, device=device
)
/ 10.0
)
router_logits = torch.randn(bs, num_experts, dtype=torch.float32)
hidden_states_expanded = (
hidden_states.view(bs, -1, hidden_dim)
.repeat(1, topk, 1)
.reshape(-1, hidden_dim)
)
hidden_states_3d, masked_m, topk_idx, routing_weights = prepare_inputs(
hidden_states_expanded, router_logits, num_experts, topk
)
w1_amax = w1.abs().amax(dim=(1, 2)).to(torch.float32).to(w1.device)
w2_amax = w2.abs().amax(dim=(1, 2)).to(torch.float32).to(w2.device)
input_global_scale = torch.ones(
(num_experts,), dtype=torch.float32, device=hidden_states.device
)
w1_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
w2_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
a2_global_scale = torch.ones(
(num_experts,), dtype=torch.float32, device=hidden_states.device
) # assume intermediate scale is 1.0
w1_fp4, w1_blockscale = scaled_fp4_grouped_quant(
w1,
w1_global_scale,
torch.ones(num_experts, dtype=torch.int32, device=w1.device) * 2 * inter_dim,
)
w2_fp4, w2_blockscale = scaled_fp4_grouped_quant(
w2,
w2_global_scale,
torch.ones(num_experts, dtype=torch.int32, device=w2.device) * hidden_dim,
)
w1_alpha = 1.0 / (input_global_scale * w1_global_scale)
w2_alpha = 1.0 / (a2_global_scale * w2_global_scale)
out = flashinfer_cutedsl_moe_masked(
hidden_states_3d.to(hidden_states.device),
input_global_scale,
w1_fp4.permute(2, 0, 1),
w1_blockscale,
w1_alpha,
w2_fp4.permute(2, 0, 1),
a2_global_scale,
w2_blockscale,
w2_alpha,
masked_m.to(hidden_states.device),
)
# reference
a_fp4, a_scale_interleaved = fp4_quantize(hidden_states, input_global_scale)
a_in_dtype = dequantize_nvfp4_to_dtype(
a_fp4,
a_scale_interleaved,
input_global_scale,
dtype=hidden_states.dtype,
device=hidden_states.device,
block_size=16,
)
w1_d = torch.empty(
(num_experts, 2 * inter_dim, hidden_dim), device=w1.device, dtype=w1.dtype
)
w2_d = torch.empty(
(num_experts, hidden_dim, inter_dim), device=w2.device, dtype=w2.dtype
)
for idx in range(0, num_experts):
w1_fp4_sliced, w1_blockscale_sliced = fp4_quantize(
w1[idx], w1_global_scale[idx]
)
w2_fp4_sliced, w2_blockscale_sliced = fp4_quantize(
w2[idx], w2_global_scale[idx]
)
w1_d[idx] = dequantize_nvfp4_to_dtype(
w1_fp4_sliced,
w1_blockscale_sliced,
w1_global_scale[idx],
dtype=w1.dtype,
device=w1.device,
block_size=16,
)
w2_d[idx] = dequantize_nvfp4_to_dtype(
w2_fp4_sliced,
w2_blockscale_sliced,
w2_global_scale[idx],
dtype=w2.dtype,
device=w2.device,
block_size=16,
)
ref_output = torch_moe_nvfp4(
a_in_dtype,
w1_d,
w2_d,
topk,
routing_weights.to(a_in_dtype.device),
topk_idx.to(a_in_dtype.device),
)
out_weighted = torch.zeros_like(ref_output, device=out.device, dtype=out.dtype)
positions = torch.nonzero(masked_m[topk_idx], as_tuple=False)
rows, cols = positions[:, 0], positions[:, 1]
experts = topk_idx[rows, cols]
for i in range(num_experts):
mask = experts == i
if mask.any():
idx = torch.nonzero(mask, as_tuple=False).squeeze(-1)
r, c = rows[idx], cols[idx]
out_weighted[r] += out[i, : len(r), :] * routing_weights[r, c].to(
out.device
).unsqueeze(-1)
torch.testing.assert_close(
out_weighted.cpu(), ref_output.cpu(), atol=5e-2, rtol=5e-2
)
@pytest.mark.parametrize(
"bs, hidden_dim, inter_dim, topk", [(2, 128, 256, 2), (16, 128, 512, 5)]
)
@torch.inference_mode()
def test_grouped_gemm_nt_masked(
bs: int, hidden_dim: int, inter_dim: int, topk: int
) -> None:
torch.manual_seed(42)
B = bs
D = hidden_dim
N = inter_dim
num_experts = 8
hidden_states = torch.randn(B, D, dtype=torch.bfloat16, device="cuda")
weights = torch.randn(num_experts, N, D, dtype=torch.bfloat16, device="cuda")
router_logits = torch.randn(B, num_experts, dtype=torch.float32)
hidden_states_expanded = (
hidden_states.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
)
hidden_states_3d, masked_m, topk_idx, _ = prepare_inputs(
hidden_states_expanded, router_logits, num_experts, topk
)
# reference
out = torch.zeros(
(B * topk, weights.shape[1]), dtype=weights.dtype, device=weights.device
)
for i in range(num_experts):
mask = topk_idx.view(-1) == i
if mask.sum():
lhs = hidden_states_expanded[mask]
rhs = weights[i]
a_amax = lhs.abs().max().to(torch.float32).to(hidden_states.device)
b_amax = rhs.abs().amax().to(torch.float32).to(weights.device)
a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
lhsq, lhsq_sf = fp4_quantize(
lhs,
a_gs,
)
rhsq, rhsq_sf = fp4_quantize(
rhs,
b_gs,
)
lhs_in_dtype = dequantize_nvfp4_to_dtype(
lhsq,
lhsq_sf,
a_gs,
dtype=hidden_states.dtype,
device=hidden_states.device,
block_size=16,
)
rhs_in_dtype = dequantize_nvfp4_to_dtype(
rhsq,
rhsq_sf,
b_gs,
dtype=hidden_states.dtype,
device=hidden_states.device,
block_size=16,
)
out[mask] = lhs_in_dtype @ rhs_in_dtype.t()
a_amax = (
hidden_states_3d.abs()
.amax(dim=(1, 2))
.to(torch.float32)
.to(hidden_states.device)
)
b_amax = weights.abs().amax(dim=(1, 2)).to(torch.float32).to(weights.device)
a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
out_flashinfer = flashinfer_cutedsl_grouped_gemm_nt_masked(
hidden_states_3d.to(hidden_states.device), a_gs, weights, b_gs, masked_m
)
# re-pack out into [num_experts, max_m, n]
out_ref = torch.zeros(
(num_experts, max(masked_m), weights.shape[1]), dtype=out.dtype
)
expert_slot = [0] * num_experts
for i, expert_id in enumerate(topk_idx.view(-1).tolist()):
out_ref[expert_id, expert_slot[expert_id], :] = out[i]
expert_slot[expert_id] += 1
# Note: just to compare the masked position due to cutedsl may write nan
# into unmasked position.
for i in range(num_experts):
torch.testing.assert_close(
out_flashinfer.permute(2, 0, 1)[i, : masked_m[i]],
out_ref.to(out_flashinfer.device)[i, : masked_m[i]],
atol=1e-1,
rtol=5e-2,
)
if __name__ == "__main__":
test_cutlass_fp4_moe_no_graph(224, 1024, 1024, 256, 8, torch.half)
test_flashinfer_fp4_moe_no_graph(224, 1024, 1024, 256, 8, torch.half)
test_flashinfer_cutedsl_moe_masked(16, 128, 512, 4)
test_grouped_gemm_nt_masked(16, 128, 512, 4)
......@@ -53,6 +53,9 @@ DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instru
DEFAULT_MODEL_NAME_FOR_TEST_MLA = "lmsys/sglang-ci-dsv3-test"
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN = "lmsys/sglang-ci-dsv3-test-NextN"
# NVFP4 models
DEFAULT_DEEPSEEK_NVFP4_MODEL_FOR_TEST = "nvidia/DeepSeek-R1-0528-FP4"
# FP8 models
DEFAULT_MODEL_NAME_FOR_TEST_FP8 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8"
DEFAULT_MODEL_NAME_FOR_ACCURACY_TEST_FP8 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8"
......
import os
import unittest
from types import SimpleNamespace
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_DEEPSEEK_NVFP4_MODEL_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
try_cached_model,
)
class TestDeepseekR1Nvfp4CuteDSLDeepEP(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = try_cached_model(DEFAULT_DEEPSEEK_NVFP4_MODEL_FOR_TEST)
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = [
"--trust-remote-code",
"--disable-radix-cache",
"--max-running-requests",
"256",
"--chunked-prefill-size",
"2048",
"--tp",
"8",
"--dp",
"8",
"--enable-dp-attention",
"--enable-ep-moe",
"--quantization",
"modelopt_fp4",
"--enable-flashinfer-cutedsl-moe",
"--enable-deepep-moe",
"--deepep-mode",
"low_latency",
]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
env={
**os.environ,
"SGLANG_DEEPEP_BF16_DISPATCH": "1",
"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "256",
},
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=512,
parallel=512,
max_new_tokens=512,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"Eval accuracy of GSM8K: {metrics=}")
self.assertGreater(metrics["accuracy"], 0.92)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment