Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
91977095
Unverified
Commit
91977095
authored
May 24, 2024
by
Robert Shaw
Committed by
GitHub
May 24, 2024
Browse files
[Bugfix] Fix Mistral v0.3 Weight Loading (#5005)
Co-authored-by:
Cody Yu
<
hao.yu.cody@gmail.com
>
parent
6a50f4ca
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
79 additions
and
3 deletions
+79
-3
tests/models/test_mistral.py
tests/models/test_mistral.py
+1
-0
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+15
-2
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+63
-1
No files found.
tests/models/test_mistral.py
View file @
91977095
...
@@ -8,6 +8,7 @@ from .utils import check_logprobs_close
...
@@ -8,6 +8,7 @@ from .utils import check_logprobs_close
MODELS
=
[
MODELS
=
[
"mistralai/Mistral-7B-Instruct-v0.1"
,
"mistralai/Mistral-7B-Instruct-v0.1"
,
"mistralai/Mistral-7B-Instruct-v0.3"
,
]
]
...
...
vllm/model_executor/model_loader/loader.py
View file @
91977095
...
@@ -23,7 +23,8 @@ from vllm.model_executor.model_loader.tensorizer import (
...
@@ -23,7 +23,8 @@ from vllm.model_executor.model_loader.tensorizer import (
from
vllm.model_executor.model_loader.utils
import
(
get_model_architecture
,
from
vllm.model_executor.model_loader.utils
import
(
get_model_architecture
,
set_default_torch_dtype
)
set_default_torch_dtype
)
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
download_weights_from_hf
,
filter_files_not_needed_for_inference
,
download_safetensors_index_file_from_hf
,
download_weights_from_hf
,
filter_duplicate_safetensors_files
,
filter_files_not_needed_for_inference
,
get_quant_config
,
initialize_dummy_weights
,
np_cache_weights_iterator
,
get_quant_config
,
initialize_dummy_weights
,
np_cache_weights_iterator
,
pt_weights_iterator
,
safetensors_weights_iterator
)
pt_weights_iterator
,
safetensors_weights_iterator
)
from
vllm.model_executor.models.vlm_base
import
VisionLanguageModelBase
from
vllm.model_executor.models.vlm_base
import
VisionLanguageModelBase
...
@@ -188,7 +189,19 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -188,7 +189,19 @@ class DefaultModelLoader(BaseModelLoader):
use_safetensors
=
True
use_safetensors
=
True
break
break
if
not
use_safetensors
:
if
use_safetensors
:
# For models like Mistral-7B-Instruct-v0.3
# there are both sharded safetensors files and a consolidated
# safetensors file. Using both breaks.
# Here, we download the `model.safetensors.index.json` and filter
# any files not found in the index.
if
not
is_local
:
download_safetensors_index_file_from_hf
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
revision
)
hf_weights_files
=
filter_duplicate_safetensors_files
(
hf_weights_files
,
hf_folder
)
else
:
hf_weights_files
=
filter_files_not_needed_for_inference
(
hf_weights_files
=
filter_files_not_needed_for_inference
(
hf_weights_files
)
hf_weights_files
)
...
...
vllm/model_executor/model_loader/weight_utils.py
View file @
91977095
...
@@ -12,9 +12,10 @@ import filelock
...
@@ -12,9 +12,10 @@ import filelock
import
huggingface_hub.constants
import
huggingface_hub.constants
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
huggingface_hub
import
HfFileSystem
,
snapshot_download
from
huggingface_hub
import
HfFileSystem
,
hf_hub_download
,
snapshot_download
from
safetensors.torch
import
load_file
,
safe_open
,
save_file
from
safetensors.torch
import
load_file
,
safe_open
,
save_file
from
tqdm.auto
import
tqdm
from
tqdm.auto
import
tqdm
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
vllm.config
import
LoadConfig
,
ModelConfig
from
vllm.config
import
LoadConfig
,
ModelConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -218,6 +219,67 @@ def download_weights_from_hf(
...
@@ -218,6 +219,67 @@ def download_weights_from_hf(
return
hf_folder
return
hf_folder
def
download_safetensors_index_file_from_hf
(
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
],
revision
:
Optional
[
str
]
=
None
,
)
->
None
:
"""Download hf safetensors index file from Hugging Face Hub.
Args:
model_name_or_path (str): The model name or path.
cache_dir (Optional[str]): The cache directory to store the model
weights. If None, will use HF defaults.
revision (Optional[str]): The revision of the model.
"""
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with
get_lock
(
model_name_or_path
,
cache_dir
):
try
:
# Download the safetensors index file.
hf_hub_download
(
repo_id
=
model_name_or_path
,
filename
=
SAFE_WEIGHTS_INDEX_NAME
,
cache_dir
=
cache_dir
,
revision
=
revision
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
)
# If file not found on remote or locally, we should not fail since
# only some models will have SAFE_WEIGHTS_INDEX_NAME.
except
huggingface_hub
.
utils
.
EntryNotFoundError
:
logger
.
info
(
"No %s found in remote."
,
SAFE_WEIGHTS_INDEX_NAME
)
except
huggingface_hub
.
utils
.
LocalEntryNotFoundError
:
logger
.
info
(
"No %s found in local cache."
,
SAFE_WEIGHTS_INDEX_NAME
)
# For models like Mistral-7B-v0.3, there are both sharded
# safetensors files and a consolidated safetensors file.
# Passing both of these to the weight loader functionality breaks.
# So, we use the SAFE_WEIGHTS_INDEX_NAME to
# look up which safetensors files should be used.
def
filter_duplicate_safetensors_files
(
hf_weights_files
:
List
[
str
],
hf_folder
:
str
)
->
List
[
str
]:
# model.safetensors.index.json is a mapping from keys in the
# torch state_dict to safetensors file holding that weight.
index_file_name
=
os
.
path
.
join
(
hf_folder
,
SAFE_WEIGHTS_INDEX_NAME
)
if
not
os
.
path
.
isfile
(
index_file_name
):
return
hf_weights_files
# Iterate through the weight_map (weight_name: safetensors files)
# to identify weights that we should use.
with
open
(
index_file_name
)
as
index_file
:
weight_map
=
json
.
load
(
index_file
)[
"weight_map"
]
weight_files_in_index
=
set
()
for
weight_name
in
weight_map
:
weight_files_in_index
.
add
(
os
.
path
.
join
(
hf_folder
,
weight_map
[
weight_name
]))
# Filter out any fields that are not found in the index file.
hf_weights_files
=
[
f
for
f
in
hf_weights_files
if
f
in
weight_files_in_index
]
return
hf_weights_files
def
filter_files_not_needed_for_inference
(
def
filter_files_not_needed_for_inference
(
hf_weights_files
:
List
[
str
])
->
List
[
str
]:
hf_weights_files
:
List
[
str
])
->
List
[
str
]:
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment