Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3eda4ec7
Unverified
Commit
3eda4ec7
authored
Jul 22, 2024
by
Simon Mo
Committed by
GitHub
Jul 22, 2024
Browse files
support ignore patterns in model loader (#6673)
parent
22fa2e35
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
51 additions
and
10 deletions
+51
-10
vllm/config.py
vllm/config.py
+14
-1
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+10
-0
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+21
-8
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+6
-1
No files found.
vllm/config.py
View file @
3eda4ec7
...
...
@@ -599,12 +599,16 @@ class LoadConfig:
mainly for profiling.
"tensorizer" will use CoreWeave's tensorizer library for
fast weight loading.
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
"""
load_format
:
Union
[
str
,
LoadFormat
,
"BaseModelLoader"
]
=
LoadFormat
.
AUTO
download_dir
:
Optional
[
str
]
=
None
model_loader_extra_config
:
Optional
[
Union
[
str
,
dict
]]
=
field
(
default_factory
=
dict
)
ignore_patterns
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
def
__post_init__
(
self
):
model_loader_extra_config
=
self
.
model_loader_extra_config
or
{}
...
...
@@ -613,6 +617,13 @@ class LoadConfig:
model_loader_extra_config
)
self
.
_verify_load_format
()
if
self
.
ignore_patterns
is
not
None
and
len
(
self
.
ignore_patterns
)
>
0
:
logger
.
info
(
"Ignoring the following patterns when downloading weights: %s"
,
self
.
ignore_patterns
)
else
:
self
.
ignore_patterns
=
[
"original/**/*"
]
def
_verify_load_format
(
self
)
->
None
:
if
not
isinstance
(
self
.
load_format
,
str
):
return
...
...
@@ -801,7 +812,9 @@ class SchedulerConfig:
# for higher throughput.
self
.
max_num_batched_tokens
=
max
(
max_model_len
,
2048
)
if
enable_chunked_prefill
:
logger
.
info
(
"Chunked prefill is enabled (EXPERIMENTAL)."
)
logger
.
info
(
"Chunked prefill is enabled with max_num_batched_tokens=%d."
,
max_num_batched_tokens
)
self
.
max_num_seqs
=
max_num_seqs
self
.
max_model_len
=
max_model_len
...
...
vllm/engine/arg_utils.py
View file @
3eda4ec7
...
...
@@ -95,6 +95,7 @@ class EngineArgs:
num_gpu_blocks_override
:
Optional
[
int
]
=
None
num_lookahead_slots
:
int
=
0
model_loader_extra_config
:
Optional
[
dict
]
=
None
ignore_patterns
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
preemption_mode
:
Optional
[
str
]
=
None
scheduler_delay_factor
:
float
=
0.0
...
...
@@ -619,6 +620,14 @@ class EngineArgs:
'corresponding to the chosen load_format. '
'This should be a JSON string that will be '
'parsed into a dictionary.'
)
parser
.
add_argument
(
'--ignore-patterns'
,
action
=
"append"
,
type
=
str
,
default
=
[],
help
=
"The pattern(s) to ignore when loading the model."
"Default to 'original/**/*' to avoid repeated loading of llama's "
"checkpoints."
)
parser
.
add_argument
(
'--preemption-mode'
,
type
=
str
,
...
...
@@ -824,6 +833,7 @@ class EngineArgs:
load_format
=
self
.
load_format
,
download_dir
=
self
.
download_dir
,
model_loader_extra_config
=
self
.
model_loader_extra_config
,
ignore_patterns
=
self
.
ignore_patterns
,
)
prompt_adapter_config
=
PromptAdapterConfig
(
...
...
vllm/model_executor/model_loader/loader.py
View file @
3eda4ec7
...
...
@@ -161,6 +161,7 @@ class DefaultModelLoader(BaseModelLoader):
cache_dir
=
self
.
load_config
.
download_dir
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
revision
=
revision
,
ignore_patterns
=
self
.
load_config
.
ignore_patterns
,
)
else
:
model_path
=
model
...
...
@@ -196,9 +197,13 @@ class DefaultModelLoader(BaseModelLoader):
allow_patterns
+=
[
"*.pt"
]
if
not
is_local
:
hf_folder
=
download_weights_from_hf
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
allow_patterns
,
revision
)
hf_folder
=
download_weights_from_hf
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
allow_patterns
,
revision
,
ignore_patterns
=
self
.
load_config
.
ignore_patterns
,
)
else
:
hf_folder
=
model_name_or_path
...
...
@@ -489,9 +494,13 @@ class ShardedStateLoader(BaseModelLoader):
return
model_name_or_path
else
:
allow_patterns
=
[
"*.safetensors"
]
return
download_weights_from_hf
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
allow_patterns
,
revision
)
return
download_weights_from_hf
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
allow_patterns
,
revision
,
ignore_patterns
=
self
.
load_config
.
ignore_patterns
,
)
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
...
...
@@ -663,8 +672,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
matching_files
=
fnmatch
.
filter
(
repo_files
,
pattern
)
if
matching_files
:
hf_folder
=
download_weights_from_hf
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
[
pattern
],
revision
)
model_name_or_path
,
self
.
load_config
.
download_dir
,
[
pattern
],
revision
,
ignore_patterns
=
self
.
load_config
.
ignore_patterns
,
)
return
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
pattern
)),
pattern
raise
RuntimeError
(
...
...
vllm/model_executor/model_loader/weight_utils.py
View file @
3eda4ec7
...
...
@@ -6,7 +6,7 @@ import json
import
os
import
tempfile
from
collections
import
defaultdict
from
typing
import
Any
,
Generator
,
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Generator
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
filelock
import
huggingface_hub.constants
...
...
@@ -189,6 +189,7 @@ def download_weights_from_hf(
cache_dir
:
Optional
[
str
],
allow_patterns
:
List
[
str
],
revision
:
Optional
[
str
]
=
None
,
ignore_patterns
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
)
->
str
:
"""Download model weights from Hugging Face Hub.
...
...
@@ -200,6 +201,9 @@ def download_weights_from_hf(
weight files. Files matched by any of the patterns will be
downloaded.
revision (Optional[str]): The revision of the model.
ignore_patterns (Optional[Union[str, List[str]]]): The patterns to
filter out the weight files. Files matched by any of the patterns
will be ignored.
Returns:
str: The path to the downloaded model weights.
...
...
@@ -223,6 +227,7 @@ def download_weights_from_hf(
hf_folder
=
snapshot_download
(
model_name_or_path
,
allow_patterns
=
allow_patterns
,
ignore_patterns
=
ignore_patterns
,
cache_dir
=
cache_dir
,
tqdm_class
=
DisabledTqdm
,
revision
=
revision
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment