Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b159c0a6
Unverified
Commit
b159c0a6
authored
Aug 13, 2025
by
Gh0u1L5
Committed by
GitHub
Aug 13, 2025
Browse files
Fix GGUF loader for Qwen3 MoE. (#22785)
Signed-off-by:
Gh0u1L5
<
Gh0u1L5@outlook.com
>
parent
6772bb0f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
0 deletions
+12
-0
vllm/model_executor/model_loader/gguf_loader.py
vllm/model_executor/model_loader/gguf_loader.py
+11
-0
vllm/model_executor/models/qwen3_moe.py
vllm/model_executor/models/qwen3_moe.py
+1
-0
No files found.
vllm/model_executor/model_loader/gguf_loader.py
View file @
b159c0a6
...
@@ -74,6 +74,17 @@ class GGUFModelLoader(BaseModelLoader):
...
@@ -74,6 +74,17 @@ class GGUFModelLoader(BaseModelLoader):
f
"model.layers.
{
idx
}
.mlp.experts.0.gate_proj.weight"
f
"model.layers.
{
idx
}
.mlp.experts.0.gate_proj.weight"
gguf_to_hf_name_map
[
f
"blk.
{
idx
}
.ffn_up_exps.weight"
]
=
\
gguf_to_hf_name_map
[
f
"blk.
{
idx
}
.ffn_up_exps.weight"
]
=
\
f
"model.layers.
{
idx
}
.mlp.experts.0.up_proj.weight"
f
"model.layers.
{
idx
}
.mlp.experts.0.up_proj.weight"
if
model_type
in
(
"qwen2_moe"
,
"qwen3_moe"
):
model_type
=
model_type
.
replace
(
"_"
,
""
)
# GGUF layer map assumes that we will have a merged expert weights
# so we need to map them manually
for
idx
in
range
(
config
.
num_hidden_layers
):
gguf_to_hf_name_map
[
f
"blk.
{
idx
}
.ffn_down_exps.weight"
]
=
\
f
"model.layers.
{
idx
}
.mlp.experts.0.down_proj.weight"
gguf_to_hf_name_map
[
f
"blk.
{
idx
}
.ffn_gate_exps.weight"
]
=
\
f
"model.layers.
{
idx
}
.mlp.experts.0.gate_proj.weight"
gguf_to_hf_name_map
[
f
"blk.
{
idx
}
.ffn_up_exps.weight"
]
=
\
f
"model.layers.
{
idx
}
.mlp.experts.0.up_proj.weight"
arch
=
None
arch
=
None
for
key
,
value
in
gguf
.
MODEL_ARCH_NAMES
.
items
():
for
key
,
value
in
gguf
.
MODEL_ARCH_NAMES
.
items
():
...
...
vllm/model_executor/models/qwen3_moe.py
View file @
b159c0a6
...
@@ -375,6 +375,7 @@ class Qwen3MoeModel(nn.Module):
...
@@ -375,6 +375,7 @@ class Qwen3MoeModel(nn.Module):
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.embed_tokens"
)
prefix
=
f
"
{
prefix
}
.embed_tokens"
)
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
config
.
num_hidden_layers
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment