Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8edaf385
Unverified
Commit
8edaf385
authored
Jan 24, 2026
by
Isotr0py
Committed by
GitHub
Jan 23, 2026
Browse files
[Models] Add `SharedFusedMoE` support to Qwen3MoE (#32082)
Signed-off-by:
Isotr0py
<
mozf@mail2.sysu.edu.cn
>
parent
5c86a898
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
56 additions
and
16 deletions
+56
-16
vllm/model_executor/models/qwen3_moe.py
vllm/model_executor/models/qwen3_moe.py
+56
-16
No files found.
vllm/model_executor/models/qwen3_moe.py
View file @
8edaf385
...
...
@@ -29,6 +29,7 @@ from itertools import islice
from
typing
import
Any
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
vllm.attention.layer
import
Attention
...
...
@@ -42,7 +43,7 @@ from vllm.distributed import (
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
Shared
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
...
...
@@ -86,6 +87,7 @@ class Qwen3MoeMLP(nn.Module):
hidden_act
:
str
,
quant_config
:
QuantizationConfig
|
None
=
None
,
reduce_results
:
bool
=
True
,
expert_gate
:
torch
.
nn
.
Linear
|
None
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
...
...
@@ -109,12 +111,17 @@ class Qwen3MoeMLP(nn.Module):
f
"Unsupported activation:
{
hidden_act
}
. Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
self
.
expert_gate
=
expert_gate
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
out
=
self
.
act_fn
(
gate_up
)
out
,
_
=
self
.
down_proj
(
out
)
if
self
.
expert_gate
is
not
None
:
out
=
F
.
sigmoid
(
self
.
expert_gate
(
x
)[
0
])
*
out
return
out
class
Qwen3MoeSparseMoeBlock
(
nn
.
Module
):
...
...
@@ -159,12 +166,46 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
self
.
physical_expert_start
+
self
.
n_local_physical_experts
)
self
.
experts
=
FusedMoE
(
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
num_experts
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate"
,
)
shared_expert_intermediate_size
=
getattr
(
config
,
"shared_expert_intermediate_size"
,
0
)
if
shared_expert_intermediate_size
>
0
:
self
.
shared_expert_gate
=
ReplicatedLinear
(
config
.
hidden_size
,
1
,
bias
=
False
,
quant_config
=
None
,
prefix
=
f
"
{
prefix
}
.shared_expert_gate"
,
)
self
.
shared_expert
=
Qwen3MoeMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
shared_expert_intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
reduce_results
=
False
,
expert_gate
=
self
.
shared_expert_gate
,
prefix
=
f
"
{
prefix
}
.shared_expert"
,
)
else
:
self
.
shared_expert_gate
=
None
self
.
shared_expert
=
None
self
.
experts
=
SharedFusedMoE
(
shared_experts
=
self
.
shared_expert
,
gate
=
self
.
gate
,
num_experts
=
self
.
n_routed_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
Tru
e
,
reduce_results
=
Fals
e
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.experts"
,
...
...
@@ -173,14 +214,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
is_sequence_parallel
=
self
.
is_sequence_parallel
,
)
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
num_experts
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate"
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
assert
hidden_states
.
dim
()
<=
2
,
(
"Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs"
...
...
@@ -194,15 +227,22 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
self
.
experts
(
shared_out
,
fused_out
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
final_hidden_states
=
(
shared_out
+
fused_out
if
shared_out
is
not
None
else
fused_out
)
if
self
.
is_sequence_parallel
:
final_hidden_states
=
tensor_model_parallel_all_gather
(
final_hidden_states
,
0
)
final_hidden_states
=
final_hidden_states
[:
num_tokens
]
elif
self
.
tp_size
>
1
:
final_hidden_states
=
self
.
experts
.
maybe_all_reduce_tensor_model_parallel
(
# noqa E501
final_hidden_states
)
# return to 1d if input is 1d
return
final_hidden_states
.
squeeze
(
0
)
if
is_input_1d
else
final_hidden_states
...
...
@@ -467,7 +507,7 @@ class Qwen3MoeModel(nn.Module):
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return
FusedMoE
.
make_expert_params_mapping
(
return
Shared
FusedMoE
.
make_expert_params_mapping
(
self
,
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment