Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e6f11356
Unverified
Commit
e6f11356
authored
May 24, 2025
by
Yi Zhang
Committed by
GitHub
May 23, 2025
Browse files
support eplb for qwen3 (#6533)
parent
7b02c326
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
46 additions
and
25 deletions
+46
-25
python/sglang/srt/layers/moe/topk.py
python/sglang/srt/layers/moe/topk.py
+4
-2
python/sglang/srt/managers/expert_distribution.py
python/sglang/srt/managers/expert_distribution.py
+3
-1
python/sglang/srt/models/qwen3_moe.py
python/sglang/srt/models/qwen3_moe.py
+39
-22
No files found.
python/sglang/srt/layers/moe/topk.py
View file @
e6f11356
...
...
@@ -65,6 +65,7 @@ def fused_topk(
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
expert_location_dispatch_info
:
Optional
[
ExpertLocationDispatchInfo
]
=
None
,
):
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
...
...
@@ -88,7 +89,7 @@ def fused_topk(
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
topk_ids
=
topk_ids_logical_to_physical
(
topk_ids
,
expert_location_dispatch_info
)
return
topk_weights
,
topk_ids
...
...
@@ -355,12 +356,13 @@ def select_experts(
assert
(
num_token_non_padded
is
None
),
"num_token_non_padded is not yet supported in fused_topk"
assert
expert_location_dispatch_info
is
None
# Qwen3MOE uses fused_topk
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
=
hidden_states
,
gating_output
=
router_logits
,
topk
=
top_k
,
renormalize
=
renormalize
,
expert_location_dispatch_info
=
expert_location_dispatch_info
,
)
else
:
assert
(
...
...
python/sglang/srt/managers/expert_distribution.py
View file @
e6f11356
...
...
@@ -690,7 +690,9 @@ def _convert_global_physical_count_to_logical_count(
)
logical_count
.
scatter_add_
(
dim
=
2
,
index
=
physical_to_logical_map
.
unsqueeze
(
0
).
expand
(
dim_extra
,
-
1
,
-
1
),
index
=
physical_to_logical_map
.
unsqueeze
(
0
)
.
expand
(
dim_extra
,
-
1
,
-
1
)
.
to
(
torch
.
int64
),
src
=
global_physical_count
,
)
return
logical_count
...
...
python/sglang/srt/models/qwen3_moe.py
View file @
e6f11356
...
...
@@ -55,7 +55,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.moe.ep_moe.layer
import
DeepEPMoE
,
EPMoE
from
sglang.srt.layers.moe.ep_moe.layer
import
get_moe_impl_class
from
sglang.srt.layers.moe.ep_moe.token_dispatcher
import
DeepEPDispatcher
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
select_experts
...
...
@@ -67,6 +67,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.managers.expert_location
import
ModelConfigForExpertLocation
from
sglang.srt.managers.expert_location_dispatch
import
ExpertLocationDispatchInfo
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
(
ForwardBatch
,
...
...
@@ -86,28 +88,25 @@ logger = logging.getLogger(__name__)
class
Qwen3MoeSparseMoeBlock
(
nn
.
Module
):
def
__init__
(
self
,
layer_id
:
int
,
config
:
Qwen3MoeConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
layer_id
=
layer_id
if
self
.
tp_size
>
config
.
num_experts
:
raise
ValueError
(
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"the number of experts
{
config
.
num_experts
}
."
)
MoEImpl
=
(
DeepEPMoE
if
global_server_args_dict
[
"enable_deepep_moe"
]
else
(
EPMoE
if
global_server_args_dict
[
"enable_ep_moe"
]
else
FusedMoE
)
)
self
.
experts
=
MoEImpl
(
num_experts
=
config
.
num_experts
,
self
.
experts
=
get_moe_impl_class
()(
num_experts
=
config
.
num_experts
+
global_server_args_dict
[
"ep_num_redundant_experts"
],
top_k
=
config
.
num_experts_per_tok
,
layer_id
=
layer_id
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
renormalize
=
config
.
norm_topk_prob
,
...
...
@@ -131,7 +130,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
if
global_server_args_dict
[
"enable_deepep_moe"
]:
# TODO: we will support tp < ep in the future
self
.
ep_size
=
get_tensor_model_parallel_world_size
()
self
.
num_experts
=
config
.
num_experts
self
.
num_experts
=
(
config
.
num_experts
+
global_server_args_dict
[
"ep_num_redundant_experts"
]
)
self
.
top_k
=
config
.
num_experts_per_tok
self
.
renormalize
=
config
.
norm_topk_prob
...
...
@@ -139,7 +140,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
group
=
parallel_state
.
get_tp_group
().
device_group
,
router_topk
=
self
.
top_k
,
permute_fusion
=
True
,
num_experts
=
config
.
num_experts
,
num_experts
=
self
.
num_experts
,
num_local_experts
=
config
.
num_experts
//
self
.
tp_size
,
hidden_size
=
config
.
hidden_size
,
params_dtype
=
config
.
torch_dtype
,
...
...
@@ -157,8 +158,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
else
:
return
self
.
forward_deepep
(
hidden_states
,
forward_mode
)
def
forward_normal
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_moe_weights
(
self
):
return
[
x
.
data
for
name
,
x
in
self
.
experts
.
named_parameters
()
if
name
not
in
[
"correction_bias"
]
]
def
forward_normal
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
...
...
@@ -189,6 +196,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
top_k
=
self
.
top_k
,
use_grouped_topk
=
False
,
renormalize
=
self
.
renormalize
,
expert_location_dispatch_info
=
ExpertLocationDispatchInfo
.
init_new
(
layer_id
=
self
.
layer_id
,
),
)
else
:
topk_idx
=
torch
.
full
(
...
...
@@ -408,6 +418,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
if
self
.
info
.
is_sparse
:
self
.
mlp
=
Qwen3MoeSparseMoeBlock
(
layer_id
=
self
.
layer_id
,
config
=
config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"mlp"
,
prefix
),
...
...
@@ -685,15 +696,7 @@ class Qwen3MoeForCausalLM(nn.Module):
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
MoEImpl
=
(
DeepEPMoE
if
global_server_args_dict
[
"enable_deepep_moe"
]
else
(
EPMoE
if
global_server_args_dict
[
"enable_ep_moe"
]
else
FusedMoE
)
)
expert_params_mapping
=
MoEImpl
.
make_expert_params_mapping
(
expert_params_mapping
=
get_moe_impl_class
().
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
...
...
@@ -770,5 +773,19 @@ class Qwen3MoeForCausalLM(nn.Module):
else
:
logger
.
warning
(
f
"Parameter
{
name
}
not found in params_dict"
)
self
.
routed_experts_weights_of_layer
=
{
layer_id
:
layer
.
mlp
.
get_moe_weights
()
for
layer_id
,
layer
in
enumerate
(
self
.
model
.
layers
)
if
isinstance
(
layer
.
mlp
,
Qwen3MoeSparseMoeBlock
)
}
@
classmethod
def
get_model_config_for_expert_location
(
cls
,
config
):
return
ModelConfigForExpertLocation
(
num_layers
=
config
.
num_hidden_layers
,
num_logical_experts
=
config
.
num_experts
,
num_groups
=
None
,
)
EntryClass
=
Qwen3MoeForCausalLM
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment