Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1344ebc8
Unverified
Commit
1344ebc8
authored
Sep 19, 2025
by
Yi Zhang
Committed by
GitHub
Sep 18, 2025
Browse files
support qwen3-next-fp8 deepep (#10622)
parent
e07b21ce
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
93 additions
and
9 deletions
+93
-9
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+64
-1
python/sglang/srt/models/qwen3_next.py
python/sglang/srt/models/qwen3_next.py
+29
-8
No files found.
python/sglang/srt/models/qwen2_moe.py
View file @
1344ebc8
...
@@ -25,12 +25,14 @@ from torch import nn
...
@@ -25,12 +25,14 @@ from torch import nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
sglang.srt.distributed
import
(
from
sglang.srt.distributed
import
(
get_moe_expert_parallel_world_size
,
get_pp_group
,
get_pp_group
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
from
sglang.srt.eplb.expert_distribution
import
get_global_expert_distribution_recorder
from
sglang.srt.eplb.expert_distribution
import
get_global_expert_distribution_recorder
from
sglang.srt.eplb.expert_location
import
ModelConfigForExpertLocation
from
sglang.srt.eplb.expert_location
import
ModelConfigForExpertLocation
from
sglang.srt.eplb.expert_location_dispatch
import
ExpertLocationDispatchInfo
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.communicator
import
(
from
sglang.srt.layers.communicator
import
(
LayerCommunicator
,
LayerCommunicator
,
...
@@ -50,6 +52,7 @@ from sglang.srt.layers.linear import (
...
@@ -50,6 +52,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
RowParallelLinear
,
)
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe
import
get_moe_a2a_backend
from
sglang.srt.layers.moe.ep_moe.layer
import
get_moe_impl_class
from
sglang.srt.layers.moe.ep_moe.layer
import
get_moe_impl_class
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.moe.topk
import
TopK
...
@@ -82,6 +85,8 @@ class Qwen2MoeMLP(nn.Module):
...
@@ -82,6 +85,8 @@ class Qwen2MoeMLP(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
reduce_results
:
bool
=
True
,
reduce_results
:
bool
=
True
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
tp_rank
:
Optional
[
int
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
...
@@ -90,6 +95,8 @@ class Qwen2MoeMLP(nn.Module):
...
@@ -90,6 +95,8 @@ class Qwen2MoeMLP(nn.Module):
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"gate_up_proj"
,
prefix
),
prefix
=
add_prefix
(
"gate_up_proj"
,
prefix
),
tp_rank
=
tp_rank
,
tp_size
=
tp_size
,
)
)
self
.
down_proj
=
RowParallelLinear
(
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
intermediate_size
,
...
@@ -98,6 +105,8 @@ class Qwen2MoeMLP(nn.Module):
...
@@ -98,6 +105,8 @@ class Qwen2MoeMLP(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
reduce_results
=
reduce_results
,
reduce_results
=
reduce_results
,
prefix
=
add_prefix
(
"down_proj"
,
prefix
),
prefix
=
add_prefix
(
"down_proj"
,
prefix
),
tp_rank
=
tp_rank
,
tp_size
=
tp_size
,
)
)
if
hidden_act
!=
"silu"
:
if
hidden_act
!=
"silu"
:
raise
ValueError
(
raise
ValueError
(
...
@@ -146,7 +155,8 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
...
@@ -146,7 +155,8 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
self
.
experts
=
get_moe_impl_class
(
quant_config
)(
self
.
experts
=
get_moe_impl_class
(
quant_config
)(
layer_id
=
self
.
layer_id
,
layer_id
=
self
.
layer_id
,
top_k
=
config
.
num_experts_per_tok
,
top_k
=
config
.
num_experts_per_tok
,
num_experts
=
config
.
num_experts
,
num_experts
=
config
.
num_experts
+
global_server_args_dict
[
"ep_num_redundant_experts"
],
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
...
@@ -168,11 +178,31 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
...
@@ -168,11 +178,31 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
reduce_results
=
False
,
reduce_results
=
False
,
prefix
=
add_prefix
(
"shared_expert"
,
prefix
),
prefix
=
add_prefix
(
"shared_expert"
,
prefix
),
**
(
dict
(
tp_rank
=
0
,
tp_size
=
1
)
if
get_moe_a2a_backend
().
is_deepep
()
else
{}
),
)
)
else
:
else
:
self
.
shared_expert
=
None
self
.
shared_expert
=
None
self
.
shared_expert_gate
=
torch
.
nn
.
Linear
(
config
.
hidden_size
,
1
,
bias
=
False
)
self
.
shared_expert_gate
=
torch
.
nn
.
Linear
(
config
.
hidden_size
,
1
,
bias
=
False
)
if
get_moe_a2a_backend
().
is_deepep
():
# TODO: we will support tp < ep in the future
self
.
ep_size
=
get_moe_expert_parallel_world_size
()
self
.
num_experts
=
(
config
.
num_experts
+
global_server_args_dict
[
"ep_num_redundant_experts"
]
)
self
.
top_k
=
config
.
num_experts_per_tok
def
get_moe_weights
(
self
):
return
[
x
.
data
for
name
,
x
in
self
.
experts
.
named_parameters
()
if
name
not
in
[
"correction_bias"
]
]
def
_forward_shared_experts
(
self
,
hidden_states
:
torch
.
Tensor
):
def
_forward_shared_experts
(
self
,
hidden_states
:
torch
.
Tensor
):
shared_output
=
None
shared_output
=
None
if
self
.
shared_expert
is
not
None
:
if
self
.
shared_expert
is
not
None
:
...
@@ -183,6 +213,36 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
...
@@ -183,6 +213,36 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
)
)
return
shared_output
return
shared_output
def
_forward_deepep
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
):
shared_output
=
None
if
hidden_states
.
shape
[
0
]
>
0
:
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
topk_weights
,
topk_idx
,
_
=
self
.
topk
(
hidden_states
,
router_logits
,
num_token_non_padded
=
forward_batch
.
num_token_non_padded
,
expert_location_dispatch_info
=
ExpertLocationDispatchInfo
.
init_new
(
layer_id
=
self
.
layer_id
,
),
)
else
:
topk_weights
,
topk_idx
,
_
=
self
.
topk
.
empty_topk_output
(
hidden_states
.
device
)
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
topk_idx
=
topk_idx
,
topk_weights
=
topk_weights
,
forward_batch
=
forward_batch
,
)
if
shared_output
is
not
None
:
final_hidden_states
.
add_
(
shared_output
)
return
final_hidden_states
def
_forward_router_experts
(
self
,
hidden_states
:
torch
.
Tensor
):
def
_forward_router_experts
(
self
,
hidden_states
:
torch
.
Tensor
):
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
...
@@ -213,6 +273,9 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
...
@@ -213,6 +273,9 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
num_tokens
,
hidden_dim
=
hidden_states
.
shape
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
if
get_moe_a2a_backend
().
is_deepep
():
return
self
.
_forward_deepep
(
hidden_states
,
forward_batch
)
DUAL_STREAM_TOKEN_THRESHOLD
=
1024
DUAL_STREAM_TOKEN_THRESHOLD
=
1024
if
(
if
(
self
.
alt_stream
is
not
None
self
.
alt_stream
is
not
None
...
...
python/sglang/srt/models/qwen3_next.py
View file @
1344ebc8
...
@@ -13,6 +13,7 @@ from sglang.srt.distributed import (
...
@@ -13,6 +13,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
)
)
from
sglang.srt.eplb.expert_distribution
import
get_global_expert_distribution_recorder
from
sglang.srt.eplb.expert_location
import
ModelConfigForExpertLocation
from
sglang.srt.eplb.expert_location
import
ModelConfigForExpertLocation
from
sglang.srt.layers.attention.fla.layernorm_gated
import
RMSNorm
as
RMSNormGated
from
sglang.srt.layers.attention.fla.layernorm_gated
import
RMSNorm
as
RMSNormGated
from
sglang.srt.layers.attention.mamba.mamba
import
mamba_v2_sharded_weight_loader
from
sglang.srt.layers.attention.mamba.mamba
import
mamba_v2_sharded_weight_loader
...
@@ -46,7 +47,14 @@ from sglang.srt.model_loader.weight_utils import (
...
@@ -46,7 +47,14 @@ from sglang.srt.model_loader.weight_utils import (
sharded_weight_loader
,
sharded_weight_loader
,
)
)
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeMLP
,
Qwen2MoeSparseMoeBlock
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeMLP
,
Qwen2MoeSparseMoeBlock
from
sglang.srt.utils
import
add_prefix
,
is_cuda
,
is_npu
,
make_layers
,
set_weight_attrs
from
sglang.srt.utils
import
(
LazyValue
,
add_prefix
,
is_cuda
,
is_npu
,
make_layers
,
set_weight_attrs
,
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
...
@@ -849,6 +857,7 @@ class Qwen3NextModel(nn.Module):
...
@@ -849,6 +857,7 @@ class Qwen3NextModel(nn.Module):
residual
=
None
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
with
get_global_expert_distribution_recorder
().
with_current_layer
(
i
):
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
layer_id
=
i
,
layer_id
=
i
,
positions
=
positions
,
positions
=
positions
,
...
@@ -901,6 +910,18 @@ class Qwen3NextForCausalLM(nn.Module):
...
@@ -901,6 +910,18 @@ class Qwen3NextForCausalLM(nn.Module):
self
.
lm_head
=
self
.
lm_head
.
float
()
self
.
lm_head
=
self
.
lm_head
.
float
()
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
_routed_experts_weights_of_layer
=
LazyValue
(
lambda
:
{
layer_id
:
layer
.
mlp
.
get_moe_weights
()
for
layer_id
,
layer
in
enumerate
(
self
.
model
.
layers
)
if
isinstance
(
layer
.
mlp
,
Qwen2MoeSparseMoeBlock
)
}
)
@
property
def
routed_experts_weights_of_layer
(
self
):
return
self
.
_routed_experts_weights_of_layer
.
value
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
self
,
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment