Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
9925c179
"torchvision/csrc/io/image/cpu/encode_jpeg.cpp" did not exist on "4b07c78daad661816298add0b17be8d10aef56c9"
Unverified
Commit
9925c179
authored
Jul 19, 2023
by
Antoni Baum
Committed by
GitHub
Jul 19, 2023
Browse files
Ray placement group support (#397)
parent
8c4b2592
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
185 additions
and
114 deletions
+185
-114
requirements.txt
requirements.txt
+1
-1
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+2
-2
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+78
-27
vllm/engine/ray_utils.py
vllm/engine/ray_utils.py
+56
-63
vllm/worker/worker.py
vllm/worker/worker.py
+48
-21
No files found.
requirements.txt
View file @
9925c179
ninja
# For faster builds.
ninja
# For faster builds.
psutil
psutil
ray
ray
>= 2.5.1
sentencepiece
# Required for LLaMA tokenizer.
sentencepiece
# Required for LLaMA tokenizer.
numpy
numpy
torch
>= 2.0.0
torch
>= 2.0.0
...
...
vllm/engine/async_llm_engine.py
View file @
9925c179
...
@@ -226,14 +226,14 @@ class AsyncLLMEngine:
...
@@ -226,14 +226,14 @@ class AsyncLLMEngine:
engine_configs
=
engine_args
.
create_engine_configs
()
engine_configs
=
engine_args
.
create_engine_configs
()
parallel_config
=
engine_configs
[
2
]
parallel_config
=
engine_configs
[
2
]
# Initialize the cluster.
# Initialize the cluster.
distributed_init_method
,
devices
=
initialize_cluster
(
distributed_init_method
,
placement_group
=
initialize_cluster
(
parallel_config
,
engine_args
.
engine_use_ray
)
parallel_config
,
engine_args
.
engine_use_ray
)
# Create the async LLM engine.
# Create the async LLM engine.
engine
=
cls
(
engine_args
.
worker_use_ray
,
engine
=
cls
(
engine_args
.
worker_use_ray
,
engine_args
.
engine_use_ray
,
engine_args
.
engine_use_ray
,
*
engine_configs
,
*
engine_configs
,
distributed_init_method
,
distributed_init_method
,
devices
,
placement_group
,
log_requests
=
not
engine_args
.
disable_log_requests
,
log_requests
=
not
engine_args
.
disable_log_requests
,
log_stats
=
not
engine_args
.
disable_log_stats
)
log_stats
=
not
engine_args
.
disable_log_stats
)
return
engine
return
engine
vllm/engine/llm_engine.py
View file @
9925c179
import
time
import
time
from
typing
import
Any
,
List
,
Optional
from
functools
import
partial
from
typing
import
Any
,
List
,
Optional
,
TYPE_CHECKING
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
SchedulerConfig
)
from
vllm.core.scheduler
import
Scheduler
from
vllm.core.scheduler
import
Scheduler
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.ray_utils
import
DeviceID
,
initialize_cluster
,
ray
from
vllm.engine.ray_utils
import
initialize_cluster
,
ray
,
RayWorker
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
...
@@ -13,7 +14,13 @@ from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
...
@@ -13,7 +14,13 @@ from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from
vllm.transformers_utils.tokenizer
import
(
detokenize_incrementally
,
from
vllm.transformers_utils.tokenizer
import
(
detokenize_incrementally
,
get_tokenizer
)
get_tokenizer
)
from
vllm.utils
import
Counter
from
vllm.utils
import
Counter
from
vllm.worker.worker
import
Worker
if
ray
:
from
ray.air.util.torch_dist
import
init_torch_dist_process_group
from
ray.util.scheduling_strategies
import
PlacementGroupSchedulingStrategy
if
TYPE_CHECKING
:
from
ray.util.placement_group
import
PlacementGroup
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -54,7 +61,7 @@ class LLMEngine:
...
@@ -54,7 +61,7 @@ class LLMEngine:
parallel_config
:
ParallelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
scheduler_config
:
SchedulerConfig
,
distributed_init_method
:
str
,
distributed_init_method
:
str
,
stage_devices
:
List
[
List
[
DeviceID
]
],
placement_group
:
Optional
[
"PlacementGroup"
],
log_stats
:
bool
,
log_stats
:
bool
,
)
->
None
:
)
->
None
:
logger
.
info
(
logger
.
info
(
...
@@ -85,31 +92,73 @@ class LLMEngine:
...
@@ -85,31 +92,73 @@ class LLMEngine:
self
.
seq_counter
=
Counter
()
self
.
seq_counter
=
Counter
()
# Create the parallel GPU workers.
# Create the parallel GPU workers.
self
.
workers
:
List
[
Worker
]
=
[]
assert
len
(
stage_devices
)
==
1
,
"Only support one stage for now."
for
rank
,
node_resource
,
_
in
stage_devices
[
0
]:
worker_cls
=
Worker
if
self
.
parallel_config
.
worker_use_ray
:
if
self
.
parallel_config
.
worker_use_ray
:
worker_cls
=
ray
.
remote
(
self
.
_init_workers_ray
(
placement_group
)
num_cpus
=
0
,
else
:
num_gpus
=
1
,
self
.
_init_workers
(
distributed_init_method
)
resources
=
{
node_resource
:
1e-3
},
)(
worker_cls
).
remote
worker
=
worker_cls
(
model_config
,
parallel_config
,
scheduler_config
,
rank
,
distributed_init_method
,
)
self
.
workers
.
append
(
worker
)
# Profile the memory usage and initialize the cache.
# Profile the memory usage and initialize the cache.
self
.
_init_cache
()
self
.
_init_cache
()
# Create the scheduler.
# Create the scheduler.
self
.
scheduler
=
Scheduler
(
scheduler_config
,
cache_config
,
log_stats
)
self
.
scheduler
=
Scheduler
(
scheduler_config
,
cache_config
,
log_stats
)
def
_init_workers
(
self
,
distributed_init_method
:
str
):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from
vllm.worker.worker
import
Worker
# pylint: disable=import-outside-toplevel
assert
self
.
parallel_config
.
world_size
==
1
,
(
"Ray is required if parallel_config.world_size > 1."
)
self
.
workers
:
List
[
Worker
]
=
[]
worker
=
Worker
(
self
.
model_config
,
self
.
parallel_config
,
self
.
scheduler_config
,
0
,
distributed_init_method
,
)
self
.
workers
.
append
(
worker
)
self
.
_run_workers
(
"init_model"
,
get_all_outputs
=
True
,
)
def
_init_workers_ray
(
self
,
placement_group
:
"PlacementGroup"
):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from
vllm.worker.worker
import
Worker
# pylint: disable=import-outside-toplevel
self
.
workers
:
List
[
Worker
]
=
[]
for
bundle
in
placement_group
.
bundle_specs
:
if
not
bundle
.
get
(
"GPU"
,
0
):
continue
worker
=
ray
.
remote
(
num_cpus
=
0
,
num_gpus
=
1
,
scheduling_strategy
=
PlacementGroupSchedulingStrategy
(
placement_group
=
placement_group
,
placement_group_capture_child_tasks
=
True
),
)(
RayWorker
).
remote
()
self
.
workers
.
append
(
worker
)
# Initialize torch distributed process group for the workers.
init_torch_dist_process_group
(
self
.
workers
,
backend
=
"nccl"
)
self
.
_run_workers
(
"init_worker"
,
get_all_outputs
=
True
,
worker_init_fn
=
lambda
:
Worker
(
self
.
model_config
,
self
.
parallel_config
,
self
.
scheduler_config
,
None
,
None
,
))
self
.
_run_workers
(
"init_model"
,
get_all_outputs
=
True
,
)
def
_verify_args
(
self
)
->
None
:
def
_verify_args
(
self
)
->
None
:
self
.
model_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
self
.
model_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
self
.
cache_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
self
.
cache_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
...
@@ -152,11 +201,12 @@ class LLMEngine:
...
@@ -152,11 +201,12 @@ class LLMEngine:
engine_configs
=
engine_args
.
create_engine_configs
()
engine_configs
=
engine_args
.
create_engine_configs
()
parallel_config
=
engine_configs
[
2
]
parallel_config
=
engine_configs
[
2
]
# Initialize the cluster.
# Initialize the cluster.
distributed_init_method
,
devices
=
initialize_cluster
(
parallel_config
)
distributed_init_method
,
placement_group
=
initialize_cluster
(
parallel_config
)
# Create the LLM engine.
# Create the LLM engine.
engine
=
cls
(
*
engine_configs
,
engine
=
cls
(
*
engine_configs
,
distributed_init_method
,
distributed_init_method
,
devices
,
placement_group
,
log_stats
=
not
engine_args
.
disable_log_stats
)
log_stats
=
not
engine_args
.
disable_log_stats
)
return
engine
return
engine
...
@@ -326,9 +376,10 @@ class LLMEngine:
...
@@ -326,9 +376,10 @@ class LLMEngine:
"""Runs the given method on all workers."""
"""Runs the given method on all workers."""
all_outputs
=
[]
all_outputs
=
[]
for
worker
in
self
.
workers
:
for
worker
in
self
.
workers
:
executor
=
getattr
(
worker
,
method
)
if
self
.
parallel_config
.
worker_use_ray
:
if
self
.
parallel_config
.
worker_use_ray
:
executor
=
executor
.
remote
executor
=
partial
(
worker
.
execute_method
.
remote
,
method
)
else
:
executor
=
getattr
(
worker
,
method
)
output
=
executor
(
*
args
,
**
kwargs
)
output
=
executor
(
*
args
,
**
kwargs
)
all_outputs
.
append
(
output
)
all_outputs
.
append
(
output
)
...
...
vllm/engine/ray_utils.py
View file @
9925c179
import
socket
import
socket
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
,
TYPE_CHECKING
from
vllm.config
import
ParallelConfig
try
:
try
:
import
ray
import
ray
from
ray.air.util.torch_dist
import
TorchDistributedWorker
class
RayWorker
(
TorchDistributedWorker
):
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
def
__init__
(
self
)
->
None
:
self
.
worker
=
None
def
init_worker
(
self
,
worker_init_fn
):
self
.
worker
=
worker_init_fn
()
def
__getattr__
(
self
,
name
):
return
getattr
(
self
.
worker
,
name
)
def
execute_method
(
self
,
method
,
*
args
,
**
kwargs
):
executor
=
getattr
(
self
,
method
)
return
executor
(
*
args
,
**
kwargs
)
except
ImportError
:
except
ImportError
:
ray
=
None
ray
=
None
TorchDistributedWorker
=
None
from
vllm.config
import
ParallelConfig
if
TYPE_CHECKING
:
from
ray.util.placement_group
import
PlacementGroup
# rank, node resource (node IP), device id
DeviceID
=
Tuple
[
int
,
Optional
[
str
],
int
]
def
get_open_port
():
def
get_open_port
():
...
@@ -22,7 +42,7 @@ def initialize_cluster(
...
@@ -22,7 +42,7 @@ def initialize_cluster(
parallel_config
:
ParallelConfig
,
parallel_config
:
ParallelConfig
,
engine_use_ray
:
bool
=
False
,
engine_use_ray
:
bool
=
False
,
ray_address
:
Optional
[
str
]
=
None
,
ray_address
:
Optional
[
str
]
=
None
,
)
->
Tuple
[
str
,
List
[
List
[
DeviceID
]
]]:
)
->
Tuple
[
str
,
Optional
[
"PlacementGroup"
]]:
"""Initialize the distributed cluster probably with Ray.
"""Initialize the distributed cluster probably with Ray.
Args:
Args:
...
@@ -52,63 +72,36 @@ def initialize_cluster(
...
@@ -52,63 +72,36 @@ def initialize_cluster(
# We need to setup the distributed init method to make sure
# We need to setup the distributed init method to make sure
# the distributed megatron code (e.g., get world size) works correctly.
# the distributed megatron code (e.g., get world size) works correctly.
distributed_init_method
=
f
"tcp://localhost:
{
port
}
"
distributed_init_method
=
f
"tcp://localhost:
{
port
}
"
all_stage_devices
=
[[(
0
,
None
,
0
)]]
return
distributed_init_method
,
None
return
distributed_init_method
,
all_stage_devices
current_placement_group
=
ray
.
util
.
get_current_placement_group
()
# Assume we have a uniform cluster that each node has the same number of
if
current_placement_group
:
# GPUs for now.
# We are in a placement group
valid_node_resources
=
[]
bundles
=
current_placement_group
.
bundle_specs
num_devices_per_node
=
None
# Verify that we can use the placement group.
for
node
in
ray
.
nodes
():
gpu_bundles
=
0
if
(
not
node
[
"Alive"
])
or
node
[
"Resources"
][
"GPU"
]
<=
0
:
for
bundle
in
bundles
:
continue
assert
bundle
.
get
(
"GPU"
,
0
)
>
1
,
(
if
num_devices_per_node
is
None
:
"Placement group bundles cannot have more than 1 GPU"
)
num_devices_per_node
=
node
[
"Resources"
][
"GPU"
]
if
bundle
.
get
(
"GPU"
,
0
):
else
:
gpu_bundles
+=
1
assert
num_devices_per_node
==
node
[
"Resources"
][
"GPU"
],
(
if
parallel_config
.
world_size
>
gpu_bundles
:
"The number of GPUs per node is not uniform."
)
for
key
in
node
[
"Resources"
]:
if
key
.
startswith
(
"node:"
):
valid_node_resources
.
append
(
key
)
# Verify the parallel config.
num_nodes
=
len
(
valid_node_resources
)
if
parallel_config
.
world_size
>
num_nodes
*
num_devices_per_node
:
raise
ValueError
(
raise
ValueError
(
"The number of required GPUs exceeds the total number of "
"The number of required GPUs exceeds the total number of "
"available GPUs."
)
"available GPUs in the placement group."
)
if
parallel_config
.
tensor_parallel_size
>=
num_devices_per_node
:
if
parallel_config
.
tensor_parallel_size
%
num_devices_per_node
!=
0
:
raise
ValueError
(
"The number of tensor parallelism is not divisible by the "
"number of GPUs per node."
)
else
:
else
:
if
num_devices_per_node
%
parallel_config
.
tensor_parallel_size
!=
0
:
num_gpus_in_cluster
=
ray
.
cluster_resources
().
get
(
"GPU"
,
0
)
if
parallel_config
.
world_size
>
num_gpus_in_cluster
:
raise
ValueError
(
raise
ValueError
(
"The number of GPUs per node is not divisible by the number "
"The number of required GPUs exceeds the total number of "
"of tensor parallelism."
)
"available GPUs in the cluster."
)
# Create a new placement group
# Assign GPUs to pipeline stages.
current_placement_group
=
ray
.
util
.
placement_group
([{
rank
=
0
"GPU"
:
1
current_node_id
=
0
}]
*
parallel_config
.
world_size
)
current_device_id
=
0
# Wait until PG is ready - this will block until all
distributed_init_method
=
None
# requested resources are available, and will timeout
all_stage_devices
=
[]
# if they cannot be provisioned.
ray
.
get
(
current_placement_group
.
ready
(),
timeout
=
1800
)
for
_
in
range
(
parallel_config
.
pipeline_parallel_size
):
stage_devices
=
[]
return
None
,
current_placement_group
for
_
in
range
(
parallel_config
.
tensor_parallel_size
):
node_resource
=
valid_node_resources
[
current_node_id
]
stage_devices
.
append
((
rank
,
node_resource
,
current_device_id
))
if
distributed_init_method
is
None
:
ip
=
node_resource
.
split
(
"node:"
)[
-
1
]
port
=
get_open_port
()
distributed_init_method
=
f
"tcp://
{
ip
}
:
{
port
}
"
rank
+=
1
current_device_id
+=
1
if
current_device_id
>=
num_devices_per_node
:
current_node_id
+=
1
current_device_id
=
0
all_stage_devices
.
append
(
stage_devices
)
return
distributed_init_method
,
all_stage_devices
vllm/worker/worker.py
View file @
9925c179
"""A GPU worker class."""
"""A GPU worker class."""
from
typing
import
Dict
,
List
,
Tuple
import
os
from
typing
import
Dict
,
List
,
Tuple
,
Optional
import
torch
import
torch
import
torch.distributed
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
SchedulerConfig
)
...
@@ -27,8 +29,8 @@ class Worker:
...
@@ -27,8 +29,8 @@ class Worker:
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
scheduler_config
:
SchedulerConfig
,
rank
:
int
,
rank
:
Optional
[
int
]
=
None
,
distributed_init_method
:
str
,
distributed_init_method
:
Optional
[
str
]
=
None
,
)
->
None
:
)
->
None
:
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
parallel_config
=
parallel_config
...
@@ -36,27 +38,39 @@ class Worker:
...
@@ -36,27 +38,39 @@ class Worker:
self
.
rank
=
rank
self
.
rank
=
rank
self
.
distributed_init_method
=
distributed_init_method
self
.
distributed_init_method
=
distributed_init_method
# Uninitialized cache engine. Will be initialized by
# self.init_cache_engine().
self
.
cache_config
=
None
self
.
block_size
=
None
self
.
cache_engine
=
None
self
.
cache_events
=
None
self
.
gpu_cache
=
None
def
init_model
(
self
):
# This env var set by Ray causes exceptions with graph building.
os
.
environ
.
pop
(
"NCCL_ASYNC_ERROR_HANDLING"
,
None
)
# Env vars will be set by Ray.
self
.
rank
=
self
.
rank
if
self
.
rank
is
not
None
else
int
(
os
.
getenv
(
"RANK"
,
"-1"
))
local_rank
=
int
(
os
.
getenv
(
"LOCAL_RANK"
,
"0"
))
self
.
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
if
self
.
rank
<
0
:
raise
ValueError
(
"Invalid or unspecified rank."
)
torch
.
cuda
.
set_device
(
self
.
device
)
# Initialize the distributed environment.
# Initialize the distributed environment.
_init_distributed_environment
(
parallel_config
,
rank
,
_init_distributed_environment
(
self
.
parallel_config
,
self
.
rank
,
distributed_init_method
)
self
.
distributed_init_method
)
# Initialize the model.
# Initialize the model.
set_random_seed
(
self
.
model_config
.
seed
)
set_random_seed
(
self
.
model_config
.
seed
)
self
.
model
=
get_model
(
model_config
)
self
.
model
=
get_model
(
self
.
model_config
)
initialize_all_reduce_launcher
(
initialize_all_reduce_launcher
(
self
.
scheduler_config
.
max_num_batched_tokens
,
self
.
scheduler_config
.
max_num_batched_tokens
,
self
.
model_config
.
get_hidden_size
(),
self
.
model_config
.
get_hidden_size
(),
self
.
model_config
.
dtype
,
self
.
model_config
.
dtype
,
)
)
# Uninitialized cache engine. Will be initialized by
# self.init_cache_engine().
self
.
cache_config
=
None
self
.
block_size
=
None
self
.
cache_engine
=
None
self
.
cache_events
=
None
self
.
gpu_cache
=
None
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
profile_num_available_blocks
(
def
profile_num_available_blocks
(
self
,
self
,
...
@@ -294,15 +308,28 @@ class Worker:
...
@@ -294,15 +308,28 @@ class Worker:
def
_init_distributed_environment
(
def
_init_distributed_environment
(
parallel_config
:
ParallelConfig
,
parallel_config
:
ParallelConfig
,
rank
:
int
,
rank
:
int
,
distributed_init_method
:
str
,
distributed_init_method
:
Optional
[
str
]
=
None
,
)
->
None
:
)
->
None
:
"""Initialize the distributed environment."""
"""Initialize the distributed environment."""
if
torch
.
distributed
.
is_initialized
():
torch_world_size
=
torch
.
distributed
.
get_world_size
()
if
torch_world_size
!=
parallel_config
.
world_size
:
raise
RuntimeError
(
"torch.distributed is already initialized but the torch world "
"size does not match parallel_config.world_size "
f
"(
{
torch_world_size
}
vs.
{
parallel_config
.
world_size
}
)."
)
elif
not
distributed_init_method
:
raise
ValueError
(
"distributed_init_method must be set if torch.distributed "
"is not already initialized"
)
else
:
torch
.
distributed
.
init_process_group
(
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
,
backend
=
"nccl"
,
world_size
=
parallel_config
.
world_size
,
world_size
=
parallel_config
.
world_size
,
rank
=
rank
,
rank
=
rank
,
init_method
=
distributed_init_method
,
init_method
=
distributed_init_method
,
)
)
# A small all_reduce for warmup.
# A small all_reduce for warmup.
torch
.
distributed
.
all_reduce
(
torch
.
zeros
(
1
).
cuda
())
torch
.
distributed
.
all_reduce
(
torch
.
zeros
(
1
).
cuda
())
initialize_model_parallel
(
parallel_config
.
tensor_parallel_size
,
initialize_model_parallel
(
parallel_config
.
tensor_parallel_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment