Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1b769dcc
Unverified
Commit
1b769dcc
authored
Jul 28, 2025
by
Jee Jee Li
Committed by
GitHub
Jul 28, 2025
Browse files
[Bugfix] Fix Ernie4_5_MoeForCausalLM shared experts (#21717)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
2cc57119
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
5 deletions
+6
-5
vllm/model_executor/models/ernie45_moe.py
vllm/model_executor/models/ernie45_moe.py
+6
-5
No files found.
vllm/model_executor/models/ernie45_moe.py
View file @
1b769dcc
...
...
@@ -109,8 +109,8 @@ class Ernie4_5_MoeMoE(nn.Module):
layer_idx
=
extract_layer_index
(
prefix
)
self
.
layer_idx
=
layer_idx
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
moe_num
_shared_experts
=
getattr
(
config
,
"moe_num_shared_experts"
,
None
)
self
.
has
_shared_experts
=
(
getattr
(
config
,
"moe_num_shared_experts"
,
0
)
>
0
)
if
self
.
tp_size
>
config
.
moe_num_experts
:
raise
ValueError
(
...
...
@@ -137,7 +137,7 @@ class Ernie4_5_MoeMoE(nn.Module):
prefix
=
f
"
{
prefix
}
.experts"
,
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
)
if
self
.
moe_num
_shared_experts
is
not
None
:
if
self
.
has
_shared_experts
:
intermediate_size
=
(
config
.
moe_intermediate_size
*
config
.
moe_num_shared_experts
)
self
.
shared_experts
=
Ernie4_5_MoeMLP
(
...
...
@@ -153,7 +153,8 @@ class Ernie4_5_MoeMoE(nn.Module):
orig_shape
=
hidden_states
.
shape
hidden_dim
=
hidden_states
.
shape
[
-
1
]
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
if
self
.
moe_num_shared_experts
is
not
None
:
shared_output
=
None
if
self
.
has_shared_experts
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
...
...
@@ -161,7 +162,7 @@ class Ernie4_5_MoeMoE(nn.Module):
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
if
self
.
moe_num
_shared_experts
is
not
None
and
\
if
self
.
has
_shared_experts
and
\
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment