Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f0a30a06
Unverified
Commit
f0a30a06
authored
Oct 11, 2025
by
Jee Jee Li
Committed by
GitHub
Oct 11, 2025
Browse files
[Bugfix] Fix qwen-moe packed_modules_mapping (#26634)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
9d6cff3e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
23 additions
and
11 deletions
+23
-11
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+1
-1
vllm/model_executor/models/qwen2_moe.py
vllm/model_executor/models/qwen2_moe.py
+13
-5
vllm/model_executor/models/qwen3_moe.py
vllm/model_executor/models/qwen3_moe.py
+9
-5
No files found.
vllm/model_executor/models/interfaces.py
View file @
f0a30a06
...
...
@@ -325,7 +325,7 @@ class SupportsLoRA(Protocol):
# are empty by default.
embedding_modules
:
ClassVar
[
dict
[
str
,
str
]]
=
{}
embedding_padding_modules
:
ClassVar
[
list
[
str
]]
=
[]
packed_modules_mapping
:
ClassVar
[
dict
[
str
,
list
[
str
]]
]
=
{}
packed_modules_mapping
:
dict
[
str
,
list
[
str
]]
=
{}
# We can't use runtime_checkable with ClassVar for issubclass checks
...
...
vllm/model_executor/models/qwen2_moe.py
View file @
f0a30a06
...
...
@@ -534,11 +534,7 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
]
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
...
@@ -547,6 +543,18 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
# Only perform the following mapping when Qwen2MoeMLP exists
if
(
getattr
(
config
,
"mlp_only_layers"
,
[])
or
config
.
shared_expert_intermediate_size
>
0
):
self
.
packed_modules_mapping
[
"gate_up_proj"
]
=
(
[
"gate_proj"
,
"up_proj"
,
],
)
self
.
model
=
Qwen2MoeModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
)
)
...
...
vllm/model_executor/models/qwen3_moe.py
View file @
f0a30a06
...
...
@@ -634,11 +634,7 @@ class Qwen3MoeForCausalLM(
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
]
}
fall_back_to_pt_during_load
=
False
...
...
@@ -649,6 +645,14 @@ class Qwen3MoeForCausalLM(
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
# Only perform the following mapping when Qwen3MoeMLP exists
if
getattr
(
config
,
"mlp_only_layers"
,
[]):
self
.
packed_modules_mapping
[
"gate_up_proj"
]
=
(
[
"gate_proj"
,
"up_proj"
,
],
)
self
.
model
=
Qwen3MoeModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment