Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
eed74a55
Unverified
Commit
eed74a55
authored
Dec 17, 2023
by
Roy
Committed by
GitHub
Dec 16, 2023
Browse files
Simplify weight loading logic (#2133)
parent
2acd76f3
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
37 deletions
+33
-37
vllm/config.py
vllm/config.py
+4
-9
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+5
-1
vllm/model_executor/weight_utils.py
vllm/model_executor/weight_utils.py
+24
-27
No files found.
vllm/config.py
View file @
eed74a55
...
@@ -122,15 +122,10 @@ class ModelConfig:
...
@@ -122,15 +122,10 @@ class ModelConfig:
# TODO: Remove this check once HF updates the pt weights of Mixtral.
# TODO: Remove this check once HF updates the pt weights of Mixtral.
architectures
=
getattr
(
self
.
hf_config
,
"architectures"
,
[])
architectures
=
getattr
(
self
.
hf_config
,
"architectures"
,
[])
if
"MixtralForCausalLM"
in
architectures
:
if
"MixtralForCausalLM"
in
architectures
and
load_format
==
"pt"
:
if
load_format
==
"pt"
:
raise
ValueError
(
raise
ValueError
(
"Currently, the 'pt' format is not supported for Mixtral. "
"Currently, the 'pt' format is not supported for Mixtral. "
"Please use the 'safetensors' format instead. "
)
"Please use the 'safetensors' format instead. "
)
elif
load_format
==
"auto"
:
# Do not fall back to pt weights.
load_format
=
"safetensors"
self
.
load_format
=
load_format
self
.
load_format
=
load_format
def
_verify_tokenizer_mode
(
self
)
->
None
:
def
_verify_tokenizer_mode
(
self
)
->
None
:
...
...
vllm/model_executor/models/mixtral.py
View file @
eed74a55
...
@@ -412,7 +412,11 @@ class MixtralForCausalLM(nn.Module):
...
@@ -412,7 +412,11 @@ class MixtralForCausalLM(nn.Module):
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
model_name_or_path
,
cache_dir
,
load_format
,
revision
,
fall_back_to_pt
=
False
):
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
...
...
vllm/model_executor/weight_utils.py
View file @
eed74a55
...
@@ -125,15 +125,29 @@ def get_quant_config(
...
@@ -125,15 +125,29 @@ def get_quant_config(
def
prepare_hf_model_weights
(
def
prepare_hf_model_weights
(
model_name_or_path
:
str
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
use_safetensors
:
bool
=
False
,
load_format
:
str
=
"auto"
,
fall_back_to_pt
:
bool
=
True
,
fall_back_to_pt
:
bool
=
True
,
revision
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
)
->
Tuple
[
str
,
List
[
str
],
bool
]:
)
->
Tuple
[
str
,
List
[
str
],
bool
]:
# Download model weights from huggingface.
# Download model weights from huggingface.
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
use_safetensors
=
False
# Some quantized models use .pt files for storing the weights.
# Some quantized models use .pt files for storing the weights.
allow_patterns
=
[
"*.safetensors"
if
load_format
==
"auto"
:
]
if
use_safetensors
else
[
"*.bin"
,
"*.pt"
]
allow_patterns
=
[
"*.safetensors"
,
"*.bin"
]
elif
load_format
==
"safetensors"
:
use_safetensors
=
True
allow_patterns
=
[
"*.safetensors"
]
elif
load_format
==
"pt"
:
allow_patterns
=
[
"*.pt"
]
elif
load_format
==
"npcache"
:
allow_patterns
=
[
"*.bin"
]
else
:
raise
ValueError
(
f
"Unknown load_format:
{
load_format
}
"
)
if
fall_back_to_pt
:
allow_patterns
+=
[
".pt"
]
if
not
is_local
:
if
not
is_local
:
# Use file lock to prevent multiple processes from
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
# downloading the same model weights at the same time.
...
@@ -148,6 +162,10 @@ def prepare_hf_model_weights(
...
@@ -148,6 +162,10 @@ def prepare_hf_model_weights(
hf_weights_files
:
List
[
str
]
=
[]
hf_weights_files
:
List
[
str
]
=
[]
for
pattern
in
allow_patterns
:
for
pattern
in
allow_patterns
:
hf_weights_files
+=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
pattern
))
hf_weights_files
+=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
pattern
))
if
len
(
hf_weights_files
)
>
0
:
if
pattern
==
"*.safetensors"
:
use_safetensors
=
True
break
if
not
use_safetensors
:
if
not
use_safetensors
:
# Exclude files that are not needed for inference.
# Exclude files that are not needed for inference.
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
...
@@ -163,13 +181,6 @@ def prepare_hf_model_weights(
...
@@ -163,13 +181,6 @@ def prepare_hf_model_weights(
if
not
any
(
f
.
endswith
(
x
)
for
x
in
blacklist
)
if
not
any
(
f
.
endswith
(
x
)
for
x
in
blacklist
)
]
]
if
len
(
hf_weights_files
)
==
0
and
use_safetensors
and
fall_back_to_pt
:
return
prepare_hf_model_weights
(
model_name_or_path
,
cache_dir
=
cache_dir
,
use_safetensors
=
False
,
fall_back_to_pt
=
False
,
revision
=
revision
)
if
len
(
hf_weights_files
)
==
0
:
if
len
(
hf_weights_files
)
==
0
:
raise
RuntimeError
(
raise
RuntimeError
(
f
"Cannot find any model weights with `
{
model_name_or_path
}
`"
)
f
"Cannot find any model weights with `
{
model_name_or_path
}
`"
)
...
@@ -182,30 +193,16 @@ def hf_model_weights_iterator(
...
@@ -182,30 +193,16 @@ def hf_model_weights_iterator(
cache_dir
:
Optional
[
str
]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
fall_back_to_pt
:
Optional
[
bool
]
=
True
,
)
->
Iterator
[
Tuple
[
str
,
torch
.
Tensor
]]:
)
->
Iterator
[
Tuple
[
str
,
torch
.
Tensor
]]:
use_safetensors
=
False
use_np_cache
=
False
fall_back_to_pt
=
False
if
load_format
==
"auto"
:
use_safetensors
=
True
fall_back_to_pt
=
True
elif
load_format
==
"safetensors"
:
use_safetensors
=
True
elif
load_format
==
"pt"
:
pass
elif
load_format
==
"npcache"
:
use_np_cache
=
True
else
:
raise
ValueError
(
f
"Unknown load_format:
{
load_format
}
"
)
hf_folder
,
hf_weights_files
,
use_safetensors
=
prepare_hf_model_weights
(
hf_folder
,
hf_weights_files
,
use_safetensors
=
prepare_hf_model_weights
(
model_name_or_path
,
model_name_or_path
,
cache_dir
=
cache_dir
,
cache_dir
=
cache_dir
,
use_safetensors
=
use_safetensors
,
load_format
=
load_format
,
fall_back_to_pt
=
fall_back_to_pt
,
fall_back_to_pt
=
fall_back_to_pt
,
revision
=
revision
)
revision
=
revision
)
if
use_
np
_
cache
:
if
load_format
==
"
npcache
"
:
# Currently np_cache only support *.bin checkpoints
# Currently np_cache only support *.bin checkpoints
assert
use_safetensors
is
False
assert
use_safetensors
is
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment