Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c15309a7
Unverified
Commit
c15309a7
authored
Sep 17, 2025
by
whx
Committed by
GitHub
Sep 17, 2025
Browse files
[Model] Apply SharedFusedMoE to glm4_moe. (#24849)
Signed-off-by:
whx-sjtu
<
2952154980@qq.com
>
parent
4a9375fe
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
55 additions
and
30 deletions
+55
-30
vllm/model_executor/models/glm4_moe.py
vllm/model_executor/models/glm4_moe.py
+55
-30
No files found.
vllm/model_executor/models/glm4_moe.py
View file @
c15309a7
...
@@ -46,6 +46,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...
@@ -46,6 +46,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.shared_fused_moe
import
SharedFusedMoE
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
...
@@ -146,7 +147,19 @@ class Glm4MoE(nn.Module):
...
@@ -146,7 +147,19 @@ class Glm4MoE(nn.Module):
self
.
physical_expert_end
=
(
self
.
physical_expert_start
+
self
.
physical_expert_end
=
(
self
.
physical_expert_start
+
self
.
n_local_physical_experts
)
self
.
n_local_physical_experts
)
self
.
experts
=
FusedMoE
(
if
config
.
n_shared_experts
is
not
None
:
intermediate_size
=
(
config
.
moe_intermediate_size
*
config
.
n_shared_experts
)
self
.
shared_experts
=
Glm4MoeMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
reduce_results
=
False
,
prefix
=
f
"
{
prefix
}
.shared_experts"
,
)
self
.
experts
=
SharedFusedMoE
(
shared_experts
=
self
.
shared_experts
,
num_experts
=
config
.
n_routed_experts
,
num_experts
=
config
.
n_routed_experts
,
top_k
=
config
.
num_experts_per_tok
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
...
@@ -163,35 +176,47 @@ class Glm4MoE(nn.Module):
...
@@ -163,35 +176,47 @@ class Glm4MoE(nn.Module):
routed_scaling_factor
=
1.0
,
routed_scaling_factor
=
1.0
,
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
,
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
,
enable_eplb
=
self
.
enable_eplb
,
enable_eplb
=
self
.
enable_eplb
,
num_redundant_experts
=
self
.
n_redundant_experts
)
num_redundant_experts
=
self
.
n_redundant_experts
,
)
if
config
.
n_shared_experts
is
not
Non
e
:
els
e
:
intermediate_size
=
(
config
.
moe_intermediate_size
*
self
.
experts
=
FusedMoE
(
config
.
n_
shar
ed_experts
)
num_experts
=
config
.
n_
rout
ed_experts
,
self
.
shared_experts
=
Glm4MoeMLP
(
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
hidden_act
=
config
.
hidden_act
,
reduce_results
=
False
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
reduce_results
=
self
.
experts
.
must_reduce_shared_expert_outputs
(
use_grouped_topk
=
True
,
),
num_expert_group
=
config
.
n_group
,
prefix
=
f
"
{
prefix
}
.shared_experts"
,
topk_group
=
config
.
topk_group
,
)
prefix
=
f
"
{
prefix
}
.experts"
,
scoring_func
=
"sigmoid"
,
# we do scaling outside, set factor to 1.0 to avoid double mul
routed_scaling_factor
=
1.0
,
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
,
enable_eplb
=
self
.
enable_eplb
,
num_redundant_experts
=
self
.
n_redundant_experts
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
if
self
.
n_shared_experts
is
not
None
:
# router_logits: (num_tokens, n_experts)
shared_output
=
self
.
shared_experts
(
hidden_states
)
else
:
shared_output
=
None
router_logits
=
self
.
gate
(
hidden_states
.
to
(
dtype
=
torch
.
float32
))
router_logits
=
self
.
gate
(
hidden_states
.
to
(
dtype
=
torch
.
float32
))
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
fused_moe_out
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
*
self
.
routed_scaling_factor
router_logits
=
router_logits
)
if
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
if
self
.
shared_experts
is
not
None
:
shared_output
,
final_hidden_states
=
fused_moe_out
assert
shared_output
is
not
None
final_hidden_states
=
\
final_hidden_states
*
self
.
routed_scaling_factor
\
+
shared_output
else
:
final_hidden_states
=
fused_moe_out
*
self
.
routed_scaling_factor
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
final_hidden_states
=
(
final_hidden_states
=
(
self
.
experts
.
maybe_all_reduce_tensor_model_parallel
(
self
.
experts
.
maybe_all_reduce_tensor_model_parallel
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment