Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ef6e0e71
Unverified
Commit
ef6e0e71
authored
Sep 30, 2025
by
CSWYF3634076
Committed by
GitHub
Sep 30, 2025
Browse files
[Bugfix][Model]fix ernie45 moe gate&bias dtype to float32 (#25936)
Signed-off-by:
wangyafeng
<
wangyafeng@baidu.com
>
parent
1ad3aca6
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
7 deletions
+13
-7
vllm/model_executor/models/ernie45_moe.py
vllm/model_executor/models/ernie45_moe.py
+3
-2
vllm/model_executor/models/ernie45_vl_moe.py
vllm/model_executor/models/ernie45_vl_moe.py
+10
-5
No files found.
vllm/model_executor/models/ernie45_moe.py
View file @
ef6e0e71
...
...
@@ -120,11 +120,12 @@ class Ernie4_5_MoeMoE(nn.Module):
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
moe_num_experts
,
bias
=
False
,
params_dtype
=
torch
.
float32
,
quant_config
=
None
,
prefix
=
f
"
{
prefix
}
.gate"
)
self
.
gate
.
e_score_correction_bias
=
nn
.
Parameter
(
torch
.
empty
(
config
.
moe_num_experts
))
torch
.
empty
(
config
.
moe_num_experts
,
dtype
=
torch
.
float32
))
self
.
experts
=
FusedMoE
(
num_experts
=
config
.
moe_num_experts
,
...
...
@@ -157,7 +158,7 @@ class Ernie4_5_MoeMoE(nn.Module):
if
self
.
has_shared_experts
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
.
to
(
dtype
=
torch
.
float32
)
)
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
...
...
vllm/model_executor/models/ernie45_vl_moe.py
View file @
ef6e0e71
...
...
@@ -199,7 +199,7 @@ class Ernie4_5_VLMoeMoE(nn.Module):
assert
config
.
moe_num_experts
[
0
]
==
config
.
moe_num_experts
[
1
]
self
.
e_score_correction_bias
=
nn
.
Parameter
(
torch
.
empty
(
2
,
config
.
moe_num_experts
[
0
]))
torch
.
empty
(
2
,
config
.
moe_num_experts
[
0
]
,
dtype
=
torch
.
float32
))
assert
text_moe_layer_start_index
<=
text_moe_layer_end_index
...
...
@@ -209,6 +209,7 @@ class Ernie4_5_VLMoeMoE(nn.Module):
config
.
hidden_size
,
config
.
moe_num_experts
[
0
],
bias
=
False
,
params_dtype
=
torch
.
float32
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.text_experts_gate"
)
...
...
@@ -238,6 +239,7 @@ class Ernie4_5_VLMoeMoE(nn.Module):
config
.
hidden_size
,
config
.
moe_num_experts
[
1
],
bias
=
False
,
params_dtype
=
torch
.
float32
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.vision_experts_gate"
)
...
...
@@ -288,7 +290,8 @@ class Ernie4_5_VLMoeMoE(nn.Module):
if
visual_token_mask
is
not
None
and
visual_token_mask
.
all
():
# only vision modal input
router_logits
,
_
=
self
.
vision_experts_gate
(
hidden_states
)
router_logits
,
_
=
self
.
vision_experts_gate
(
hidden_states
.
to
(
dtype
=
torch
.
float32
))
final_hidden_states
=
self
.
vision_experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
elif
visual_token_mask
is
not
None
and
visual_token_mask
.
any
():
...
...
@@ -303,19 +306,21 @@ class Ernie4_5_VLMoeMoE(nn.Module):
vision_hidden_states
=
hidden_states
[
visual_token_mask
].
reshape
(
-
1
,
self
.
hidden_size
)
text_router_logits
,
_
=
self
.
text_experts_gate
(
text_hidden_states
)
text_router_logits
,
_
=
self
.
text_experts_gate
(
text_hidden_states
.
to
(
dtype
=
torch
.
float32
))
final_hidden_states
[
text_token_mask
]
=
self
.
text_experts
(
hidden_states
=
text_hidden_states
,
router_logits
=
text_router_logits
).
flatten
()
vision_router_logits
,
_
=
self
.
vision_experts_gate
(
vision_hidden_states
)
vision_hidden_states
.
to
(
dtype
=
torch
.
float32
)
)
final_hidden_states
[
visual_token_mask
]
=
self
.
vision_experts
(
hidden_states
=
vision_hidden_states
,
router_logits
=
vision_router_logits
).
flatten
()
else
:
# only text modal input
text_router_logits
,
_
=
self
.
text_experts_gate
(
hidden_states
)
text_router_logits
,
_
=
self
.
text_experts_gate
(
hidden_states
.
to
(
dtype
=
torch
.
float32
))
final_hidden_states
=
self
.
text_experts
(
hidden_states
=
hidden_states
,
router_logits
=
text_router_logits
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment