Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
4bb6b671
Unverified
Commit
4bb6b671
authored
Nov 21, 2023
by
boydfd
Committed by
GitHub
Nov 20, 2023
Browse files
fix RAM OOM when load large models in tensor parallel mode. (#1395)
Co-authored-by:
ran_lin
<
rlin@thoughtworks.com
>
parent
819b18e7
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
52 additions
and
7 deletions
+52
-7
vllm/config.py
vllm/config.py
+2
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+9
-1
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+39
-6
vllm/worker/worker.py
vllm/worker/worker.py
+2
-0
No files found.
vllm/config.py
View file @
4bb6b671
...
...
@@ -285,10 +285,12 @@ class ParallelConfig:
pipeline_parallel_size
:
int
,
tensor_parallel_size
:
int
,
worker_use_ray
:
bool
,
max_parallel_loading_workers
:
Optional
[
int
]
=
None
,
)
->
None
:
self
.
pipeline_parallel_size
=
pipeline_parallel_size
self
.
tensor_parallel_size
=
tensor_parallel_size
self
.
worker_use_ray
=
worker_use_ray
self
.
max_parallel_loading_workers
=
max_parallel_loading_workers
self
.
world_size
=
pipeline_parallel_size
*
tensor_parallel_size
if
self
.
world_size
>
1
:
...
...
vllm/engine/arg_utils.py
View file @
4bb6b671
...
...
@@ -22,6 +22,7 @@ class EngineArgs:
worker_use_ray
:
bool
=
False
pipeline_parallel_size
:
int
=
1
tensor_parallel_size
:
int
=
1
max_parallel_loading_workers
:
Optional
[
int
]
=
None
block_size
:
int
=
16
swap_space
:
int
=
4
# GiB
gpu_memory_utilization
:
float
=
0.90
...
...
@@ -128,6 +129,12 @@ class EngineArgs:
type
=
int
,
default
=
EngineArgs
.
tensor_parallel_size
,
help
=
'number of tensor parallel replicas'
)
parser
.
add_argument
(
'--max-parallel-loading-workers'
,
type
=
int
,
help
=
'load model sequentially in multiple batches, '
'to avoid RAM OOM when using tensor '
'parallel and large models'
)
# KV cache arguments
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
...
...
@@ -195,7 +202,8 @@ class EngineArgs:
getattr
(
model_config
.
hf_config
,
'sliding_window'
,
None
))
parallel_config
=
ParallelConfig
(
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
,
self
.
worker_use_ray
)
self
.
worker_use_ray
,
self
.
max_parallel_loading_workers
)
scheduler_config
=
SchedulerConfig
(
self
.
max_num_batched_tokens
,
self
.
max_num_seqs
,
model_config
.
max_model_len
,
...
...
vllm/engine/llm_engine.py
View file @
4bb6b671
...
...
@@ -143,6 +143,12 @@ class LLMEngine:
"init_model"
,
get_all_outputs
=
True
,
)
self
.
_run_workers
(
"load_model"
,
get_all_outputs
=
True
,
max_concurrent_workers
=
self
.
parallel_config
.
max_parallel_loading_workers
,
)
def
_init_workers_ray
(
self
,
placement_group
:
"PlacementGroup"
,
**
ray_remote_kwargs
):
...
...
@@ -182,6 +188,12 @@ class LLMEngine:
"init_model"
,
get_all_outputs
=
True
,
)
self
.
_run_workers
(
"load_model"
,
get_all_outputs
=
True
,
max_concurrent_workers
=
self
.
parallel_config
.
max_parallel_loading_workers
,
)
def
_verify_args
(
self
)
->
None
:
self
.
model_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
...
...
@@ -682,16 +694,15 @@ class LLMEngine:
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
return
def
_run_workers
(
def
_run_workers
_in_batch
(
self
,
workers
,
method
:
str
,
*
args
,
get_all_outputs
:
bool
=
False
,
**
kwargs
,
)
->
Any
:
"""Runs the given method on all workers."""
):
all_outputs
=
[]
for
worker
in
self
.
workers
:
for
worker
in
workers
:
if
self
.
parallel_config
.
worker_use_ray
:
executor
=
partial
(
worker
.
execute_method
.
remote
,
method
)
else
:
...
...
@@ -699,9 +710,31 @@ class LLMEngine:
output
=
executor
(
*
args
,
**
kwargs
)
all_outputs
.
append
(
output
)
if
self
.
parallel_config
.
worker_use_ray
:
all_outputs
=
ray
.
get
(
all_outputs
)
return
all_outputs
def
_run_workers
(
self
,
method
:
str
,
*
args
,
get_all_outputs
:
bool
=
False
,
max_concurrent_workers
:
Optional
[
int
]
=
None
,
**
kwargs
,
)
->
Any
:
"""Runs the given method on all workers."""
all_outputs
=
[]
if
max_concurrent_workers
:
work_groups
=
[
self
.
workers
[
i
:
i
+
max_concurrent_workers
]
for
i
in
range
(
0
,
len
(
self
.
workers
),
max_concurrent_workers
)
]
else
:
work_groups
=
[
self
.
workers
]
for
workers
in
work_groups
:
all_outputs
.
extend
(
self
.
_run_workers_in_batch
(
workers
,
method
,
*
args
,
**
kwargs
))
if
get_all_outputs
:
return
all_outputs
...
...
vllm/worker/worker.py
View file @
4bb6b671
...
...
@@ -67,6 +67,8 @@ class Worker:
# Initialize the model.
set_random_seed
(
self
.
model_config
.
seed
)
def
load_model
(
self
):
self
.
model
=
get_model
(
self
.
model_config
)
@
torch
.
inference_mode
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment