Unverified Commit 14127804 authored by Kaixi Hou's avatar Kaixi Hou Committed by GitHub
Browse files

[NVIDIA] Fix unit test of MoE and add it to nightly ci (#12709)

parent 82f39dc1
......@@ -215,6 +215,9 @@ suites = {
TestFile("batch_invariant/test_batch_invariant_ops.py", 10),
TestFile("test_deepseek_v3_deterministic.py", 240),
],
"nightly-4-gpu-b200": [
TestFile("test_fp4_moe.py", 300),
],
"nightly-8-gpu": [],
"__not_in_ci__": [
TestFile("ascend/test_ascend_w8a8_quantization.py"),
......
......@@ -5,10 +5,9 @@ import pytest
import torch
from flashinfer import fp4_quantize
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
from sgl_kernel import scaled_fp4_grouped_quant, scaled_fp4_quant
from sgl_kernel import scaled_fp4_grouped_quant, scaled_fp4_quant, silu_and_mul
from torch.nn import functional as F
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
from sglang.srt.layers.moe.flashinfer_cutedsl_moe import flashinfer_cutedsl_moe_masked
......@@ -140,7 +139,7 @@ def torch_moe(a, w1, w2, score, topk, expert_map):
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(
out[mask] = silu_and_mul(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(
0, 1
)
return (
......@@ -451,248 +450,6 @@ def test_flashinfer_fp4_moe_no_graph(
check_moe(m, n, k, e, topk, dtype, flashinfer_moe_impl, flip_w13=True)
@pytest.mark.parametrize("bs, hidden_dim, inter_dim", [(2, 128, 256), (16, 128, 512)])
@pytest.mark.parametrize("topk", [1, 2, 4])
@torch.inference_mode()
def test_flashinfer_cutedsl_moe_masked(
bs: int, hidden_dim: int, inter_dim: int, topk: int
):
torch.manual_seed(42)
device = "cuda"
dtype = torch.bfloat16
num_experts = 8
hidden_states = (
torch.randn(bs, hidden_dim, dtype=torch.bfloat16, device=device) / 5.0
)
w1 = (
torch.randn(
num_experts, 2 * inter_dim, hidden_dim, dtype=torch.bfloat16, device=device
)
/ 10.0
)
w2 = (
torch.randn(
num_experts, hidden_dim, inter_dim, dtype=torch.bfloat16, device=device
)
/ 10.0
)
router_logits = torch.randn(bs, num_experts, dtype=torch.float32)
hidden_states_expanded = (
hidden_states.view(bs, -1, hidden_dim)
.repeat(1, topk, 1)
.reshape(-1, hidden_dim)
)
hidden_states_3d, masked_m, topk_idx, routing_weights = prepare_inputs(
hidden_states_expanded, router_logits, num_experts, topk
)
w1_amax = w1.abs().amax(dim=(1, 2)).to(torch.float32).to(w1.device)
w2_amax = w2.abs().amax(dim=(1, 2)).to(torch.float32).to(w2.device)
input_global_scale = torch.ones(
(num_experts,), dtype=torch.float32, device=hidden_states.device
)
w1_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
w2_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
a2_global_scale = torch.ones(
(num_experts,), dtype=torch.float32, device=hidden_states.device
) # assume intermediate scale is 1.0
w1_fp4, w1_blockscale = scaled_fp4_grouped_quant(
w1,
w1_global_scale,
torch.ones(num_experts, dtype=torch.int32, device=w1.device) * 2 * inter_dim,
)
w2_fp4, w2_blockscale = scaled_fp4_grouped_quant(
w2,
w2_global_scale,
torch.ones(num_experts, dtype=torch.int32, device=w2.device) * hidden_dim,
)
w1_alpha = 1.0 / (input_global_scale * w1_global_scale)
w2_alpha = 1.0 / (a2_global_scale * w2_global_scale)
out = flashinfer_cutedsl_moe_masked(
hidden_states_3d.to(hidden_states.device),
input_global_scale,
w1_fp4.permute(2, 0, 1),
w1_blockscale,
w1_alpha,
w2_fp4.permute(2, 0, 1),
a2_global_scale,
w2_blockscale,
w2_alpha,
masked_m.to(hidden_states.device),
)
# reference
a_fp4, a_scale_interleaved = fp4_quantize(hidden_states, input_global_scale)
a_in_dtype = dequantize_nvfp4_to_dtype(
a_fp4,
a_scale_interleaved,
input_global_scale,
dtype=hidden_states.dtype,
device=hidden_states.device,
block_size=16,
)
w1_d = torch.empty(
(num_experts, 2 * inter_dim, hidden_dim), device=w1.device, dtype=w1.dtype
)
w2_d = torch.empty(
(num_experts, hidden_dim, inter_dim), device=w2.device, dtype=w2.dtype
)
for idx in range(0, num_experts):
w1_fp4_sliced, w1_blockscale_sliced = fp4_quantize(
w1[idx], w1_global_scale[idx]
)
w2_fp4_sliced, w2_blockscale_sliced = fp4_quantize(
w2[idx], w2_global_scale[idx]
)
w1_d[idx] = dequantize_nvfp4_to_dtype(
w1_fp4_sliced,
w1_blockscale_sliced,
w1_global_scale[idx],
dtype=w1.dtype,
device=w1.device,
block_size=16,
)
w2_d[idx] = dequantize_nvfp4_to_dtype(
w2_fp4_sliced,
w2_blockscale_sliced,
w2_global_scale[idx],
dtype=w2.dtype,
device=w2.device,
block_size=16,
)
ref_output = torch_moe_nvfp4(
a_in_dtype,
w1_d,
w2_d,
topk,
routing_weights.to(a_in_dtype.device),
topk_idx.to(a_in_dtype.device),
)
out_weighted = torch.zeros_like(ref_output, device=out.device, dtype=out.dtype)
positions = torch.nonzero(masked_m[topk_idx], as_tuple=False)
rows, cols = positions[:, 0], positions[:, 1]
experts = topk_idx[rows, cols]
for i in range(num_experts):
mask = experts == i
if mask.any():
idx = torch.nonzero(mask, as_tuple=False).squeeze(-1)
r, c = rows[idx], cols[idx]
out_weighted[r] += out[i, : len(r), :] * routing_weights[r, c].to(
out.device
).unsqueeze(-1)
torch.testing.assert_close(
out_weighted.cpu(), ref_output.cpu(), atol=5e-2, rtol=5e-2
)
@pytest.mark.parametrize(
"bs, hidden_dim, inter_dim, topk", [(2, 128, 256, 2), (16, 128, 512, 5)]
)
@torch.inference_mode()
def test_grouped_gemm_nt_masked(
bs: int, hidden_dim: int, inter_dim: int, topk: int
) -> None:
torch.manual_seed(42)
B = bs
D = hidden_dim
N = inter_dim
num_experts = 8
hidden_states = torch.randn(B, D, dtype=torch.bfloat16, device="cuda")
weights = torch.randn(num_experts, N, D, dtype=torch.bfloat16, device="cuda")
router_logits = torch.randn(B, num_experts, dtype=torch.float32)
hidden_states_expanded = (
hidden_states.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
)
hidden_states_3d, masked_m, topk_idx, _ = prepare_inputs(
hidden_states_expanded, router_logits, num_experts, topk
)
# reference
out = torch.zeros(
(B * topk, weights.shape[1]), dtype=weights.dtype, device=weights.device
)
for i in range(num_experts):
mask = topk_idx.view(-1) == i
if mask.sum():
lhs = hidden_states_expanded[mask]
rhs = weights[i]
a_amax = lhs.abs().max().to(torch.float32).to(hidden_states.device)
b_amax = rhs.abs().amax().to(torch.float32).to(weights.device)
a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
lhsq, lhsq_sf = fp4_quantize(
lhs,
a_gs,
)
rhsq, rhsq_sf = fp4_quantize(
rhs,
b_gs,
)
lhs_in_dtype = dequantize_nvfp4_to_dtype(
lhsq,
lhsq_sf,
a_gs,
dtype=hidden_states.dtype,
device=hidden_states.device,
block_size=16,
)
rhs_in_dtype = dequantize_nvfp4_to_dtype(
rhsq,
rhsq_sf,
b_gs,
dtype=hidden_states.dtype,
device=hidden_states.device,
block_size=16,
)
out[mask] = lhs_in_dtype @ rhs_in_dtype.t()
a_amax = (
hidden_states_3d.abs()
.amax(dim=(1, 2))
.to(torch.float32)
.to(hidden_states.device)
)
b_amax = weights.abs().amax(dim=(1, 2)).to(torch.float32).to(weights.device)
a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
out_flashinfer = flashinfer_cutedsl_grouped_gemm_nt_masked(
hidden_states_3d.to(hidden_states.device), a_gs, weights, b_gs, masked_m
)
# re-pack out into [num_experts, max_m, n]
out_ref = torch.zeros(
(num_experts, max(masked_m), weights.shape[1]), dtype=out.dtype
)
expert_slot = [0] * num_experts
for i, expert_id in enumerate(topk_idx.view(-1).tolist()):
out_ref[expert_id, expert_slot[expert_id], :] = out[i]
expert_slot[expert_id] += 1
# Note: just to compare the masked position due to cutedsl may write nan
# into unmasked position.
for i in range(num_experts):
torch.testing.assert_close(
out_flashinfer.permute(2, 0, 1)[i, : masked_m[i]],
out_ref.to(out_flashinfer.device)[i, : masked_m[i]],
atol=1e-1,
rtol=5e-2,
)
if __name__ == "__main__":
test_cutlass_fp4_moe_no_graph(224, 1024, 1024, 256, 8, torch.half)
test_flashinfer_fp4_moe_no_graph(224, 1024, 1024, 256, 8, torch.half)
test_flashinfer_cutedsl_moe_masked(16, 128, 512, 4)
test_grouped_gemm_nt_masked(16, 128, 512, 4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment