Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4dc24bc8
Commit
4dc24bc8
authored
May 06, 2025
by
zhuwenwen
Browse files
support qwen3-moe nn layout
parent
15470ae4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
61 additions
and
4 deletions
+61
-4
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+3
-3
vllm/model_executor/models/qwen3_moe.py
vllm/model_executor/models/qwen3_moe.py
+58
-1
No files found.
vllm/model_executor/model_loader/utils.py
View file @
4dc24bc8
...
@@ -89,9 +89,9 @@ def get_model_architecture(
...
@@ -89,9 +89,9 @@ def get_model_architecture(
model_config
:
ModelConfig
)
->
Tuple
[
Type
[
nn
.
Module
],
str
]:
model_config
:
ModelConfig
)
->
Tuple
[
Type
[
nn
.
Module
],
str
]:
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
visions
=
getattr
(
model_config
.
hf_config
,
"visual"
,
[])
or
getattr
(
model_config
.
hf_config
,
"vision_config"
,
[])
visions
=
getattr
(
model_config
.
hf_config
,
"visual"
,
[])
or
getattr
(
model_config
.
hf_config
,
"vision_config"
,
[])
support_nn_architectures
=
[
'LlamaForCausalLM'
,
'Q
wen2ForCausalLM'
,
'QWenLMHeadModel
'
,
'Qwen2VLForConditionalGeneration'
,
'Qwen2_5_VLForConditionalGeneration'
,
'Qwen2MoeForCausalLM'
,
'Qwen3ForCausalLM'
,
support_nn_architectures
=
[
'LlamaForCausalLM'
,
'Q
WenLMHeadModel'
,
'Qwen2ForCausalLM
'
,
'Qwen2VLForConditionalGeneration'
,
'Qwen2_5_VLForConditionalGeneration'
,
'Qwen2MoeForCausalLM'
,
'Qwen3ForCausalLM'
,
'Qwen3MoeForCausalLM'
,
'ChatGLMModel'
,
'ChatGLMForConditionalGeneration'
,
'ChatGLMModel'
,
'ChatGLMForConditionalGeneration'
,
'BaichuanForCausalLM'
,
'BloomForCausalLM'
,
'MixtralForCausalLM'
,
'FalconForCausalLM'
,
'
BaichuanForCausalLM'
,
'BloomForCausalLM'
,
'MedusaModel'
,
'MixtralForCausalLM
'
,
'MLPSpeculatorPreTrainedModel'
,
'FalconForCausalLM'
,
'DeepseekV2ForCausalLM'
,
'DeepseekV3ForCausalLM'
,
'DeepSeekMTPModel'
]
'
MedusaModel
'
,
'MLPSpeculatorPreTrainedModel'
,
'DeepseekV2ForCausalLM'
,
'DeepseekV3ForCausalLM'
,
'DeepSeekMTPModel'
]
if
any
(
arch
in
architectures
for
arch
in
support_nn_architectures
):
if
any
(
arch
in
architectures
for
arch
in
support_nn_architectures
):
if
os
.
getenv
(
'LLAMA_NN'
)
!=
'0'
:
if
os
.
getenv
(
'LLAMA_NN'
)
!=
'0'
:
if
(
architectures
==
[
'QWenLMHeadModel'
]
or
architectures
==
[
'ChatGLMModel'
]
)
and
visions
!=
[]:
if
(
architectures
==
[
'QWenLMHeadModel'
]
or
architectures
==
[
'ChatGLMModel'
]
)
and
visions
!=
[]:
...
...
vllm/model_executor/models/qwen3_moe.py
View file @
4dc24bc8
...
@@ -23,6 +23,8 @@
...
@@ -23,6 +23,8 @@
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
os
import
re
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
...
@@ -55,6 +57,9 @@ from .utils import (AutoWeightsLoader, extract_layer_index,
...
@@ -55,6 +57,9 @@ from .utils import (AutoWeightsLoader, extract_layer_index,
is_pp_missing_parameter
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
maybe_prefix
)
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
from
vllm.utils
import
W8a8GetCacheJSON
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -343,6 +348,18 @@ class Qwen3MoeModel(nn.Module):
...
@@ -343,6 +348,18 @@ class Qwen3MoeModel(nn.Module):
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
(
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
self
.
tritonsingleton
=
W8a8GetCacheJSON
()
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'1'
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
return
self
.
embed_tokens
(
input_ids
)
...
@@ -472,6 +489,46 @@ class Qwen3MoeModel(nn.Module):
...
@@ -472,6 +489,46 @@ class Qwen3MoeModel(nn.Module):
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
loaded_params
.
add
(
name
)
if
self
.
use_llama_nn
and
self
.
quant_method
is
None
:
lay_key_words
=
[
"gate_up_proj.weight"
,
"down_proj.weight"
,
"mlp.gate.weight"
,
"self_attn.qkv_proj.weight"
,
"self_attn.o_proj.weight"
,
"lm_head.weight"
,
]
combined_words
=
"|"
.
join
(
lay_key_words
)
# lay_qkv_words = ["self_attn.qkv_proj.weight"]
# qkv_words = "|".join(lay_qkv_words)
# lay_qkv_bias_words = ["self_attn.qkv_proj.bias"]
# qkv_bias_words = "|".join(lay_qkv_bias_words)
for
layername
in
loaded_params
:
weight
=
params_dict
[
layername
]
os
.
environ
[
'LM_NN'
]
=
'0'
# if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
# weight.data = pad_weight(weight.data, 32)
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
# if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
# weight.data = pad_weight(weight.data, 32)
# if self.use_fa_pad and (re.findall(qkv_words, layername)):
# if not gemm_bank_conf(weight.data.shape[0]):
# weight.data = pad_weight(weight.data, 32)
_weight
=
torch
.
zeros_like
(
weight
.
data
)
ori_shape
=
_weight
.
shape
ops
.
trans_w16_gemm
(
_weight
,
weight
.
data
,
_weight
.
shape
[
0
],
_weight
.
shape
[
1
])
weight
.
data
.
copy_
(
_weight
)
weight
.
data
=
weight
.
data
.
reshape
(
ori_shape
[
1
],
-
1
)
return
loaded_params
return
loaded_params
...
@@ -525,4 +582,4 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
...
@@ -525,4 +582,4 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
self
,
self
,
skip_prefixes
=
([
"rotary_emb.inv_freq"
]),
skip_prefixes
=
([
"rotary_emb.inv_freq"
]),
)
)
return
loader
.
load_weights
(
weights
)
return
loader
.
load_weights
(
weights
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment