Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0b9dfba7
Unverified
Commit
0b9dfba7
authored
Oct 02, 2025
by
fzyzcjy
Committed by
GitHub
Oct 02, 2025
Browse files
Support dispatch low latency (#10263)
Co-authored-by:
Kaixi Hou
<
4001424+kaixih@users.noreply.github.com
>
parent
6a290034
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
80 additions
and
29 deletions
+80
-29
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+11
-0
python/sglang/srt/layers/moe/flashinfer_cutedsl_moe.py
python/sglang/srt/layers/moe/flashinfer_cutedsl_moe.py
+38
-25
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
+19
-3
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+11
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+1
-0
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
0b9dfba7
...
@@ -31,6 +31,10 @@ from sglang.srt.layers.quantization.fp8_kernel import (
...
@@ -31,6 +31,10 @@ from sglang.srt.layers.quantization.fp8_kernel import (
is_fp8_fnuz
,
is_fp8_fnuz
,
sglang_per_token_group_quant_fp8
,
sglang_per_token_group_quant_fp8
,
)
)
from
sglang.srt.layers.quantization.modelopt_quant
import
(
CUTEDSL_MOE_NVFP4_DISPATCH
,
ModelOptNvFp4FusedMoEMethod
,
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.offloader
import
get_offloader
from
sglang.srt.offloader
import
get_offloader
...
@@ -453,6 +457,13 @@ class DeepEPMoE(EPMoE):
...
@@ -453,6 +457,13 @@ class DeepEPMoE(EPMoE):
topk_idx
=
topk_idx
,
topk_idx
=
topk_idx
,
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
input_global_scale
=
(
self
.
w13_input_scale_quant
if
isinstance
(
self
.
quant_method
,
ModelOptNvFp4FusedMoEMethod
)
and
self
.
quant_method
.
enable_flashinfer_cutedsl_moe
and
CUTEDSL_MOE_NVFP4_DISPATCH
else
None
),
)
)
def
moe_impl
(
self
,
dispatch_output
:
DispatchOutput
):
def
moe_impl
(
self
,
dispatch_output
:
DispatchOutput
):
...
...
python/sglang/srt/layers/moe/flashinfer_cutedsl_moe.py
View file @
0b9dfba7
from
typing
import
Any
,
Dict
,
Optional
from
typing
import
Any
,
Dict
,
Optional
,
Union
import
torch
import
torch
from
flashinfer.cute_dsl.blockscaled_gemm
import
grouped_gemm_nt_masked
from
flashinfer.cute_dsl.blockscaled_gemm
import
grouped_gemm_nt_masked
...
@@ -20,7 +20,7 @@ def get_cute_dtype(input: torch.Tensor) -> str:
...
@@ -20,7 +20,7 @@ def get_cute_dtype(input: torch.Tensor) -> str:
def
flashinfer_cutedsl_moe_masked
(
def
flashinfer_cutedsl_moe_masked
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
input_global_scale
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1_blockscale
:
torch
.
Tensor
,
w1_blockscale
:
torch
.
Tensor
,
...
@@ -36,7 +36,9 @@ def flashinfer_cutedsl_moe_masked(
...
@@ -36,7 +36,9 @@ def flashinfer_cutedsl_moe_masked(
kernels.
kernels.
Args:
Args:
hidden_states (torch.Tensor): [num_experts, m, k], bf16
hidden_states: Either of the following case
* torch.Tensor: [num_experts, m, k], bf16
* tuple[torch.Tensor, torch.Tensor]: [num_experts, m, k // 2], uint8, [num_experts, m, k // 16], float8_e4m3fn
input_global_scale (torch.Tensor): (l,)
input_global_scale (torch.Tensor): (l,)
w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8
w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8
w1_blockscale (torch.Tensor): blockscale factors, e4m3,
w1_blockscale (torch.Tensor): blockscale factors, e4m3,
...
@@ -48,13 +50,10 @@ def flashinfer_cutedsl_moe_masked(
...
@@ -48,13 +50,10 @@ def flashinfer_cutedsl_moe_masked(
masked_m (torch.Tensor): Masked dimension indices
masked_m (torch.Tensor): Masked dimension indices
Notes:
Notes:
- Assumes max(masked_m)
<
= m.
- Assumes max(masked_m)
=
= m.
"""
"""
# === Assertions on dtypes ===
# === Assertions on dtypes ===
assert
(
input_global_scale
.
dtype
==
torch
.
float32
),
f
"input_global_scale must be float32, got
{
input_global_scale
.
dtype
}
"
assert
w1
.
dtype
==
torch
.
uint8
,
f
"w1 must be uint8 (fp4 packed), got
{
w1
.
dtype
}
"
assert
w1
.
dtype
==
torch
.
uint8
,
f
"w1 must be uint8 (fp4 packed), got
{
w1
.
dtype
}
"
assert
(
assert
(
w1_blockscale
.
dtype
==
torch
.
float8_e4m3fn
w1_blockscale
.
dtype
==
torch
.
float8_e4m3fn
...
@@ -75,7 +74,31 @@ def flashinfer_cutedsl_moe_masked(
...
@@ -75,7 +74,31 @@ def flashinfer_cutedsl_moe_masked(
# === Assertions on shapes ===
# === Assertions on shapes ===
n
=
w2
.
shape
[
-
1
]
*
2
# intermediate dimension
n
=
w2
.
shape
[
-
1
]
*
2
# intermediate dimension
num_experts
,
m
,
k
=
hidden_states
.
shape
if
isinstance
(
hidden_states
,
tuple
):
assert
(
input_global_scale
is
None
),
"input_global_scale is needed when input needs quant"
a_q
=
hidden_states
[
0
].
view
(
torch
.
uint8
)
a_q_sf
=
hidden_states
[
1
].
view
(
torch
.
float8_e4m3fn
)
m
,
k_by_2
,
num_experts
=
a_q
.
shape
k
=
k_by_2
*
2
else
:
num_experts
,
m
,
k
=
hidden_states
.
shape
assert
(
input_global_scale
.
dtype
==
torch
.
float32
),
f
"input_global_scale must be float32, got
{
input_global_scale
.
dtype
}
"
assert
input_global_scale
.
shape
==
(
num_experts
,
),
f
"input_global_scale must be (l,), got
{
input_global_scale
.
shape
}
"
a_q
,
a_q_sf
=
scaled_fp4_grouped_quant
(
hidden_states
,
input_global_scale
,
masked_m
,
)
assert
w1
.
shape
[
-
2
]
==
2
*
n
,
f
"w1 last-2 dim must be 2*n, got
{
w1
.
shape
}
"
assert
w1
.
shape
[
-
2
]
==
2
*
n
,
f
"w1 last-2 dim must be 2*n, got
{
w1
.
shape
}
"
assert
(
assert
(
...
@@ -85,10 +108,6 @@ def flashinfer_cutedsl_moe_masked(
...
@@ -85,10 +108,6 @@ def flashinfer_cutedsl_moe_masked(
k
,
k
,
n
//
2
,
n
//
2
,
),
f
"w2 shape mismatch, got
{
w2
.
shape
[
-
2
:]
}
, expected
{
(
k
,
n
//
2
)
}
"
),
f
"w2 shape mismatch, got
{
w2
.
shape
[
-
2
:]
}
, expected
{
(
k
,
n
//
2
)
}
"
assert
input_global_scale
.
shape
==
(
num_experts
,
),
f
"input_global_scale must be (l,), got
{
input_global_scale
.
shape
}
"
assert
w1_alpha
.
shape
==
(
assert
w1_alpha
.
shape
==
(
num_experts
,
num_experts
,
),
f
"w1_alpha must be (l,), got
{
w1_alpha
.
shape
}
"
),
f
"w1_alpha must be (l,), got
{
w1_alpha
.
shape
}
"
...
@@ -99,27 +118,21 @@ def flashinfer_cutedsl_moe_masked(
...
@@ -99,27 +118,21 @@ def flashinfer_cutedsl_moe_masked(
num_experts
,
num_experts
,
),
f
"w2_alpha must be (l,), got
{
w2_alpha
.
shape
}
"
),
f
"w2_alpha must be (l,), got
{
w2_alpha
.
shape
}
"
aq
,
aq_sf
=
scaled_fp4_grouped_quant
(
# TODO(kaixih@nvidia): dtype should be based on inputs.
hidden_states
,
input_global_scale
,
masked_m
,
)
gateup_output
=
torch
.
empty
(
gateup_output
=
torch
.
empty
(
(
num_experts
,
m
,
n
*
2
),
dtype
=
hidden_states
.
dtype
,
device
=
aq
.
device
(
num_experts
,
m
,
n
*
2
),
dtype
=
torch
.
bfloat16
,
device
=
a
_
q
.
device
)
)
gateup_output
=
gateup_output
.
permute
(
1
,
2
,
0
)
# requirement of kernel
gateup_output
=
gateup_output
.
permute
(
1
,
2
,
0
)
# requirement of kernel
sf_vec_size
=
16
sf_vec_size
=
16
assert
aq_sf
.
dtype
==
torch
.
float8_e4m3fn
assert
a
_
q_sf
.
dtype
==
torch
.
float8_e4m3fn
assert
aq
.
dtype
==
torch
.
uint8
assert
a
_
q
.
dtype
==
torch
.
uint8
ab_dtype
=
"float4_e2m1fn"
ab_dtype
=
"float4_e2m1fn"
sf_dtype
=
"float8_e4m3fn"
sf_dtype
=
"float8_e4m3fn"
c_dtype
=
"bfloat16"
c_dtype
=
get_cute_dtype
(
hidden_states
)
# Gemm1
# Gemm1
grouped_gemm_nt_masked
(
grouped_gemm_nt_masked
(
(
aq
,
aq_sf
),
(
a
_
q
,
a
_
q_sf
),
(
w1
.
permute
(
1
,
2
,
0
),
w1_blockscale
),
(
w1
.
permute
(
1
,
2
,
0
),
w1_blockscale
),
gateup_output
,
gateup_output
,
masked_m
,
masked_m
,
...
@@ -139,7 +152,7 @@ def flashinfer_cutedsl_moe_masked(
...
@@ -139,7 +152,7 @@ def flashinfer_cutedsl_moe_masked(
)
)
# Gemm2
# Gemm2
out
=
torch
.
empty
_like
(
hidden_states
)
out
=
torch
.
empty
((
num_experts
,
m
,
k
),
dtype
=
torch
.
bfloat16
,
device
=
a_q
.
device
)
out
=
out
.
permute
(
1
,
2
,
0
)
# requirement of kernel
out
=
out
.
permute
(
1
,
2
,
0
)
# requirement of kernel
grouped_gemm_nt_masked
(
grouped_gemm_nt_masked
(
(
diq
,
diq_sf
),
(
diq
,
diq_sf
),
...
...
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
View file @
0b9dfba7
...
@@ -296,6 +296,7 @@ class _DeepEPDispatcherImplBase:
...
@@ -296,6 +296,7 @@ class _DeepEPDispatcherImplBase:
def
dispatch_a
(
def
dispatch_a
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_global_scale
:
Optional
[
torch
.
Tensor
],
topk_idx
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
):
):
...
@@ -329,6 +330,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
...
@@ -329,6 +330,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
def
dispatch_a
(
def
dispatch_a
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_global_scale
:
Optional
[
torch
.
Tensor
],
topk_idx
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
):
):
...
@@ -505,6 +507,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
...
@@ -505,6 +507,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
def
dispatch_a
(
def
dispatch_a
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_global_scale
:
Optional
[
torch
.
Tensor
],
topk_idx
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
):
):
...
@@ -516,9 +519,8 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
...
@@ -516,9 +519,8 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
)
//
self
.
num_experts
)
//
self
.
num_experts
hidden_states
,
masked_m
,
event
,
hook
=
self
.
_dispatch_core
(
hidden_states
,
masked_m
,
event
,
hook
=
self
.
_dispatch_core
(
hidden_states
,
hidden_states
,
input_global_scale
,
topk_idx
,
topk_idx
,
# TODO(shuw): pending https://github.com/deepseek-ai/DeepEP/pull/341
use_fp8
=
not
get_bool_env_var
(
"SGLANG_DEEPEP_BF16_DISPATCH"
),
)
)
return
(
return
(
hidden_states
,
hidden_states
,
...
@@ -558,9 +560,15 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
...
@@ -558,9 +560,15 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
def
_dispatch_core
(
def
_dispatch_core
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_global_scale
:
Optional
[
torch
.
Tensor
],
topk_idx
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
use_fp8
:
bool
=
False
,
):
):
use_nvfp4
=
use_fp8
=
False
if
input_global_scale
is
not
None
:
use_nvfp4
=
True
elif
not
get_bool_env_var
(
"SGLANG_DEEPEP_BF16_DISPATCH"
):
use_fp8
=
True
buffer
=
self
.
_get_buffer
()
buffer
=
self
.
_get_buffer
()
packed_recv_hidden
,
packed_recv_count
,
self
.
handle
,
event
,
hook
=
(
packed_recv_hidden
,
packed_recv_count
,
self
.
handle
,
event
,
hook
=
(
buffer
.
low_latency_dispatch
(
buffer
.
low_latency_dispatch
(
...
@@ -569,6 +577,12 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
...
@@ -569,6 +577,12 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
self
.
num_max_dispatch_tokens_per_rank
,
self
.
num_max_dispatch_tokens_per_rank
,
self
.
num_experts
,
self
.
num_experts
,
use_fp8
=
use_fp8
,
use_fp8
=
use_fp8
,
**
(
dict
(
use_nvfp4
=
True
)
if
use_nvfp4
else
dict
()),
**
(
dict
(
x_global_scale
=
input_global_scale
)
if
input_global_scale
is
not
None
else
dict
()
),
async_finish
=
not
self
.
return_recv_hook
,
async_finish
=
not
self
.
return_recv_hook
,
return_recv_hook
=
self
.
return_recv_hook
,
return_recv_hook
=
self
.
return_recv_hook
,
round_scale
=
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
round_scale
=
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
...
@@ -682,6 +696,7 @@ class DeepEPDispatcher(BaseDispatcher):
...
@@ -682,6 +696,7 @@ class DeepEPDispatcher(BaseDispatcher):
def
dispatch_a
(
def
dispatch_a
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_global_scale
:
Optional
[
torch
.
Tensor
],
topk_idx
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
...
@@ -689,6 +704,7 @@ class DeepEPDispatcher(BaseDispatcher):
...
@@ -689,6 +704,7 @@ class DeepEPDispatcher(BaseDispatcher):
self
.
_update_stage
(
_Stage
.
INITIAL
,
_Stage
.
AFTER_DISPATCH_A
)
self
.
_update_stage
(
_Stage
.
INITIAL
,
_Stage
.
AFTER_DISPATCH_A
)
inner_state
=
self
.
_get_impl
(
forward_batch
).
dispatch_a
(
inner_state
=
self
.
_get_impl
(
forward_batch
).
dispatch_a
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
input_global_scale
=
input_global_scale
,
topk_idx
=
topk_idx
,
topk_idx
=
topk_idx
,
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
)
)
...
...
python/sglang/srt/layers/quantization/modelopt_quant.py
View file @
0b9dfba7
...
@@ -80,6 +80,10 @@ CUTEDSL_MOE_SCALAR_INPUT_SCALE = get_bool_env_var(
...
@@ -80,6 +80,10 @@ CUTEDSL_MOE_SCALAR_INPUT_SCALE = get_bool_env_var(
USE_CUTLASS_BACKEND_FOR_FP4_GEMM
=
get_bool_env_var
(
USE_CUTLASS_BACKEND_FOR_FP4_GEMM
=
get_bool_env_var
(
"SGLANG_USE_CUTLASS_BACKEND_FOR_FP4_GEMM"
"SGLANG_USE_CUTLASS_BACKEND_FOR_FP4_GEMM"
)
)
# TODO make it true by default when the DeepEP PR is merged
CUTEDSL_MOE_NVFP4_DISPATCH
=
get_bool_env_var
(
"SGLANG_CUTEDSL_MOE_NVFP4_DISPATCH"
,
"false"
)
# Supported activation schemes for the current configuration
# Supported activation schemes for the current configuration
ACTIVATION_SCHEMES
=
[
"static"
]
ACTIVATION_SCHEMES
=
[
"static"
]
...
@@ -1234,6 +1238,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
...
@@ -1234,6 +1238,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
w13_input_scale
=
_slice_scale
(
w13_input_scale
)
w13_input_scale
=
_slice_scale
(
w13_input_scale
)
w2_input_scale
=
_slice_scale
(
w2_input_scale
)
w2_input_scale
=
_slice_scale
(
w2_input_scale
)
if
CUTEDSL_MOE_NVFP4_DISPATCH
:
assert
torch
.
all
(
w13_input_scale
==
w13_input_scale
[
0
])
w13_input_scale
=
w13_input_scale
[
0
]
else
:
else
:
w13_input_scale
=
layer
.
w13_input_scale
.
max
(
dim
=
1
).
values
.
to
(
torch
.
float32
)
w13_input_scale
=
layer
.
w13_input_scale
.
max
(
dim
=
1
).
values
.
to
(
torch
.
float32
)
w2_input_scale
=
layer
.
w2_input_scale
w2_input_scale
=
layer
.
w2_input_scale
...
@@ -1476,7 +1484,9 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
...
@@ -1476,7 +1484,9 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
out
=
flashinfer_cutedsl_moe_masked
(
out
=
flashinfer_cutedsl_moe_masked
(
hidden_states
=
x
,
hidden_states
=
x
,
input_global_scale
=
layer
.
w13_input_scale_quant
,
input_global_scale
=
(
None
if
CUTEDSL_MOE_NVFP4_DISPATCH
else
layer
.
w13_input_scale_quant
),
w1
=
layer
.
w13_weight
,
w1
=
layer
.
w13_weight
,
w1_blockscale
=
layer
.
w13_blockscale_swizzled
,
w1_blockscale
=
layer
.
w13_blockscale_swizzled
,
w1_alpha
=
layer
.
g1_alphas
,
w1_alpha
=
layer
.
g1_alphas
,
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
0b9dfba7
...
@@ -896,6 +896,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -896,6 +896,7 @@ class DeepseekV2MoE(nn.Module):
if
self
.
ep_size
>
1
:
if
self
.
ep_size
>
1
:
self
.
experts
.
deepep_dispatcher
.
dispatch_a
(
self
.
experts
.
deepep_dispatcher
.
dispatch_a
(
hidden_states
=
state
.
hidden_states_mlp_input
,
hidden_states
=
state
.
hidden_states_mlp_input
,
input_global_scale
=
None
,
topk_idx
=
state
.
pop
(
"topk_idx_local"
),
topk_idx
=
state
.
pop
(
"topk_idx_local"
),
topk_weights
=
state
.
pop
(
"topk_weights_local"
),
topk_weights
=
state
.
pop
(
"topk_weights_local"
),
forward_batch
=
state
.
forward_batch
,
forward_batch
=
state
.
forward_batch
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment