Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
eed74a55
Unverified
Commit
eed74a55
authored
Dec 17, 2023
by
Roy
Committed by
GitHub
Dec 16, 2023
Browse files
Simplify weight loading logic (#2133)
parent
2acd76f3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
37 deletions
+33
-37
vllm/config.py
vllm/config.py
+4
-9
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+5
-1
vllm/model_executor/weight_utils.py
vllm/model_executor/weight_utils.py
+24
-27
No files found.
vllm/config.py
View file @
eed74a55
...
...
@@ -122,15 +122,10 @@ class ModelConfig:
# TODO: Remove this check once HF updates the pt weights of Mixtral.
architectures
=
getattr
(
self
.
hf_config
,
"architectures"
,
[])
if
"MixtralForCausalLM"
in
architectures
:
if
load_format
==
"pt"
:
raise
ValueError
(
"Currently, the 'pt' format is not supported for Mixtral. "
"Please use the 'safetensors' format instead. "
)
elif
load_format
==
"auto"
:
# Do not fall back to pt weights.
load_format
=
"safetensors"
if
"MixtralForCausalLM"
in
architectures
and
load_format
==
"pt"
:
raise
ValueError
(
"Currently, the 'pt' format is not supported for Mixtral. "
"Please use the 'safetensors' format instead. "
)
self
.
load_format
=
load_format
def
_verify_tokenizer_mode
(
self
)
->
None
:
...
...
vllm/model_executor/models/mixtral.py
View file @
eed74a55
...
...
@@ -412,7 +412,11 @@ class MixtralForCausalLM(nn.Module):
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
model_name_or_path
,
cache_dir
,
load_format
,
revision
,
fall_back_to_pt
=
False
):
if
"rotary_emb.inv_freq"
in
name
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
...
...
vllm/model_executor/weight_utils.py
View file @
eed74a55
...
...
@@ -125,15 +125,29 @@ def get_quant_config(
def
prepare_hf_model_weights
(
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
use_safetensors
:
bool
=
False
,
load_format
:
str
=
"auto"
,
fall_back_to_pt
:
bool
=
True
,
revision
:
Optional
[
str
]
=
None
,
)
->
Tuple
[
str
,
List
[
str
],
bool
]:
# Download model weights from huggingface.
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
use_safetensors
=
False
# Some quantized models use .pt files for storing the weights.
allow_patterns
=
[
"*.safetensors"
]
if
use_safetensors
else
[
"*.bin"
,
"*.pt"
]
if
load_format
==
"auto"
:
allow_patterns
=
[
"*.safetensors"
,
"*.bin"
]
elif
load_format
==
"safetensors"
:
use_safetensors
=
True
allow_patterns
=
[
"*.safetensors"
]
elif
load_format
==
"pt"
:
allow_patterns
=
[
"*.pt"
]
elif
load_format
==
"npcache"
:
allow_patterns
=
[
"*.bin"
]
else
:
raise
ValueError
(
f
"Unknown load_format:
{
load_format
}
"
)
if
fall_back_to_pt
:
allow_patterns
+=
[
".pt"
]
if
not
is_local
:
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
...
...
@@ -148,6 +162,10 @@ def prepare_hf_model_weights(
hf_weights_files
:
List
[
str
]
=
[]
for
pattern
in
allow_patterns
:
hf_weights_files
+=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
pattern
))
if
len
(
hf_weights_files
)
>
0
:
if
pattern
==
"*.safetensors"
:
use_safetensors
=
True
break
if
not
use_safetensors
:
# Exclude files that are not needed for inference.
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
...
...
@@ -163,13 +181,6 @@ def prepare_hf_model_weights(
if
not
any
(
f
.
endswith
(
x
)
for
x
in
blacklist
)
]
if
len
(
hf_weights_files
)
==
0
and
use_safetensors
and
fall_back_to_pt
:
return
prepare_hf_model_weights
(
model_name_or_path
,
cache_dir
=
cache_dir
,
use_safetensors
=
False
,
fall_back_to_pt
=
False
,
revision
=
revision
)
if
len
(
hf_weights_files
)
==
0
:
raise
RuntimeError
(
f
"Cannot find any model weights with `
{
model_name_or_path
}
`"
)
...
...
@@ -182,30 +193,16 @@ def hf_model_weights_iterator(
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
fall_back_to_pt
:
Optional
[
bool
]
=
True
,
)
->
Iterator
[
Tuple
[
str
,
torch
.
Tensor
]]:
use_safetensors
=
False
use_np_cache
=
False
fall_back_to_pt
=
False
if
load_format
==
"auto"
:
use_safetensors
=
True
fall_back_to_pt
=
True
elif
load_format
==
"safetensors"
:
use_safetensors
=
True
elif
load_format
==
"pt"
:
pass
elif
load_format
==
"npcache"
:
use_np_cache
=
True
else
:
raise
ValueError
(
f
"Unknown load_format:
{
load_format
}
"
)
hf_folder
,
hf_weights_files
,
use_safetensors
=
prepare_hf_model_weights
(
model_name_or_path
,
cache_dir
=
cache_dir
,
use_safetensors
=
use_safetensors
,
load_format
=
load_format
,
fall_back_to_pt
=
fall_back_to_pt
,
revision
=
revision
)
if
use_
np
_
cache
:
if
load_format
==
"
npcache
"
:
# Currently np_cache only support *.bin checkpoints
assert
use_safetensors
is
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment