Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3df05f4d
Unverified
Commit
3df05f4d
authored
Sep 11, 2025
by
Shu Wang
Committed by
GitHub
Sep 11, 2025
Browse files
[NVIDIA] [3/N] Nvfp4 Masked Gemm: Add flashinfer grouped_gemm_nt_masked (#9199)
parent
7b141f81
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
694 additions
and
5 deletions
+694
-5
docs/references/environment_variables.md
docs/references/environment_variables.md
+5
-0
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+18
-0
python/sglang/srt/layers/moe/flashinfer_cutedsl_moe.py
python/sglang/srt/layers/moe/flashinfer_cutedsl_moe.py
+156
-0
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
+2
-1
python/sglang/srt/layers/moe/utils.py
python/sglang/srt/layers/moe/utils.py
+4
-0
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+41
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+6
-2
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+12
-0
python/sglang/test/test_fp4_moe.py
python/sglang/test/test_fp4_moe.py
+370
-1
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+3
-0
test/srt/test_cutedsl_flashinfer_8gpu.py
test/srt/test_cutedsl_flashinfer_8gpu.py
+77
-0
No files found.
docs/references/environment_variables.md
View file @
3df05f4d
...
...
@@ -40,6 +40,11 @@ SGLang supports various environment variables that can be used to configure its
|
`SGL_DG_USE_NVRTC`
| Use NVRTC (instead of Triton) for JIT compilation (Experimental) |
`"0"`
|
|
`SGL_USE_DEEPGEMM_BMM`
| Use DeepGEMM for Batched Matrix Multiplication (BMM) operations |
`"false"`
|
## DeepEP Configuration
| Environment Variable | Description | Default Value |
|
`SGLANG_DEEPEP_BF16_DISPATCH`
| Use Bfloat16 for dispatch |
`"false"`
|
## Memory Management
| Environment Variable | Description | Default Value |
...
...
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
3df05f4d
...
...
@@ -459,6 +459,8 @@ class DeepEPMoE(EPMoE):
assert
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
and
self
.
use_fp8_w8a8
return
self
.
forward_deepgemm_contiguous
(
dispatch_output
)
elif
DispatchOutputChecker
.
format_is_deepep_ll
(
dispatch_output
):
if
get_moe_runner_backend
().
is_flashinfer_cutedsl
():
return
self
.
forward_flashinfer_cutedsl
(
dispatch_output
)
assert
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
and
self
.
use_fp8_w8a8
return
self
.
forward_deepgemm_masked
(
dispatch_output
)
else
:
...
...
@@ -638,6 +640,22 @@ class DeepEPMoE(EPMoE):
return
gather_out
def
forward_flashinfer_cutedsl
(
self
,
dispatch_output
:
DeepEPLLOutput
,
):
hidden_states
,
_
,
_
,
masked_m
,
_
=
dispatch_output
assert
self
.
quant_method
is
not
None
assert
self
.
moe_runner_config
.
activation
==
"silu"
output
=
self
.
quant_method
.
apply_without_routing_weights
(
layer
=
self
,
x
=
hidden_states
,
masked_m
=
masked_m
,
moe_runner_config
=
self
.
moe_runner_config
,
)
return
output
def
forward_deepgemm_masked
(
self
,
dispatch_output
:
DeepEPLLOutput
,
...
...
python/sglang/srt/layers/moe/flashinfer_cutedsl_moe.py
0 → 100644
View file @
3df05f4d
from
typing
import
Any
,
Dict
,
Optional
import
torch
from
flashinfer.cute_dsl.blockscaled_gemm
import
grouped_gemm_nt_masked
from
sgl_kernel.gemm
import
(
scaled_fp4_grouped_quant
,
silu_and_mul_scaled_fp4_grouped_quant
,
)
def
get_cute_dtype
(
input
:
torch
.
Tensor
)
->
str
:
if
input
.
dtype
==
torch
.
bfloat16
:
return
"bfloat16"
elif
input
.
dtype
==
torch
.
float16
:
return
"float16"
elif
input
.
dtype
==
torch
.
float32
:
return
"float32"
else
:
raise
ValueError
(
f
"Unsupported cute dtype
{
input
.
dtype
}
"
)
def
flashinfer_cutedsl_moe_masked
(
hidden_states
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1_blockscale
:
torch
.
Tensor
,
w1_alpha
,
w2
:
torch
.
Tensor
,
a2_global_scale
:
torch
.
Tensor
,
w2_blockscale
:
torch
.
Tensor
,
w2_alpha
,
masked_m
:
torch
.
Tensor
,
):
"""
Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL
kernels.
Args:
hidden_states (torch.Tensor): [num_experts, m, k], bf16
input_global_scale (torch.Tensor): (l,)
w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8
w1_blockscale (torch.Tensor): blockscale factors, e4m3,
w1_alpha (torch.Tensor): (l,)
w2 (torch.Tensor): fp4 weights, [l, k, n // 2], uint8
a2_global_scale (torch.Tensor): (l,)
w2_blockscale (torch.Tensor): blockscale factors, e4m3,
w2_alpha (torch.Tensor): (l,)
masked_m (torch.Tensor): Masked dimension indices
Notes:
- Assumes max(masked_m) <= m.
"""
# === Assertions on dtypes ===
assert
(
input_global_scale
.
dtype
==
torch
.
float32
),
f
"input_global_scale must be float32, got
{
input_global_scale
.
dtype
}
"
assert
w1
.
dtype
==
torch
.
uint8
,
f
"w1 must be uint8 (fp4 packed), got
{
w1
.
dtype
}
"
assert
(
w1_blockscale
.
dtype
==
torch
.
float8_e4m3fn
),
f
"w1_blockscale must be float8_e4m3fn, got
{
w1_blockscale
.
dtype
}
"
assert
(
w1_alpha
.
dtype
==
torch
.
float32
),
f
"w1_alpha must be float32, got
{
w1_alpha
.
dtype
}
"
assert
w2
.
dtype
==
torch
.
uint8
,
f
"w2 must be uint8 (fp4 packed), got
{
w2
.
dtype
}
"
assert
(
a2_global_scale
.
dtype
==
torch
.
float32
),
f
"a2_global_scale must be float32, got
{
a2_global_scale
.
dtype
}
"
assert
(
w2_blockscale
.
dtype
==
torch
.
float8_e4m3fn
),
f
"w2_blockscale must be float8_e4m3fn, got
{
w2_blockscale
.
dtype
}
"
assert
(
w2_alpha
.
dtype
==
torch
.
float32
),
f
"w2_alpha must be float32, got
{
w2_alpha
.
dtype
}
"
# === Assertions on shapes ===
n
=
w2
.
shape
[
-
1
]
*
2
# intermediate dimension
num_experts
,
m
,
k
=
hidden_states
.
shape
assert
w1
.
shape
[
-
2
]
==
2
*
n
,
f
"w1 last-2 dim must be 2*n, got
{
w1
.
shape
}
"
assert
(
w1
.
shape
[
-
1
]
*
2
==
k
),
f
"w1 last dim * 2 must equal k, got
{
w1
.
shape
[
-
1
]
}
vs k=
{
k
}
"
assert
w2
.
shape
[
-
2
:]
==
(
k
,
n
//
2
,
),
f
"w2 shape mismatch, got
{
w2
.
shape
[
-
2
:]
}
, expected
{
(
k
,
n
//
2
)
}
"
assert
input_global_scale
.
shape
==
(
num_experts
,
),
f
"input_global_scale must be (l,), got
{
input_global_scale
.
shape
}
"
assert
w1_alpha
.
shape
==
(
num_experts
,
),
f
"w1_alpha must be (l,), got
{
w1_alpha
.
shape
}
"
assert
a2_global_scale
.
shape
==
(
num_experts
,
),
f
"a2_global_scale must be (l,), got
{
a2_global_scale
.
shape
}
"
assert
w2_alpha
.
shape
==
(
num_experts
,
),
f
"w2_alpha must be (l,), got
{
w2_alpha
.
shape
}
"
aq
,
aq_sf
=
scaled_fp4_grouped_quant
(
hidden_states
,
input_global_scale
,
masked_m
,
)
gateup_output
=
torch
.
empty
(
(
num_experts
,
m
,
n
*
2
),
dtype
=
hidden_states
.
dtype
,
device
=
aq
.
device
)
gateup_output
=
gateup_output
.
permute
(
1
,
2
,
0
)
# requirement of kernel
sf_vec_size
=
16
assert
aq_sf
.
dtype
==
torch
.
float8_e4m3fn
assert
aq
.
dtype
==
torch
.
uint8
ab_dtype
=
"float4_e2m1fn"
sf_dtype
=
"float8_e4m3fn"
c_dtype
=
get_cute_dtype
(
hidden_states
)
# Gemm1
grouped_gemm_nt_masked
(
(
aq
,
aq_sf
),
(
w1
.
permute
(
1
,
2
,
0
),
w1_blockscale
),
gateup_output
,
masked_m
,
ab_dtype
=
ab_dtype
,
sf_dtype
=
sf_dtype
,
c_dtype
=
c_dtype
,
sf_vec_size
=
sf_vec_size
,
alpha
=
w1_alpha
.
view
(
1
,
1
,
num_experts
),
alpha_dtype
=
get_cute_dtype
(
w1_alpha
),
)
# in logical [m, n, l]
# SILU and quantization
diq
,
diq_sf
=
silu_and_mul_scaled_fp4_grouped_quant
(
gateup_output
.
permute
(
2
,
0
,
1
),
a2_global_scale
,
masked_m
,
)
# Gemm2
out
=
torch
.
empty_like
(
hidden_states
)
out
=
out
.
permute
(
1
,
2
,
0
)
# requirement of kernel
grouped_gemm_nt_masked
(
(
diq
,
diq_sf
),
(
w2
.
permute
(
1
,
2
,
0
),
w2_blockscale
),
out
,
masked_m
,
ab_dtype
=
ab_dtype
,
sf_dtype
=
sf_dtype
,
c_dtype
=
c_dtype
,
sf_vec_size
=
sf_vec_size
,
alpha
=
w2_alpha
.
view
(
1
,
1
,
num_experts
),
alpha_dtype
=
get_cute_dtype
(
w2_alpha
),
)
# in logical [m, k, l]
return
out
.
permute
(
2
,
0
,
1
)
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
View file @
3df05f4d
...
...
@@ -508,7 +508,8 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
hidden_states
,
masked_m
,
event
,
hook
=
self
.
_dispatch_core
(
hidden_states
,
topk_idx
,
use_fp8
=
True
,
# TODO(shuw): pending https://github.com/deepseek-ai/DeepEP/pull/341
use_fp8
=
not
get_bool_env_var
(
"SGLANG_DEEPEP_BF16_DISPATCH"
),
)
return
(
hidden_states
,
...
...
python/sglang/srt/layers/moe/utils.py
View file @
3df05f4d
...
...
@@ -49,6 +49,7 @@ class MoeRunnerBackend(Enum):
FLASHINFER
=
"flashinfer_trtllm"
FLASHINFER_CUTLASS
=
"flashinfer_cutlass"
FLASHINFER_MXFP4
=
"flashinfer_mxfp4"
FLASHINFER_CUTEDSL
=
"flashinfer_cutedsl"
def
is_auto
(
self
):
return
self
==
MoeRunnerBackend
.
AUTO
...
...
@@ -65,6 +66,9 @@ class MoeRunnerBackend(Enum):
def
is_flashinfer_cutlass
(
self
):
return
self
==
MoeRunnerBackend
.
FLASHINFER_CUTLASS
def
is_flashinfer_cutedsl
(
self
):
return
self
==
MoeRunnerBackend
.
FLASHINFER_CUTEDSL
def
is_flashinfer_mxfp4
(
self
):
return
self
==
MoeRunnerBackend
.
FLASHINFER_MXFP4
...
...
python/sglang/srt/layers/quantization/modelopt_quant.py
View file @
3df05f4d
...
...
@@ -878,6 +878,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
"""Access the global enable_flashinfer_cutlass_moe setting."""
return
get_moe_runner_backend
().
is_flashinfer_cutlass
()
@
property
def
enable_flashinfer_cutedsl_moe
(
self
)
->
bool
:
from
sglang.srt.layers.moe
import
get_moe_runner_backend
"""Access the global enable_flashinfer_cutedsl_moe setting."""
return
get_moe_runner_backend
().
is_flashinfer_cutedsl
()
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -1398,5 +1405,38 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input
=
moe_runner_config
.
apply_router_weight_on_input
,
).
to
(
x
.
dtype
)
# Scale by routed_scaling_factor is fused into select_experts.
return
StandardCombineInput
(
hidden_states
=
output
)
def
apply_without_routing_weights
(
self
,
layer
:
FusedMoE
,
x
:
torch
.
Tensor
,
masked_m
:
torch
.
Tensor
,
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
assert
(
moe_runner_config
.
activation
==
"silu"
),
"Only SiLU activation is supported."
assert
self
.
enable_flashinfer_cutedsl_moe
,
"only support flashinfer cutedsl moe"
assert
(
not
moe_runner_config
.
apply_router_weight_on_input
),
"apply_router_weight_on_input is not supported for Flashinfer"
from
sglang.srt.layers.moe.flashinfer_cutedsl_moe
import
(
flashinfer_cutedsl_moe_masked
,
)
out
=
flashinfer_cutedsl_moe_masked
(
hidden_states
=
x
,
input_global_scale
=
layer
.
w13_input_scale_quant
,
w1
=
layer
.
w13_weight
,
w1_blockscale
=
layer
.
w13_blockscale_swizzled
,
w1_alpha
=
layer
.
g1_alphas
,
w2
=
layer
.
w2_weight
,
a2_global_scale
=
layer
.
w2_input_scale_quant
,
w2_blockscale
=
layer
.
w2_blockscale_swizzled
,
w2_alpha
=
layer
.
g2_alphas
,
masked_m
=
masked_m
,
)
return
out
python/sglang/srt/models/deepseek_v2.py
View file @
3df05f4d
...
...
@@ -673,10 +673,14 @@ class DeepseekV2MoE(nn.Module):
if
shared_output
is
not
None
:
x
=
shared_output
x
.
add_
(
final_hidden_states
,
alpha
=
self
.
routed_scaling_factor
)
if
self
.
experts
.
should_fuse_routed_scaling_factor_in_topk
():
x
.
add_
(
final_hidden_states
)
else
:
x
.
add_
(
final_hidden_states
,
alpha
=
self
.
routed_scaling_factor
)
final_hidden_states
=
x
else
:
final_hidden_states
*=
self
.
routed_scaling_factor
if
not
self
.
experts
.
should_fuse_routed_scaling_factor_in_topk
():
final_hidden_states
*=
self
.
routed_scaling_factor
return
final_hidden_states
...
...
python/sglang/srt/server_args.py
View file @
3df05f4d
...
...
@@ -399,6 +399,7 @@ class ServerArgs:
enable_ep_moe
:
bool
=
False
enable_deepep_moe
:
bool
=
False
enable_flashinfer_cutlass_moe
:
bool
=
False
enable_flashinfer_cutedsl_moe
:
bool
=
False
enable_flashinfer_trtllm_moe
:
bool
=
False
enable_triton_kernel_moe
:
bool
=
False
enable_flashinfer_mxfp4_moe
:
bool
=
False
...
...
@@ -420,6 +421,11 @@ class ServerArgs:
print_deprecated_warning
(
"NOTE: --enable-triton-kernel-moe is deprecated. Please set `--moe-runner-backend` to 'triton_kernel' instead."
)
if
self
.
enable_flashinfer_cutedsl_moe
:
self
.
moe_runner_backend
=
"flashinfer_cutedsl"
print_deprecated_warning
(
"NOTE: --enable-flashinfer-cutedsl-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutedsl' instead."
)
if
self
.
enable_flashinfer_cutlass_moe
:
self
.
moe_runner_backend
=
"flashinfer_cutlass"
print_deprecated_warning
(
...
...
@@ -1622,6 +1628,7 @@ class ServerArgs:
"flashinfer_trtllm"
,
"flashinfer_cutlass"
,
"flashinfer_mxfp4"
,
"flashinfer_cutedsl"
,
],
default
=
ServerArgs
.
moe_runner_backend
,
help
=
"Choose the runner backend for MoE."
,
...
...
@@ -2204,6 +2211,11 @@ class ServerArgs:
action
=
"store_true"
,
help
=
"(Deprecated) Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP"
,
)
parser
.
add_argument
(
"--enable-flashinfer-cutedsl-moe"
,
action
=
"store_true"
,
help
=
"(Deprecated) Enable FlashInfer CuteDSL MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP"
,
)
parser
.
add_argument
(
"--enable-flashinfer-trtllm-moe"
,
action
=
"store_true"
,
...
...
python/sglang/test/test_fp4_moe.py
View file @
3df05f4d
...
...
@@ -3,12 +3,15 @@ from typing import Callable
import
pytest
import
torch
from
flashinfer
import
fp4_quantize
from
flashinfer.fused_moe
import
cutlass_fused_moe
as
flashinfer_cutlass_fused_moe
from
sgl_kernel
import
scaled_fp4_quant
from
sgl_kernel
import
scaled_fp4_grouped_quant
,
scaled_fp4_quant
from
torch.nn
import
functional
as
F
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.moe.cutlass_moe
import
cutlass_moe_fp4
from
sglang.srt.layers.moe.cutlass_moe_params
import
CutlassMoEParams
,
CutlassMoEType
from
sglang.srt.layers.moe.flashinfer_cutedsl_moe
import
flashinfer_cutedsl_moe_masked
from
sglang.srt.layers.moe.topk
import
TopKConfig
,
select_experts
if
torch
.
cuda
.
get_device_capability
()
<
(
10
,
0
):
...
...
@@ -78,6 +81,37 @@ def break_fp4_bytes(a, dtype):
return
values
.
reshape
(
m
,
n
*
2
).
to
(
dtype
=
dtype
)
def
compute_routing
(
router_logits
:
torch
.
Tensor
,
top_k
:
int
):
routing_weights
=
torch
.
softmax
(
router_logits
,
dim
=
1
,
dtype
=
torch
.
float
)
routing_weights
,
selected_experts
=
torch
.
topk
(
routing_weights
,
top_k
,
dim
=-
1
)
routing_weights
/=
routing_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
routing_weights
=
routing_weights
.
float
()
return
routing_weights
,
selected_experts
def
prepare_inputs
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
num_experts
:
int
,
topk
:
int
,
):
routing_weights
,
topk_idx
=
compute_routing
(
router_logits
,
topk
)
masked_m
=
[]
for
i
in
range
(
num_experts
):
mask
=
topk_idx
.
view
(
-
1
)
==
i
masked_m
.
append
(
mask
.
sum
())
masked_m
=
torch
.
tensor
(
masked_m
,
dtype
=
torch
.
int32
)
hidden_states_3d
=
torch
.
empty
(
(
num_experts
,
max
(
masked_m
),
hidden_states
.
shape
[
1
]),
dtype
=
hidden_states
.
dtype
)
for
i
in
range
(
num_experts
):
hidden_states_3d
[
i
,
:
masked_m
[
i
],
:]
=
hidden_states
[
topk_idx
.
view
(
-
1
)
==
i
]
return
hidden_states_3d
,
masked_m
,
topk_idx
,
routing_weights
MNK_FACTORS
=
[
(
2
,
1024
,
1024
),
(
2
,
1024
,
1536
),
...
...
@@ -114,6 +148,99 @@ def torch_moe(a, w1, w2, score, topk, expert_map):
).
sum
(
dim
=
1
)
def
torch_moe_nvfp4
(
a
,
w1
,
w2
,
topk
,
topk_weight
,
topk_ids
):
B
,
D
=
a
.
shape
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
out
=
torch
.
zeros
(
B
*
topk
,
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
topk_weight
=
topk_weight
.
view
(
-
1
)
topk_ids
=
topk_ids
.
view
(
-
1
)
for
i
in
range
(
w1
.
shape
[
0
]):
mask
=
topk_ids
==
i
if
mask
.
sum
():
m
=
w1
[
i
].
shape
[
0
]
assert
m
%
2
==
0
# Note: w1 and w3 are swapped!
w3_expert
,
w1_expert
=
w1
[
i
][
m
//
2
:,
:],
w1
[
i
][:
m
//
2
,
:]
inter
=
F
.
silu
(
a
[
mask
]
@
w1_expert
.
t
())
*
(
a
[
mask
]
@
w3_expert
.
t
())
inter_gs
=
torch
.
tensor
(
1.0
).
cuda
()
inter_q
,
inter_blockscale
=
fp4_quantize
(
inter
,
inter_gs
)
inter
=
dequantize_nvfp4_to_dtype
(
inter_q
,
inter_blockscale
,
inter_gs
,
dtype
=
inter
.
dtype
,
device
=
inter
.
device
,
block_size
=
16
,
).
cuda
()
out
[
mask
]
=
inter
@
w2
[
i
].
transpose
(
0
,
1
)
return
(
out
.
view
(
B
,
-
1
,
w2
.
shape
[
1
])
*
topk_weight
.
view
(
B
,
-
1
,
1
).
to
(
out
.
dtype
)
).
sum
(
dim
=
1
)
def
flashinfer_cutedsl_grouped_gemm_nt_masked
(
hidden_states
:
torch
.
Tensor
,
# 3d
input_global_scale
:
torch
.
Tensor
,
# (l,)
weights
:
torch
.
Tensor
,
w_global_scale
:
torch
.
Tensor
,
# (l,)
masked_m
:
torch
.
Tensor
,
):
from
flashinfer.cute_dsl.blockscaled_gemm
import
grouped_gemm_nt_masked
# hidden_states: [l, m, k]
# weights: [l, n, k]
aq
,
aq_sf
=
scaled_fp4_grouped_quant
(
hidden_states
,
input_global_scale
,
masked_m
.
to
(
hidden_states
.
device
),
)
num_experts
,
n
,
k
=
weights
.
shape
bq
,
bq_sf
=
scaled_fp4_grouped_quant
(
weights
,
w_global_scale
,
torch
.
ones
(
num_experts
,
device
=
weights
.
device
,
dtype
=
torch
.
int32
)
*
n
,
)
out
=
torch
.
zeros
(
(
num_experts
,
max
(
masked_m
),
n
),
dtype
=
weights
.
dtype
,
device
=
aq
.
device
)
out
=
out
.
permute
(
1
,
2
,
0
)
# requirement of kernel
sf_vec_size
=
16
ab_dtype
=
"float4_e2m1fn"
sf_dtype
=
"float8_e4m3fn"
c_dtype
=
"bfloat16"
alpha
=
1.0
/
(
input_global_scale
*
w_global_scale
).
to
(
out
.
dtype
).
view
(
1
,
1
,
num_experts
)
def
get_cute_dtype
(
input
:
torch
.
Tensor
)
->
str
:
if
input
.
dtype
==
torch
.
bfloat16
:
return
"bfloat16"
elif
input
.
dtype
==
torch
.
float16
:
return
"float16"
elif
input
.
dtype
==
torch
.
float32
:
return
"float32"
else
:
raise
ValueError
(
f
"Unsupported cute dtype
{
input
.
dtype
}
"
)
grouped_gemm_nt_masked
(
(
aq
,
aq_sf
),
(
bq
,
bq_sf
),
out
,
masked_m
.
to
(
aq
.
device
),
ab_dtype
=
ab_dtype
,
sf_dtype
=
sf_dtype
,
c_dtype
=
c_dtype
,
sf_vec_size
=
sf_vec_size
,
alpha
=
alpha
,
alpha_dtype
=
get_cute_dtype
(
alpha
),
)
return
out
def
check_moe
(
m
:
int
,
n
:
int
,
...
...
@@ -324,6 +451,248 @@ def test_flashinfer_fp4_moe_no_graph(
check_moe
(
m
,
n
,
k
,
e
,
topk
,
dtype
,
flashinfer_moe_impl
,
flip_w13
=
True
)
@
pytest
.
mark
.
parametrize
(
"bs, hidden_dim, inter_dim"
,
[(
2
,
128
,
256
),
(
16
,
128
,
512
)])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
1
,
2
,
4
])
@
torch
.
inference_mode
()
def
test_flashinfer_cutedsl_moe_masked
(
bs
:
int
,
hidden_dim
:
int
,
inter_dim
:
int
,
topk
:
int
):
torch
.
manual_seed
(
42
)
device
=
"cuda"
dtype
=
torch
.
bfloat16
num_experts
=
8
hidden_states
=
(
torch
.
randn
(
bs
,
hidden_dim
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
/
5.0
)
w1
=
(
torch
.
randn
(
num_experts
,
2
*
inter_dim
,
hidden_dim
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
/
10.0
)
w2
=
(
torch
.
randn
(
num_experts
,
hidden_dim
,
inter_dim
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
/
10.0
)
router_logits
=
torch
.
randn
(
bs
,
num_experts
,
dtype
=
torch
.
float32
)
hidden_states_expanded
=
(
hidden_states
.
view
(
bs
,
-
1
,
hidden_dim
)
.
repeat
(
1
,
topk
,
1
)
.
reshape
(
-
1
,
hidden_dim
)
)
hidden_states_3d
,
masked_m
,
topk_idx
,
routing_weights
=
prepare_inputs
(
hidden_states_expanded
,
router_logits
,
num_experts
,
topk
)
w1_amax
=
w1
.
abs
().
amax
(
dim
=
(
1
,
2
)).
to
(
torch
.
float32
).
to
(
w1
.
device
)
w2_amax
=
w2
.
abs
().
amax
(
dim
=
(
1
,
2
)).
to
(
torch
.
float32
).
to
(
w2
.
device
)
input_global_scale
=
torch
.
ones
(
(
num_experts
,),
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
w1_global_scale
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
w1_amax
w2_global_scale
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
w2_amax
a2_global_scale
=
torch
.
ones
(
(
num_experts
,),
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
# assume intermediate scale is 1.0
w1_fp4
,
w1_blockscale
=
scaled_fp4_grouped_quant
(
w1
,
w1_global_scale
,
torch
.
ones
(
num_experts
,
dtype
=
torch
.
int32
,
device
=
w1
.
device
)
*
2
*
inter_dim
,
)
w2_fp4
,
w2_blockscale
=
scaled_fp4_grouped_quant
(
w2
,
w2_global_scale
,
torch
.
ones
(
num_experts
,
dtype
=
torch
.
int32
,
device
=
w2
.
device
)
*
hidden_dim
,
)
w1_alpha
=
1.0
/
(
input_global_scale
*
w1_global_scale
)
w2_alpha
=
1.0
/
(
a2_global_scale
*
w2_global_scale
)
out
=
flashinfer_cutedsl_moe_masked
(
hidden_states_3d
.
to
(
hidden_states
.
device
),
input_global_scale
,
w1_fp4
.
permute
(
2
,
0
,
1
),
w1_blockscale
,
w1_alpha
,
w2_fp4
.
permute
(
2
,
0
,
1
),
a2_global_scale
,
w2_blockscale
,
w2_alpha
,
masked_m
.
to
(
hidden_states
.
device
),
)
# reference
a_fp4
,
a_scale_interleaved
=
fp4_quantize
(
hidden_states
,
input_global_scale
)
a_in_dtype
=
dequantize_nvfp4_to_dtype
(
a_fp4
,
a_scale_interleaved
,
input_global_scale
,
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
,
block_size
=
16
,
)
w1_d
=
torch
.
empty
(
(
num_experts
,
2
*
inter_dim
,
hidden_dim
),
device
=
w1
.
device
,
dtype
=
w1
.
dtype
)
w2_d
=
torch
.
empty
(
(
num_experts
,
hidden_dim
,
inter_dim
),
device
=
w2
.
device
,
dtype
=
w2
.
dtype
)
for
idx
in
range
(
0
,
num_experts
):
w1_fp4_sliced
,
w1_blockscale_sliced
=
fp4_quantize
(
w1
[
idx
],
w1_global_scale
[
idx
]
)
w2_fp4_sliced
,
w2_blockscale_sliced
=
fp4_quantize
(
w2
[
idx
],
w2_global_scale
[
idx
]
)
w1_d
[
idx
]
=
dequantize_nvfp4_to_dtype
(
w1_fp4_sliced
,
w1_blockscale_sliced
,
w1_global_scale
[
idx
],
dtype
=
w1
.
dtype
,
device
=
w1
.
device
,
block_size
=
16
,
)
w2_d
[
idx
]
=
dequantize_nvfp4_to_dtype
(
w2_fp4_sliced
,
w2_blockscale_sliced
,
w2_global_scale
[
idx
],
dtype
=
w2
.
dtype
,
device
=
w2
.
device
,
block_size
=
16
,
)
ref_output
=
torch_moe_nvfp4
(
a_in_dtype
,
w1_d
,
w2_d
,
topk
,
routing_weights
.
to
(
a_in_dtype
.
device
),
topk_idx
.
to
(
a_in_dtype
.
device
),
)
out_weighted
=
torch
.
zeros_like
(
ref_output
,
device
=
out
.
device
,
dtype
=
out
.
dtype
)
positions
=
torch
.
nonzero
(
masked_m
[
topk_idx
],
as_tuple
=
False
)
rows
,
cols
=
positions
[:,
0
],
positions
[:,
1
]
experts
=
topk_idx
[
rows
,
cols
]
for
i
in
range
(
num_experts
):
mask
=
experts
==
i
if
mask
.
any
():
idx
=
torch
.
nonzero
(
mask
,
as_tuple
=
False
).
squeeze
(
-
1
)
r
,
c
=
rows
[
idx
],
cols
[
idx
]
out_weighted
[
r
]
+=
out
[
i
,
:
len
(
r
),
:]
*
routing_weights
[
r
,
c
].
to
(
out
.
device
).
unsqueeze
(
-
1
)
torch
.
testing
.
assert_close
(
out_weighted
.
cpu
(),
ref_output
.
cpu
(),
atol
=
5e-2
,
rtol
=
5e-2
)
@
pytest
.
mark
.
parametrize
(
"bs, hidden_dim, inter_dim, topk"
,
[(
2
,
128
,
256
,
2
),
(
16
,
128
,
512
,
5
)]
)
@
torch
.
inference_mode
()
def
test_grouped_gemm_nt_masked
(
bs
:
int
,
hidden_dim
:
int
,
inter_dim
:
int
,
topk
:
int
)
->
None
:
torch
.
manual_seed
(
42
)
B
=
bs
D
=
hidden_dim
N
=
inter_dim
num_experts
=
8
hidden_states
=
torch
.
randn
(
B
,
D
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
weights
=
torch
.
randn
(
num_experts
,
N
,
D
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
router_logits
=
torch
.
randn
(
B
,
num_experts
,
dtype
=
torch
.
float32
)
hidden_states_expanded
=
(
hidden_states
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
)
hidden_states_3d
,
masked_m
,
topk_idx
,
_
=
prepare_inputs
(
hidden_states_expanded
,
router_logits
,
num_experts
,
topk
)
# reference
out
=
torch
.
zeros
(
(
B
*
topk
,
weights
.
shape
[
1
]),
dtype
=
weights
.
dtype
,
device
=
weights
.
device
)
for
i
in
range
(
num_experts
):
mask
=
topk_idx
.
view
(
-
1
)
==
i
if
mask
.
sum
():
lhs
=
hidden_states_expanded
[
mask
]
rhs
=
weights
[
i
]
a_amax
=
lhs
.
abs
().
max
().
to
(
torch
.
float32
).
to
(
hidden_states
.
device
)
b_amax
=
rhs
.
abs
().
amax
().
to
(
torch
.
float32
).
to
(
weights
.
device
)
a_gs
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
a_amax
b_gs
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
b_amax
lhsq
,
lhsq_sf
=
fp4_quantize
(
lhs
,
a_gs
,
)
rhsq
,
rhsq_sf
=
fp4_quantize
(
rhs
,
b_gs
,
)
lhs_in_dtype
=
dequantize_nvfp4_to_dtype
(
lhsq
,
lhsq_sf
,
a_gs
,
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
,
block_size
=
16
,
)
rhs_in_dtype
=
dequantize_nvfp4_to_dtype
(
rhsq
,
rhsq_sf
,
b_gs
,
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
,
block_size
=
16
,
)
out
[
mask
]
=
lhs_in_dtype
@
rhs_in_dtype
.
t
()
a_amax
=
(
hidden_states_3d
.
abs
()
.
amax
(
dim
=
(
1
,
2
))
.
to
(
torch
.
float32
)
.
to
(
hidden_states
.
device
)
)
b_amax
=
weights
.
abs
().
amax
(
dim
=
(
1
,
2
)).
to
(
torch
.
float32
).
to
(
weights
.
device
)
a_gs
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
a_amax
b_gs
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
b_amax
out_flashinfer
=
flashinfer_cutedsl_grouped_gemm_nt_masked
(
hidden_states_3d
.
to
(
hidden_states
.
device
),
a_gs
,
weights
,
b_gs
,
masked_m
)
# re-pack out into [num_experts, max_m, n]
out_ref
=
torch
.
zeros
(
(
num_experts
,
max
(
masked_m
),
weights
.
shape
[
1
]),
dtype
=
out
.
dtype
)
expert_slot
=
[
0
]
*
num_experts
for
i
,
expert_id
in
enumerate
(
topk_idx
.
view
(
-
1
).
tolist
()):
out_ref
[
expert_id
,
expert_slot
[
expert_id
],
:]
=
out
[
i
]
expert_slot
[
expert_id
]
+=
1
# Note: just to compare the masked position due to cutedsl may write nan
# into unmasked position.
for
i
in
range
(
num_experts
):
torch
.
testing
.
assert_close
(
out_flashinfer
.
permute
(
2
,
0
,
1
)[
i
,
:
masked_m
[
i
]],
out_ref
.
to
(
out_flashinfer
.
device
)[
i
,
:
masked_m
[
i
]],
atol
=
1e-1
,
rtol
=
5e-2
,
)
if
__name__
==
"__main__"
:
test_cutlass_fp4_moe_no_graph
(
224
,
1024
,
1024
,
256
,
8
,
torch
.
half
)
test_flashinfer_fp4_moe_no_graph
(
224
,
1024
,
1024
,
256
,
8
,
torch
.
half
)
test_flashinfer_cutedsl_moe_masked
(
16
,
128
,
512
,
4
)
test_grouped_gemm_nt_masked
(
16
,
128
,
512
,
4
)
python/sglang/test/test_utils.py
View file @
3df05f4d
...
...
@@ -53,6 +53,9 @@ DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instru
DEFAULT_MODEL_NAME_FOR_TEST_MLA
=
"lmsys/sglang-ci-dsv3-test"
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN
=
"lmsys/sglang-ci-dsv3-test-NextN"
# NVFP4 models
DEFAULT_DEEPSEEK_NVFP4_MODEL_FOR_TEST
=
"nvidia/DeepSeek-R1-0528-FP4"
# FP8 models
DEFAULT_MODEL_NAME_FOR_TEST_FP8
=
"neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8"
DEFAULT_MODEL_NAME_FOR_ACCURACY_TEST_FP8
=
"neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8"
...
...
test/srt/test_cutedsl_flashinfer_8gpu.py
0 → 100644
View file @
3df05f4d
import
os
import
unittest
from
types
import
SimpleNamespace
import
requests
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
from
sglang.test.test_utils
import
(
DEFAULT_DEEPSEEK_NVFP4_MODEL_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
popen_launch_server
,
try_cached_model
,
)
class
TestDeepseekR1Nvfp4CuteDSLDeepEP
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
try_cached_model
(
DEFAULT_DEEPSEEK_NVFP4_MODEL_FOR_TEST
)
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
other_args
=
[
"--trust-remote-code"
,
"--disable-radix-cache"
,
"--max-running-requests"
,
"256"
,
"--chunked-prefill-size"
,
"2048"
,
"--tp"
,
"8"
,
"--dp"
,
"8"
,
"--enable-dp-attention"
,
"--enable-ep-moe"
,
"--quantization"
,
"modelopt_fp4"
,
"--enable-flashinfer-cutedsl-moe"
,
"--enable-deepep-moe"
,
"--deepep-mode"
,
"low_latency"
,
]
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
other_args
,
env
=
{
**
os
.
environ
,
"SGLANG_DEEPEP_BF16_DISPATCH"
:
"1"
,
"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK"
:
"256"
,
},
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
num_shots
=
5
,
data_path
=
None
,
num_questions
=
512
,
parallel
=
512
,
max_new_tokens
=
512
,
host
=
"http://127.0.0.1"
,
port
=
int
(
self
.
base_url
.
split
(
":"
)[
-
1
]),
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
print
(
f
"Eval accuracy of GSM8K:
{
metrics
=
}
"
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.92
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment