Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
30e75439
Unverified
Commit
30e75439
authored
May 16, 2024
by
Aurick Qiao
Committed by
GitHub
May 15, 2024
Browse files
[Core] Implement sharded state loader (#4690)
Co-authored-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
52f8107c
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
351 additions
and
0 deletions
+351
-0
examples/save_sharded_state.py
examples/save_sharded_state.py
+75
-0
tests/test_sharded_state_loader.py
tests/test_sharded_state_loader.py
+90
-0
vllm/config.py
vllm/config.py
+1
-0
vllm/executor/distributed_gpu_executor.py
vllm/executor/distributed_gpu_executor.py
+11
-0
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+148
-0
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+14
-0
vllm/worker/worker.py
vllm/worker/worker.py
+12
-0
No files found.
examples/save_sharded_state.py
0 → 100644
View file @
30e75439
"""
Saves each worker's model state dict directly to a checkpoint, which enables a
fast load path for large tensor-parallel models where each worker only needs to
read its own shard rather than the entire checkpoint.
Example usage:
python save_sharded_state.py
\
--model /path/to/load
\
--quantization deepspeedfp
\
--tensor-parallel-size 8
\
--output /path/to/save
Then, the model can be loaded with
llm = LLM(
model="/path/to/save",
load_format="sharded_state",
quantization="deepspeedfp",
tensor_parallel_size=8,
)
"""
import
argparse
import
dataclasses
import
os
import
shutil
from
pathlib
import
Path
from
vllm
import
LLM
,
EngineArgs
parser
=
argparse
.
ArgumentParser
()
EngineArgs
.
add_cli_args
(
parser
)
parser
.
add_argument
(
"--output"
,
"-o"
,
required
=
True
,
type
=
str
,
help
=
"path to output checkpoint"
)
parser
.
add_argument
(
"--file-pattern"
,
type
=
str
,
help
=
"string pattern of saved filenames"
)
parser
.
add_argument
(
"--max-file-size"
,
type
=
str
,
default
=
5
*
1024
**
3
,
help
=
"max size (in bytes) of each safetensors file"
)
def
main
(
args
):
engine_args
=
EngineArgs
.
from_cli_args
(
args
)
if
engine_args
.
enable_lora
:
raise
ValueError
(
"Saving with enable_lora=True is not supported!"
)
model_path
=
engine_args
.
model
if
not
Path
(
model_path
).
is_dir
():
raise
ValueError
(
"model path must be a local directory"
)
# Create LLM instance from arguments
llm
=
LLM
(
**
dataclasses
.
asdict
(
engine_args
))
# Prepare output directory
Path
(
args
.
output
).
mkdir
(
exist_ok
=
True
)
# Dump worker states to output directory
model_executor
=
llm
.
llm_engine
.
model_executor
model_executor
.
save_sharded_state
(
path
=
args
.
output
,
pattern
=
args
.
file_pattern
,
max_size
=
args
.
max_file_size
)
# Copy metadata files to output directory
for
file
in
os
.
listdir
(
model_path
):
if
os
.
path
.
splitext
(
file
)[
1
]
not
in
(
".bin"
,
".pt"
,
".safetensors"
):
if
os
.
path
.
isdir
(
os
.
path
.
join
(
model_path
,
file
)):
shutil
.
copytree
(
os
.
path
.
join
(
model_path
,
file
),
os
.
path
.
join
(
args
.
output
,
file
))
else
:
shutil
.
copy
(
os
.
path
.
join
(
model_path
,
file
),
args
.
output
)
if
__name__
==
"__main__"
:
args
=
parser
.
parse_args
()
main
(
args
)
tests/test_sharded_state_loader.py
0 → 100644
View file @
30e75439
import
os
import
shutil
from
tempfile
import
TemporaryDirectory
import
pytest
import
torch
from
huggingface_hub
import
snapshot_download
from
vllm
import
LLM
,
SamplingParams
from
vllm.model_executor.model_loader.loader
import
ShardedStateLoader
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
# Create a sampling params object.
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
seed
=
0
,
max_tokens
=
256
,
ignore_eos
=
True
,
)
def
test_filter_subtensors
():
state_dict
=
{
"a"
:
torch
.
empty
(
2
),
"b"
:
torch
.
empty
((
2
,
4
)),
"c"
:
torch
.
empty
((
2
,
4
,
8
)),
}
state_dict
.
update
({
"x"
:
state_dict
[
"b"
],
"y"
:
state_dict
[
"c"
][
1
,
2
,
:],
"z"
:
state_dict
[
"c"
][
1
,
:,
4
],
})
filtered_state_dict
=
ShardedStateLoader
.
_filter_subtensors
(
state_dict
)
assert
tuple
(
filtered_state_dict
.
keys
())
==
(
"a"
,
"b"
,
"c"
)
for
key
,
tensor
in
filtered_state_dict
.
items
():
assert
tensor
.
equal
(
state_dict
[
key
])
@
pytest
.
mark
.
parametrize
(
"enable_lora"
,
[
False
,
True
])
def
test_sharded_state_loader
(
enable_lora
):
weights_patterns
=
(
"*.bin"
,
"*.pt"
,
"*.safetensors"
)
with
TemporaryDirectory
()
as
cache_dir
,
TemporaryDirectory
()
as
output_dir
:
input_dir
=
snapshot_download
(
"meta-llama/Llama-2-7b-hf"
,
cache_dir
=
cache_dir
)
llm
=
LLM
(
model
=
input_dir
,
worker_use_ray
=
True
,
gpu_memory_utilization
=
0.3
,
)
# Dump worker states to output directory
model_executor
=
llm
.
llm_engine
.
model_executor
model_executor
.
save_sharded_state
(
path
=
output_dir
)
# Copy metadata files to output directory
for
file
in
os
.
listdir
(
input_dir
):
if
not
any
(
file
.
endswith
(
ext
)
for
ext
in
weights_patterns
):
shutil
.
copy
(
f
"
{
input_dir
}
/
{
file
}
"
,
output_dir
)
del
llm
.
llm_engine
.
model_executor
llm_before
=
LLM
(
model
=
input_dir
,
worker_use_ray
=
True
,
enable_lora
=
enable_lora
,
gpu_memory_utilization
=
0.3
,
)
gen_before
=
llm_before
.
generate
(
prompts
,
sampling_params
)
out_before
=
[
gen
.
outputs
[
0
].
__dict__
for
gen
in
gen_before
]
del
llm_before
.
llm_engine
.
model_executor
llm_after
=
LLM
(
model
=
output_dir
,
worker_use_ray
=
True
,
enable_lora
=
enable_lora
,
gpu_memory_utilization
=
0.3
,
load_format
=
"sharded_state"
,
)
gen_after
=
llm_after
.
generate
(
prompts
,
sampling_params
)
out_after
=
[
gen
.
outputs
[
0
].
__dict__
for
gen
in
gen_after
]
del
llm_after
.
llm_engine
.
model_executor
assert
out_before
==
out_after
vllm/config.py
View file @
30e75439
...
@@ -463,6 +463,7 @@ class LoadFormat(str, enum.Enum):
...
@@ -463,6 +463,7 @@ class LoadFormat(str, enum.Enum):
NPCACHE
=
"npcache"
NPCACHE
=
"npcache"
DUMMY
=
"dummy"
DUMMY
=
"dummy"
TENSORIZER
=
"tensorizer"
TENSORIZER
=
"tensorizer"
SHARDED_STATE
=
"sharded_state"
@
dataclass
@
dataclass
...
...
vllm/executor/distributed_gpu_executor.py
View file @
30e75439
...
@@ -77,6 +77,17 @@ class DistributedGPUExecutor(GPUExecutor):
...
@@ -77,6 +77,17 @@ class DistributedGPUExecutor(GPUExecutor):
def
list_loras
(
self
)
->
Set
[
int
]:
def
list_loras
(
self
)
->
Set
[
int
]:
return
self
.
_run_workers
(
"list_loras"
)
return
self
.
_run_workers
(
"list_loras"
)
def
save_sharded_state
(
self
,
path
:
str
,
pattern
:
Optional
[
str
]
=
None
,
max_size
:
Optional
[
int
]
=
None
,
)
->
None
:
self
.
_run_workers
(
"save_sharded_state"
,
path
=
path
,
pattern
=
pattern
,
max_size
=
max_size
)
@
abstractmethod
@
abstractmethod
def
_run_workers
(
def
_run_workers
(
self
,
self
,
...
...
vllm/model_executor/model_loader/loader.py
View file @
30e75439
# ruff: noqa: SIM117
# ruff: noqa: SIM117
import
collections
import
copy
import
copy
import
glob
import
glob
import
os
import
os
...
@@ -366,6 +367,150 @@ class TensorizerLoader(BaseModelLoader):
...
@@ -366,6 +367,150 @@ class TensorizerLoader(BaseModelLoader):
cache_config
)
cache_config
)
class
ShardedStateLoader
(
BaseModelLoader
):
"""
Model loader that directly loads each worker's model state dict, which
enables a fast load path for large tensor-parallel models where each worker
only needs to read its own shard rather than the entire checkpoint. See
`examples/save_sharded_states.py` for creating a sharded checkpoint.
"""
DEFAULT_PATTERN
=
"model-rank-{rank}-part-{part}.safetensors"
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
extra_config
=
({}
if
load_config
.
model_loader_extra_config
is
None
else
load_config
.
model_loader_extra_config
.
copy
())
self
.
pattern
=
extra_config
.
pop
(
"pattern"
,
self
.
DEFAULT_PATTERN
)
if
extra_config
:
raise
ValueError
(
f
"Unexpected extra config keys for load format "
f
"
{
load_config
.
load_format
}
: "
f
"
{
load_config
.
model_loader_extra_config
.
keys
()
}
"
)
@
staticmethod
def
_filter_subtensors
(
tensors
:
Dict
[
str
,
torch
.
Tensor
])
->
Dict
[
str
,
torch
.
Tensor
]:
"""
Filter out all tensors that share the same memory or a subset of the
memory of another tensor.
"""
same_storage_groups
=
collections
.
defaultdict
(
list
)
for
key
,
tensor
in
tensors
.
items
():
if
tensor
.
numel
():
ptr
=
tensor
.
untyped_storage
().
data_ptr
()
same_storage_groups
[
tensor
.
device
,
ptr
].
append
((
key
,
tensor
))
def
get_end_ptr
(
tensor
:
torch
.
Tensor
)
->
int
:
return
tensor
.
view
(
-
1
)[
-
1
].
data_ptr
()
+
tensor
.
element_size
()
result
=
{}
for
group
in
same_storage_groups
.
values
():
for
k
,
t
in
group
:
a
,
b
=
t
.
data_ptr
(),
get_end_ptr
(
t
)
for
k2
,
t2
in
group
:
if
not
t2
.
is_contiguous
():
continue
a2
,
b2
=
t2
.
data_ptr
(),
get_end_ptr
(
t2
)
if
a
<
a2
or
b2
<
b
:
continue
if
a2
<
a
or
b
<
b2
or
not
t
.
is_contiguous
():
break
# t2 covers strictly more memory than t.
if
k2
<
k
:
# Same tensors, keep the one with the smaller key.
break
else
:
result
[
k
]
=
t
return
result
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
)
->
nn
.
Module
:
from
safetensors.torch
import
safe_open
from
vllm.distributed
import
get_tensor_model_parallel_rank
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
lora_config
,
vision_language_config
,
cache_config
)
rank
=
get_tensor_model_parallel_rank
()
pattern
=
os
.
path
.
join
(
model_config
.
model
,
self
.
pattern
.
format
(
rank
=
rank
,
part
=
"*"
),
)
filepaths
=
glob
.
glob
(
pattern
)
if
not
filepaths
:
# TODO: support un-sharded checkpoints too
raise
ValueError
(
f
"Could not find checkpoint files '
{
pattern
}
', only "
f
"pre-sharded checkpoints are currently supported!"
)
state_dict
=
self
.
_filter_subtensors
(
model
.
state_dict
())
for
path
in
filepaths
:
with
safe_open
(
path
,
framework
=
"pt"
)
as
f
:
for
key
in
f
.
keys
():
# noqa: SIM118
tensor
=
f
.
get_tensor
(
key
)
# If loading with LoRA enabled, additional padding may
# be added to certain parameters. We only load into a
# narrowed view of the parameter data.
param_data
=
state_dict
[
key
].
data
param_shape
=
state_dict
[
key
].
shape
for
dim
,
size
in
enumerate
(
tensor
.
shape
):
if
size
<
param_shape
[
dim
]:
param_data
=
param_data
.
narrow
(
dim
,
0
,
size
)
if
tensor
.
shape
!=
param_shape
:
logger
.
warning
(
"loading tensor of shape %s into "
"parameter '%s' of shape %s"
,
tensor
.
shape
,
key
,
param_shape
)
param_data
.
copy_
(
tensor
)
state_dict
.
pop
(
key
)
if
state_dict
:
raise
ValueError
(
f
"Missing keys
{
tuple
(
state_dict
)
}
in loaded state!"
)
return
model
.
eval
()
@
staticmethod
def
save_model
(
model
:
torch
.
nn
.
Module
,
path
:
str
,
pattern
:
Optional
[
str
]
=
None
,
max_size
:
Optional
[
int
]
=
None
,
)
->
None
:
from
safetensors.torch
import
save_file
from
vllm.distributed
import
get_tensor_model_parallel_rank
if
pattern
is
None
:
pattern
=
ShardedStateLoader
.
DEFAULT_PATTERN
rank
=
get_tensor_model_parallel_rank
()
part_idx
=
0
total_size
=
0
state_dict
=
ShardedStateLoader
.
_filter_subtensors
(
model
.
state_dict
())
state_dict_part
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
for
key
,
tensor
in
state_dict
.
items
():
param_size
=
tensor
.
nelement
()
*
tensor
.
element_size
()
if
max_size
is
not
None
and
total_size
+
param_size
>
max_size
:
filename
=
pattern
.
format
(
rank
=
rank
,
part
=
part_idx
)
save_file
(
state_dict_part
,
os
.
path
.
join
(
path
,
filename
),
)
part_idx
+=
1
total_size
=
0
state_dict_part
=
{}
state_dict_part
[
key
]
=
tensor
total_size
+=
param_size
if
len
(
state_dict_part
)
>
0
:
filename
=
pattern
.
format
(
rank
=
rank
,
part
=
part_idx
)
save_file
(
state_dict_part
,
os
.
path
.
join
(
path
,
filename
),
)
def
get_model_loader
(
load_config
:
LoadConfig
)
->
BaseModelLoader
:
def
get_model_loader
(
load_config
:
LoadConfig
)
->
BaseModelLoader
:
"""Get a model loader based on the load format."""
"""Get a model loader based on the load format."""
...
@@ -378,4 +523,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
...
@@ -378,4 +523,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
if
load_config
.
load_format
==
LoadFormat
.
TENSORIZER
:
if
load_config
.
load_format
==
LoadFormat
.
TENSORIZER
:
return
TensorizerLoader
(
load_config
)
return
TensorizerLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
SHARDED_STATE
:
return
ShardedStateLoader
(
load_config
)
return
DefaultModelLoader
(
load_config
)
return
DefaultModelLoader
(
load_config
)
vllm/worker/model_runner.py
View file @
30e75439
...
@@ -182,6 +182,20 @@ class ModelRunner:
...
@@ -182,6 +182,20 @@ class ModelRunner:
"but the KV cache data type is not FP8. "
"but the KV cache data type is not FP8. "
"KV cache scaling factors will not be used."
)
"KV cache scaling factors will not be used."
)
def
save_sharded_state
(
self
,
path
:
str
,
pattern
:
Optional
[
str
]
=
None
,
max_size
:
Optional
[
int
]
=
None
,
)
->
None
:
from
vllm.model_executor.model_loader.loader
import
ShardedStateLoader
ShardedStateLoader
.
save_model
(
self
.
model
,
path
,
pattern
=
pattern
,
max_size
=
max_size
,
)
def
get_max_block_per_batch
(
self
)
->
int
:
def
get_max_block_per_batch
(
self
)
->
int
:
block_size
=
self
.
block_size
block_size
=
self
.
block_size
return
(
self
.
max_seq_len_to_capture
+
block_size
-
1
)
//
block_size
return
(
self
.
max_seq_len_to_capture
+
block_size
-
1
)
//
block_size
...
...
vllm/worker/worker.py
View file @
30e75439
...
@@ -119,6 +119,18 @@ class Worker(WorkerBase):
...
@@ -119,6 +119,18 @@ class Worker(WorkerBase):
def
load_model
(
self
):
def
load_model
(
self
):
self
.
model_runner
.
load_model
()
self
.
model_runner
.
load_model
()
def
save_sharded_state
(
self
,
path
:
str
,
pattern
:
Optional
[
str
]
=
None
,
max_size
:
Optional
[
int
]
=
None
,
)
->
None
:
self
.
model_runner
.
save_sharded_state
(
path
,
pattern
=
pattern
,
max_size
=
max_size
,
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""Profiles the peak memory usage of the model to determine how many
"""Profiles the peak memory usage of the model to determine how many
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment