Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
109e15a3
Unverified
Commit
109e15a3
authored
May 01, 2025
by
Jerry Zhang
Committed by
GitHub
May 01, 2025
Browse files
Add `pt_load_map_location` to allow loading to cuda (#16869)
Signed-off-by:
Jerry Zhang
<
jerryzh168@gmail.com
>
parent
f192ca90
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
74 additions
and
3 deletions
+74
-3
tests/quantization/test_torchao.py
tests/quantization/test_torchao.py
+26
-0
tests/test_config.py
tests/test_config.py
+15
-1
vllm/config.py
vllm/config.py
+10
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+17
-1
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+2
-0
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+4
-1
No files found.
tests/quantization/test_torchao.py
View file @
109e15a3
...
...
@@ -3,6 +3,7 @@ import importlib.metadata
import
importlib.util
import
pytest
import
torch
DTYPE
=
[
"bfloat16"
]
...
...
@@ -21,5 +22,30 @@ def test_pre_quantized_model(vllm_runner):
print
(
output
)
@
pytest
.
mark
.
skipif
(
not
TORCHAO_AVAILABLE
,
reason
=
"torchao is not available"
)
@
pytest
.
mark
.
parametrize
(
"pt_load_map_location"
,
[
"cuda:0"
,
# {"": "cuda"},
])
def
test_opt_125m_int4wo_model_loading_with_params
(
vllm_runner
,
pt_load_map_location
):
"""
Test loading roberta-base model with no lm_head.
"""
torch
.
_dynamo
.
reset
()
model_name
=
"jerryzh168/opt-125m-int4wo"
with
vllm_runner
(
model_name
=
model_name
,
quantization
=
"torchao"
,
dtype
=
"bfloat16"
,
pt_load_map_location
=
pt_load_map_location
)
as
llm
:
output
=
llm
.
generate_greedy
([
"The capital of France is"
],
max_tokens
=
32
)
assert
output
print
(
output
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
tests/test_config.py
View file @
109e15a3
...
...
@@ -5,7 +5,8 @@ from typing import Literal, Union
import
pytest
from
vllm.config
import
ModelConfig
,
PoolerConfig
,
config
,
get_field
from
vllm.config
import
(
LoadConfig
,
ModelConfig
,
PoolerConfig
,
VllmConfig
,
config
,
get_field
)
from
vllm.model_executor.layers.pooler
import
PoolingType
from
vllm.platforms
import
current_platform
...
...
@@ -410,3 +411,16 @@ def test_generation_config_loading():
override_generation_config
=
override_generation_config
)
assert
model_config
.
get_diff_sampling_param
()
==
override_generation_config
@
pytest
.
mark
.
parametrize
(
"pt_load_map_location"
,
[
"cuda"
,
{
""
:
"cuda"
},
])
def
test_load_config_pt_load_map_location
(
pt_load_map_location
):
load_config
=
LoadConfig
(
pt_load_map_location
=
pt_load_map_location
)
config
=
VllmConfig
(
load_config
=
load_config
)
assert
config
.
load_config
.
pt_load_map_location
==
pt_load_map_location
vllm/config.py
View file @
109e15a3
...
...
@@ -1564,6 +1564,16 @@ class LoadConfig:
use_tqdm_on_load
:
bool
=
True
"""Whether to enable tqdm for showing progress bar when loading model
weights."""
pt_load_map_location
:
Union
[
str
,
dict
[
str
,
str
]]
=
"cpu"
"""
pt_load_map_location: the map location for loading pytorch checkpoint, to
support loading checkpoints can only be loaded on certain devices like
"cuda", this is equivalent to {"": "cuda"}. Another supported format is
mapping from different devices like from GPU 1 to GPU 0:
{"cuda:1": "cuda:0"}. Note that when passed from command line, the strings
in dictionary needs to be double quoted for json parsing. For more details,
see original doc for `map_location` in https://pytorch.org/docs/stable/generated/torch.load.html
"""
def
compute_hash
(
self
)
->
str
:
"""
...
...
vllm/engine/arg_utils.py
View file @
109e15a3
...
...
@@ -64,6 +64,13 @@ def optional_type(
return
_optional_type
def
union_dict_and_str
(
val
:
str
)
->
Optional
[
Union
[
str
,
dict
[
str
,
str
]]]:
if
not
re
.
match
(
"^{.*}$"
,
val
):
return
str
(
val
)
else
:
return
optional_type
(
json
.
loads
)(
val
)
@
deprecated
(
"Passing a JSON argument as a string containing comma separated key=value "
"pairs is deprecated. This will be removed in v0.10.0. Please use a JSON "
...
...
@@ -187,6 +194,10 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
kwargs
[
name
][
"type"
]
=
human_readable_int
elif
contains_type
(
type_hints
,
float
):
kwargs
[
name
][
"type"
]
=
float
elif
contains_type
(
type_hints
,
dict
)
and
(
contains_type
(
type_hints
,
str
)
or
any
(
is_not_builtin
(
th
)
for
th
in
type_hints
)):
kwargs
[
name
][
"type"
]
=
union_dict_and_str
elif
contains_type
(
type_hints
,
dict
):
# Dict arguments will always be optional
kwargs
[
name
][
"type"
]
=
optional_type
(
json
.
loads
)
...
...
@@ -371,6 +382,7 @@ class EngineArgs:
reasoning_parser
:
str
=
DecodingConfig
.
reasoning_backend
use_tqdm_on_load
:
bool
=
LoadConfig
.
use_tqdm_on_load
pt_load_map_location
:
str
=
LoadConfig
.
pt_load_map_location
def
__post_init__
(
self
):
# support `EngineArgs(compilation_config={...})`
...
...
@@ -491,6 +503,8 @@ class EngineArgs:
type
=
str
,
default
=
None
,
help
=
'Name or path of the QLoRA adapter.'
)
load_group
.
add_argument
(
'--pt-load-map-location'
,
**
load_kwargs
[
"pt_load_map_location"
])
# Guided decoding arguments
guided_decoding_kwargs
=
get_kwargs
(
DecodingConfig
)
...
...
@@ -883,12 +897,14 @@ class EngineArgs:
if
self
.
quantization
==
"bitsandbytes"
:
self
.
load_format
=
"bitsandbytes"
return
LoadConfig
(
load_format
=
self
.
load_format
,
download_dir
=
self
.
download_dir
,
model_loader_extra_config
=
self
.
model_loader_extra_config
,
ignore_patterns
=
self
.
ignore_patterns
,
use_tqdm_on_load
=
self
.
use_tqdm_on_load
,
pt_load_map_location
=
self
.
pt_load_map_location
,
)
def
create_speculative_config
(
...
...
@@ -1513,7 +1529,7 @@ def _warn_or_fallback(feature_name: str) -> bool:
def
human_readable_int
(
value
):
"""Parse human-readable integers like '1k', '2M', etc.
Including decimal values with decimal multipliers.
Examples:
- '1k' -> 1,000
- '1K' -> 1,024
...
...
vllm/model_executor/model_loader/loader.py
View file @
109e15a3
...
...
@@ -384,6 +384,7 @@ class DefaultModelLoader(BaseModelLoader):
weights_iterator
=
pt_weights_iterator
(
hf_weights_files
,
self
.
load_config
.
use_tqdm_on_load
,
self
.
load_config
.
pt_load_map_location
,
)
if
current_platform
.
is_tpu
():
...
...
@@ -890,6 +891,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
iterator
=
pt_weights_iterator
(
hf_weights_files
,
self
.
load_config
.
use_tqdm_on_load
,
self
.
load_config
.
pt_load_map_location
,
)
for
org_name
,
param
in
iterator
:
# mapping weight names from transformers to vllm while preserving
...
...
vllm/model_executor/model_loader/weight_utils.py
View file @
109e15a3
...
...
@@ -502,6 +502,7 @@ def fastsafetensors_weights_iterator(
def
pt_weights_iterator
(
hf_weights_files
:
List
[
str
],
use_tqdm_on_load
:
bool
,
pt_load_map_location
:
Union
[
str
,
dict
[
str
,
str
]]
=
"cpu"
,
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Iterate over the weights in the model bin/pt files."""
for
bin_file
in
tqdm
(
...
...
@@ -510,7 +511,9 @@ def pt_weights_iterator(
disable
=
not
enable_tqdm
(
use_tqdm_on_load
),
bar_format
=
_BAR_FORMAT
,
):
state
=
torch
.
load
(
bin_file
,
map_location
=
"cpu"
,
weights_only
=
True
)
state
=
torch
.
load
(
bin_file
,
map_location
=
pt_load_map_location
,
weights_only
=
True
)
yield
from
state
.
items
()
del
state
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment