Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
30e75439
Unverified
Commit
30e75439
authored
May 16, 2024
by
Aurick Qiao
Committed by
GitHub
May 15, 2024
Browse files
[Core] Implement sharded state loader (#4690)
Co-authored-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
52f8107c
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
351 additions
and
0 deletions
+351
-0
examples/save_sharded_state.py
examples/save_sharded_state.py
+75
-0
tests/test_sharded_state_loader.py
tests/test_sharded_state_loader.py
+90
-0
vllm/config.py
vllm/config.py
+1
-0
vllm/executor/distributed_gpu_executor.py
vllm/executor/distributed_gpu_executor.py
+11
-0
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+148
-0
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+14
-0
vllm/worker/worker.py
vllm/worker/worker.py
+12
-0
No files found.
examples/save_sharded_state.py
0 → 100644
View file @
30e75439
"""
Saves each worker's model state dict directly to a checkpoint, which enables a
fast load path for large tensor-parallel models where each worker only needs to
read its own shard rather than the entire checkpoint.
Example usage:
python save_sharded_state.py
\
--model /path/to/load
\
--quantization deepspeedfp
\
--tensor-parallel-size 8
\
--output /path/to/save
Then, the model can be loaded with
llm = LLM(
model="/path/to/save",
load_format="sharded_state",
quantization="deepspeedfp",
tensor_parallel_size=8,
)
"""
import
argparse
import
dataclasses
import
os
import
shutil
from
pathlib
import
Path
from
vllm
import
LLM
,
EngineArgs
parser
=
argparse
.
ArgumentParser
()
EngineArgs
.
add_cli_args
(
parser
)
parser
.
add_argument
(
"--output"
,
"-o"
,
required
=
True
,
type
=
str
,
help
=
"path to output checkpoint"
)
parser
.
add_argument
(
"--file-pattern"
,
type
=
str
,
help
=
"string pattern of saved filenames"
)
parser
.
add_argument
(
"--max-file-size"
,
type
=
str
,
default
=
5
*
1024
**
3
,
help
=
"max size (in bytes) of each safetensors file"
)
def
main
(
args
):
engine_args
=
EngineArgs
.
from_cli_args
(
args
)
if
engine_args
.
enable_lora
:
raise
ValueError
(
"Saving with enable_lora=True is not supported!"
)
model_path
=
engine_args
.
model
if
not
Path
(
model_path
).
is_dir
():
raise
ValueError
(
"model path must be a local directory"
)
# Create LLM instance from arguments
llm
=
LLM
(
**
dataclasses
.
asdict
(
engine_args
))
# Prepare output directory
Path
(
args
.
output
).
mkdir
(
exist_ok
=
True
)
# Dump worker states to output directory
model_executor
=
llm
.
llm_engine
.
model_executor
model_executor
.
save_sharded_state
(
path
=
args
.
output
,
pattern
=
args
.
file_pattern
,
max_size
=
args
.
max_file_size
)
# Copy metadata files to output directory
for
file
in
os
.
listdir
(
model_path
):
if
os
.
path
.
splitext
(
file
)[
1
]
not
in
(
".bin"
,
".pt"
,
".safetensors"
):
if
os
.
path
.
isdir
(
os
.
path
.
join
(
model_path
,
file
)):
shutil
.
copytree
(
os
.
path
.
join
(
model_path
,
file
),
os
.
path
.
join
(
args
.
output
,
file
))
else
:
shutil
.
copy
(
os
.
path
.
join
(
model_path
,
file
),
args
.
output
)
if
__name__
==
"__main__"
:
args
=
parser
.
parse_args
()
main
(
args
)
tests/test_sharded_state_loader.py
0 → 100644
View file @
30e75439
import
os
import
shutil
from
tempfile
import
TemporaryDirectory
import
pytest
import
torch
from
huggingface_hub
import
snapshot_download
from
vllm
import
LLM
,
SamplingParams
from
vllm.model_executor.model_loader.loader
import
ShardedStateLoader
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
# Create a sampling params object.
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
seed
=
0
,
max_tokens
=
256
,
ignore_eos
=
True
,
)
def
test_filter_subtensors
():
state_dict
=
{
"a"
:
torch
.
empty
(
2
),
"b"
:
torch
.
empty
((
2
,
4
)),
"c"
:
torch
.
empty
((
2
,
4
,
8
)),
}
state_dict
.
update
({
"x"
:
state_dict
[
"b"
],
"y"
:
state_dict
[
"c"
][
1
,
2
,
:],
"z"
:
state_dict
[
"c"
][
1
,
:,
4
],
})
filtered_state_dict
=
ShardedStateLoader
.
_filter_subtensors
(
state_dict
)
assert
tuple
(
filtered_state_dict
.
keys
())
==
(
"a"
,
"b"
,
"c"
)
for
key
,
tensor
in
filtered_state_dict
.
items
():
assert
tensor
.
equal
(
state_dict
[
key
])
@
pytest
.
mark
.
parametrize
(
"enable_lora"
,
[
False
,
True
])
def
test_sharded_state_loader
(
enable_lora
):
weights_patterns
=
(
"*.bin"
,
"*.pt"
,
"*.safetensors"
)
with
TemporaryDirectory
()
as
cache_dir
,
TemporaryDirectory
()
as
output_dir
:
input_dir
=
snapshot_download
(
"meta-llama/Llama-2-7b-hf"
,
cache_dir
=
cache_dir
)
llm
=
LLM
(
model
=
input_dir
,
worker_use_ray
=
True
,
gpu_memory_utilization
=
0.3
,
)
# Dump worker states to output directory
model_executor
=
llm
.
llm_engine
.
model_executor
model_executor
.
save_sharded_state
(
path
=
output_dir
)
# Copy metadata files to output directory
for
file
in
os
.
listdir
(
input_dir
):
if
not
any
(
file
.
endswith
(
ext
)
for
ext
in
weights_patterns
):
shutil
.
copy
(
f
"
{
input_dir
}
/
{
file
}
"
,
output_dir
)
del
llm
.
llm_engine
.
model_executor
llm_before
=
LLM
(
model
=
input_dir
,
worker_use_ray
=
True
,
enable_lora
=
enable_lora
,
gpu_memory_utilization
=
0.3
,
)
gen_before
=
llm_before
.
generate
(
prompts
,
sampling_params
)
out_before
=
[
gen
.
outputs
[
0
].
__dict__
for
gen
in
gen_before
]
del
llm_before
.
llm_engine
.
model_executor
llm_after
=
LLM
(
model
=
output_dir
,
worker_use_ray
=
True
,
enable_lora
=
enable_lora
,
gpu_memory_utilization
=
0.3
,
load_format
=
"sharded_state"
,
)
gen_after
=
llm_after
.
generate
(
prompts
,
sampling_params
)
out_after
=
[
gen
.
outputs
[
0
].
__dict__
for
gen
in
gen_after
]
del
llm_after
.
llm_engine
.
model_executor
assert
out_before
==
out_after
vllm/config.py
View file @
30e75439
...
...
@@ -463,6 +463,7 @@ class LoadFormat(str, enum.Enum):
NPCACHE
=
"npcache"
DUMMY
=
"dummy"
TENSORIZER
=
"tensorizer"
SHARDED_STATE
=
"sharded_state"
@
dataclass
...
...
vllm/executor/distributed_gpu_executor.py
View file @
30e75439
...
...
@@ -77,6 +77,17 @@ class DistributedGPUExecutor(GPUExecutor):
def
list_loras
(
self
)
->
Set
[
int
]:
return
self
.
_run_workers
(
"list_loras"
)
def
save_sharded_state
(
self
,
path
:
str
,
pattern
:
Optional
[
str
]
=
None
,
max_size
:
Optional
[
int
]
=
None
,
)
->
None
:
self
.
_run_workers
(
"save_sharded_state"
,
path
=
path
,
pattern
=
pattern
,
max_size
=
max_size
)
@
abstractmethod
def
_run_workers
(
self
,
...
...
vllm/model_executor/model_loader/loader.py
View file @
30e75439
# ruff: noqa: SIM117
import
collections
import
copy
import
glob
import
os
...
...
@@ -366,6 +367,150 @@ class TensorizerLoader(BaseModelLoader):
cache_config
)
class
ShardedStateLoader
(
BaseModelLoader
):
"""
Model loader that directly loads each worker's model state dict, which
enables a fast load path for large tensor-parallel models where each worker
only needs to read its own shard rather than the entire checkpoint. See
`examples/save_sharded_states.py` for creating a sharded checkpoint.
"""
DEFAULT_PATTERN
=
"model-rank-{rank}-part-{part}.safetensors"
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
extra_config
=
({}
if
load_config
.
model_loader_extra_config
is
None
else
load_config
.
model_loader_extra_config
.
copy
())
self
.
pattern
=
extra_config
.
pop
(
"pattern"
,
self
.
DEFAULT_PATTERN
)
if
extra_config
:
raise
ValueError
(
f
"Unexpected extra config keys for load format "
f
"
{
load_config
.
load_format
}
: "
f
"
{
load_config
.
model_loader_extra_config
.
keys
()
}
"
)
@
staticmethod
def
_filter_subtensors
(
tensors
:
Dict
[
str
,
torch
.
Tensor
])
->
Dict
[
str
,
torch
.
Tensor
]:
"""
Filter out all tensors that share the same memory or a subset of the
memory of another tensor.
"""
same_storage_groups
=
collections
.
defaultdict
(
list
)
for
key
,
tensor
in
tensors
.
items
():
if
tensor
.
numel
():
ptr
=
tensor
.
untyped_storage
().
data_ptr
()
same_storage_groups
[
tensor
.
device
,
ptr
].
append
((
key
,
tensor
))
def
get_end_ptr
(
tensor
:
torch
.
Tensor
)
->
int
:
return
tensor
.
view
(
-
1
)[
-
1
].
data_ptr
()
+
tensor
.
element_size
()
result
=
{}
for
group
in
same_storage_groups
.
values
():
for
k
,
t
in
group
:
a
,
b
=
t
.
data_ptr
(),
get_end_ptr
(
t
)
for
k2
,
t2
in
group
:
if
not
t2
.
is_contiguous
():
continue
a2
,
b2
=
t2
.
data_ptr
(),
get_end_ptr
(
t2
)
if
a
<
a2
or
b2
<
b
:
continue
if
a2
<
a
or
b
<
b2
or
not
t
.
is_contiguous
():
break
# t2 covers strictly more memory than t.
if
k2
<
k
:
# Same tensors, keep the one with the smaller key.
break
else
:
result
[
k
]
=
t
return
result
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
)
->
nn
.
Module
:
from
safetensors.torch
import
safe_open
from
vllm.distributed
import
get_tensor_model_parallel_rank
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
lora_config
,
vision_language_config
,
cache_config
)
rank
=
get_tensor_model_parallel_rank
()
pattern
=
os
.
path
.
join
(
model_config
.
model
,
self
.
pattern
.
format
(
rank
=
rank
,
part
=
"*"
),
)
filepaths
=
glob
.
glob
(
pattern
)
if
not
filepaths
:
# TODO: support un-sharded checkpoints too
raise
ValueError
(
f
"Could not find checkpoint files '
{
pattern
}
', only "
f
"pre-sharded checkpoints are currently supported!"
)
state_dict
=
self
.
_filter_subtensors
(
model
.
state_dict
())
for
path
in
filepaths
:
with
safe_open
(
path
,
framework
=
"pt"
)
as
f
:
for
key
in
f
.
keys
():
# noqa: SIM118
tensor
=
f
.
get_tensor
(
key
)
# If loading with LoRA enabled, additional padding may
# be added to certain parameters. We only load into a
# narrowed view of the parameter data.
param_data
=
state_dict
[
key
].
data
param_shape
=
state_dict
[
key
].
shape
for
dim
,
size
in
enumerate
(
tensor
.
shape
):
if
size
<
param_shape
[
dim
]:
param_data
=
param_data
.
narrow
(
dim
,
0
,
size
)
if
tensor
.
shape
!=
param_shape
:
logger
.
warning
(
"loading tensor of shape %s into "
"parameter '%s' of shape %s"
,
tensor
.
shape
,
key
,
param_shape
)
param_data
.
copy_
(
tensor
)
state_dict
.
pop
(
key
)
if
state_dict
:
raise
ValueError
(
f
"Missing keys
{
tuple
(
state_dict
)
}
in loaded state!"
)
return
model
.
eval
()
@
staticmethod
def
save_model
(
model
:
torch
.
nn
.
Module
,
path
:
str
,
pattern
:
Optional
[
str
]
=
None
,
max_size
:
Optional
[
int
]
=
None
,
)
->
None
:
from
safetensors.torch
import
save_file
from
vllm.distributed
import
get_tensor_model_parallel_rank
if
pattern
is
None
:
pattern
=
ShardedStateLoader
.
DEFAULT_PATTERN
rank
=
get_tensor_model_parallel_rank
()
part_idx
=
0
total_size
=
0
state_dict
=
ShardedStateLoader
.
_filter_subtensors
(
model
.
state_dict
())
state_dict_part
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
for
key
,
tensor
in
state_dict
.
items
():
param_size
=
tensor
.
nelement
()
*
tensor
.
element_size
()
if
max_size
is
not
None
and
total_size
+
param_size
>
max_size
:
filename
=
pattern
.
format
(
rank
=
rank
,
part
=
part_idx
)
save_file
(
state_dict_part
,
os
.
path
.
join
(
path
,
filename
),
)
part_idx
+=
1
total_size
=
0
state_dict_part
=
{}
state_dict_part
[
key
]
=
tensor
total_size
+=
param_size
if
len
(
state_dict_part
)
>
0
:
filename
=
pattern
.
format
(
rank
=
rank
,
part
=
part_idx
)
save_file
(
state_dict_part
,
os
.
path
.
join
(
path
,
filename
),
)
def
get_model_loader
(
load_config
:
LoadConfig
)
->
BaseModelLoader
:
"""Get a model loader based on the load format."""
...
...
@@ -378,4 +523,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
if
load_config
.
load_format
==
LoadFormat
.
TENSORIZER
:
return
TensorizerLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
SHARDED_STATE
:
return
ShardedStateLoader
(
load_config
)
return
DefaultModelLoader
(
load_config
)
vllm/worker/model_runner.py
View file @
30e75439
...
...
@@ -182,6 +182,20 @@ class ModelRunner:
"but the KV cache data type is not FP8. "
"KV cache scaling factors will not be used."
)
def
save_sharded_state
(
self
,
path
:
str
,
pattern
:
Optional
[
str
]
=
None
,
max_size
:
Optional
[
int
]
=
None
,
)
->
None
:
from
vllm.model_executor.model_loader.loader
import
ShardedStateLoader
ShardedStateLoader
.
save_model
(
self
.
model
,
path
,
pattern
=
pattern
,
max_size
=
max_size
,
)
def
get_max_block_per_batch
(
self
)
->
int
:
block_size
=
self
.
block_size
return
(
self
.
max_seq_len_to_capture
+
block_size
-
1
)
//
block_size
...
...
vllm/worker/worker.py
View file @
30e75439
...
...
@@ -119,6 +119,18 @@ class Worker(WorkerBase):
def
load_model
(
self
):
self
.
model_runner
.
load_model
()
def
save_sharded_state
(
self
,
path
:
str
,
pattern
:
Optional
[
str
]
=
None
,
max_size
:
Optional
[
int
]
=
None
,
)
->
None
:
self
.
model_runner
.
save_sharded_state
(
path
,
pattern
=
pattern
,
max_size
=
max_size
,
)
@
torch
.
inference_mode
()
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""Profiles the peak memory usage of the model to determine how many
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment