Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
db94f061
Commit
db94f061
authored
Apr 24, 2025
by
zhuwenwen
Browse files
fix llama and qwen layout
parent
14f46a65
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
47 additions
and
30 deletions
+47
-30
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+29
-26
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+8
-0
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+5
-2
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+5
-2
No files found.
vllm/model_executor/model_loader/utils.py
View file @
db94f061
...
@@ -33,45 +33,52 @@ def set_default_torch_dtype(dtype: torch.dtype):
...
@@ -33,45 +33,52 @@ def set_default_torch_dtype(dtype: torch.dtype):
def
is_transformers_impl_compatible
(
def
is_transformers_impl_compatible
(
arch
:
str
,
arch
:
str
,
module
:
Optional
[
transformers
.
PreTrainedModel
]
=
None
)
->
bool
:
module
:
Optional
[
"
transformers.PreTrainedModel
"
]
=
None
)
->
bool
:
mod
=
module
or
getattr
(
transformers
,
arch
,
None
)
mod
=
module
or
getattr
(
transformers
,
arch
,
None
)
if
mod
is
None
:
if
mod
is
None
:
return
False
return
False
if
hasattr
(
mod
,
"supports_backend"
):
return
mod
.
is_backend_compatible
()
return
mod
.
is_backend_compatible
()
else
:
return
mod
.
_supports_flex_attn
def
resolve_transformers_
fallback
(
model_config
:
ModelConfig
,
def
resolve_transformers_
arch
(
model_config
:
ModelConfig
,
architectures
:
list
[
str
]):
architectures
:
list
[
str
]):
for
i
,
arch
in
enumerate
(
architectures
):
for
i
,
arch
in
enumerate
(
architectures
):
if
arch
==
"Transformers
Model
"
:
if
arch
==
"Transformers
ForCausalLM
"
:
continue
continue
custom_module
=
None
auto_map
:
dict
[
str
,
str
]
=
getattr
(
model_config
.
hf_config
,
"auto_map"
,
auto_map
=
getattr
(
model_config
.
hf_config
,
"auto_map"
,
None
)
None
)
or
dict
()
if
auto_map
is
not
None
and
"AutoModel"
in
auto_map
:
# Make sure that config class is always initialized before model class,
custom_module
=
get_class_from_dynamic_module
(
# otherwise the model class won't be able to access the config class,
model_config
.
hf_config
.
auto_map
[
"AutoModel"
],
# the expected auto_map should have correct order like:
model_config
.
model
)
# "auto_map": {
# "AutoConfig": "<your-repo-name>--<config-name>",
# "AutoModel": "<your-repo-name>--<config-name>",
# "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
# },
auto_modules
=
{
name
:
get_class_from_dynamic_module
(
module
,
model_config
.
model
)
for
name
,
module
in
sorted
(
auto_map
.
items
(),
key
=
lambda
x
:
x
[
0
])
}
custom_model_module
=
auto_modules
.
get
(
"AutoModel"
)
# TODO(Isotr0py): Further clean up these raises.
# TODO(Isotr0py): Further clean up these raises.
# perhaps handled them in _ModelRegistry._raise_for_unsupported?
# perhaps handled them in _ModelRegistry._raise_for_unsupported?
if
model_config
.
model_impl
==
ModelImpl
.
TRANSFORMERS
:
if
model_config
.
model_impl
==
ModelImpl
.
TRANSFORMERS
:
if
not
is_transformers_impl_compatible
(
arch
,
custom_module
):
if
not
is_transformers_impl_compatible
(
arch
,
custom_
model_
module
):
raise
ValueError
(
raise
ValueError
(
f
"The Transformers implementation of
{
arch
}
is not "
f
"The Transformers implementation of
{
arch
}
is not "
"compatible with vLLM."
)
"compatible with vLLM."
)
architectures
[
i
]
=
"Transformers
Model
"
architectures
[
i
]
=
"Transformers
ForCausalLM
"
if
model_config
.
model_impl
==
ModelImpl
.
AUTO
:
if
model_config
.
model_impl
==
ModelImpl
.
AUTO
:
if
not
is_transformers_impl_compatible
(
arch
,
custom_module
):
if
not
is_transformers_impl_compatible
(
arch
,
custom_
model_
module
):
raise
ValueError
(
raise
ValueError
(
f
"
{
arch
}
has no vLLM implementation and the Transformers "
f
"
{
arch
}
has no vLLM implementation and the Transformers "
"implementation is not compatible with vLLM."
)
"implementation is not compatible with vLLM. Try setting "
"VLLM_USE_V1=0."
)
logger
.
warning
(
logger
.
warning
(
"%s has no vLLM implementation, falling back to Transformers "
"%s has no vLLM implementation, falling back to Transformers "
"implementation. Some features may not be supported and "
"implementation. Some features may not be supported and "
"performance may not be optimal."
,
arch
)
"performance may not be optimal."
,
arch
)
architectures
[
i
]
=
"Transformers
Model
"
architectures
[
i
]
=
"Transformers
ForCausalLM
"
return
architectures
return
architectures
...
@@ -111,9 +118,6 @@ def get_model_architecture(
...
@@ -111,9 +118,6 @@ def get_model_architecture(
os
.
environ
[
'AWQ_PAD'
]
=
'1'
os
.
environ
[
'AWQ_PAD'
]
=
'1'
else
:
else
:
os
.
environ
[
'AWQ_PAD'
]
=
'0'
os
.
environ
[
'AWQ_PAD'
]
=
'0'
else
:
if
os
.
getenv
(
'LLAMA_NN'
)
==
'1'
:
os
.
environ
[
'LLAMA_NN'
]
=
'1'
else
:
else
:
os
.
environ
[
'LLAMA_NN'
]
=
'0'
os
.
environ
[
'LLAMA_NN'
]
=
'0'
os
.
environ
[
'LM_NN'
]
=
'0'
os
.
environ
[
'LM_NN'
]
=
'0'
...
@@ -137,8 +141,7 @@ def get_model_architecture(
...
@@ -137,8 +141,7 @@ def get_model_architecture(
for
arch
in
architectures
)
for
arch
in
architectures
)
if
(
not
is_vllm_supported
if
(
not
is_vllm_supported
or
model_config
.
model_impl
==
ModelImpl
.
TRANSFORMERS
):
or
model_config
.
model_impl
==
ModelImpl
.
TRANSFORMERS
):
architectures
=
resolve_transformers_fallback
(
model_config
,
architectures
=
resolve_transformers_arch
(
model_config
,
architectures
)
architectures
)
model_cls
,
arch
=
ModelRegistry
.
resolve_model_cls
(
architectures
)
model_cls
,
arch
=
ModelRegistry
.
resolve_model_cls
(
architectures
)
if
model_config
.
task
==
"embed"
:
if
model_config
.
task
==
"embed"
:
...
...
vllm/model_executor/model_loader/weight_utils.py
View file @
db94f061
...
@@ -407,6 +407,11 @@ def safetensors_weights_iterator(
...
@@ -407,6 +407,11 @@ def safetensors_weights_iterator(
hf_weights_files
:
List
[
str
]
hf_weights_files
:
List
[
str
]
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Iterate over the weights in the model safetensor files."""
"""Iterate over the weights in the model safetensor files."""
total_count
=
0
for
st_file
in
hf_weights_files
:
with
safe_open
(
st_file
,
framework
=
"pt"
)
as
f
:
total_count
+=
len
(
f
.
keys
())
current_count
=
0
enable_tqdm
=
not
torch
.
distributed
.
is_initialized
(
enable_tqdm
=
not
torch
.
distributed
.
is_initialized
(
)
or
torch
.
distributed
.
get_rank
()
==
0
)
or
torch
.
distributed
.
get_rank
()
==
0
for
st_file
in
tqdm
(
for
st_file
in
tqdm
(
...
@@ -417,7 +422,10 @@ def safetensors_weights_iterator(
...
@@ -417,7 +422,10 @@ def safetensors_weights_iterator(
):
):
with
safe_open
(
st_file
,
framework
=
"pt"
)
as
f
:
with
safe_open
(
st_file
,
framework
=
"pt"
)
as
f
:
for
name
in
f
.
keys
():
# noqa: SIM118
for
name
in
f
.
keys
():
# noqa: SIM118
current_count
+=
1
param
=
f
.
get_tensor
(
name
)
param
=
f
.
get_tensor
(
name
)
param
.
current_count
=
current_count
param
.
total_count
=
total_count
yield
name
,
param
yield
name
,
param
...
...
vllm/model_executor/models/llama.py
View file @
db94f061
...
@@ -414,6 +414,8 @@ class LlamaModel(nn.Module):
...
@@ -414,6 +414,8 @@ class LlamaModel(nn.Module):
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
current_count
=
loaded_weight
.
current_count
total_count
=
loaded_weight
.
total_count
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
if
(
"rotary_emb.cos_cached"
in
name
if
(
"rotary_emb.cos_cached"
in
name
...
@@ -466,7 +468,7 @@ class LlamaModel(nn.Module):
...
@@ -466,7 +468,7 @@ class LlamaModel(nn.Module):
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
loaded_params
.
add
(
name
)
if
self
.
use_llama_nn
and
self
.
quant_method
is
None
:
if
self
.
use_llama_nn
and
self
.
quant_method
is
None
and
current_count
==
total_count
:
lay_key_words
=
[
lay_key_words
=
[
"self_attn.qkv_proj.weight"
,
"self_attn.qkv_proj.weight"
,
"self_attn.o_proj.weight"
,
"self_attn.o_proj.weight"
,
...
@@ -479,7 +481,8 @@ class LlamaModel(nn.Module):
...
@@ -479,7 +481,8 @@ class LlamaModel(nn.Module):
# qkv_words = "|".join(lay_qkv_words)
# qkv_words = "|".join(lay_qkv_words)
# for layername, weight in params_dict.items():
# for layername, weight in params_dict.items():
for
layername
in
loaded_params
:
# for layername in loaded_params:
for
layername
in
params_dict
.
keys
():
weight
=
params_dict
[
layername
]
weight
=
params_dict
[
layername
]
if
"lm_head.weight"
in
layername
and
weight
.
shape
[
1
]
>=
4096
:
if
"lm_head.weight"
in
layername
and
weight
.
shape
[
1
]
>=
4096
:
lay_key_words
.
append
(
"lm_head.weight"
)
lay_key_words
.
append
(
"lm_head.weight"
)
...
...
vllm/model_executor/models/qwen2.py
View file @
db94f061
...
@@ -398,6 +398,8 @@ class Qwen2Model(nn.Module):
...
@@ -398,6 +398,8 @@ class Qwen2Model(nn.Module):
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
loaded_params
:
Set
[
str
]
=
set
()
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
current_count
=
loaded_weight
.
current_count
total_count
=
loaded_weight
.
total_count
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
if
(
self
.
quant_config
is
not
None
and
if
(
self
.
quant_config
is
not
None
and
...
@@ -440,7 +442,7 @@ class Qwen2Model(nn.Module):
...
@@ -440,7 +442,7 @@ class Qwen2Model(nn.Module):
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
loaded_params
.
add
(
name
)
if
self
.
use_llama_nn
and
self
.
quant_method
is
None
:
if
self
.
use_llama_nn
and
self
.
quant_method
is
None
and
current_count
==
total_count
:
lay_key_words
=
[
lay_key_words
=
[
"self_attn.qkv_proj.weight"
,
"self_attn.qkv_proj.weight"
,
"self_attn.o_proj.weight"
,
"self_attn.o_proj.weight"
,
...
@@ -456,7 +458,8 @@ class Qwen2Model(nn.Module):
...
@@ -456,7 +458,8 @@ class Qwen2Model(nn.Module):
# qkv_bias_words = "|".join(lay_qkv_bias_words)
# qkv_bias_words = "|".join(lay_qkv_bias_words)
# for layername, weight in params_dict.items():
# for layername, weight in params_dict.items():
for
layername
in
loaded_params
:
# for layername in loaded_params:
for
layername
in
params_dict
.
keys
():
weight
=
params_dict
[
layername
]
weight
=
params_dict
[
layername
]
if
"lm_head.weight"
in
layername
and
weight
.
shape
[
1
]
>=
3584
:
if
"lm_head.weight"
in
layername
and
weight
.
shape
[
1
]
>=
3584
:
lay_key_words
.
append
(
"lm_head.weight"
)
lay_key_words
.
append
(
"lm_head.weight"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment