Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6d0e3d37
Unverified
Commit
6d0e3d37
authored
Jan 18, 2025
by
youkaichao
Committed by
GitHub
Jan 18, 2025
Browse files
[core] clean up executor class hierarchy between v1 and v0 (#12171)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
02798eca
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
61 additions
and
798 deletions
+61
-798
vllm/executor/executor_base.py
vllm/executor/executor_base.py
+0
-10
vllm/v1/executor/abstract.py
vllm/v1/executor/abstract.py
+58
-29
vllm/v1/executor/multiproc_executor.py
vllm/v1/executor/multiproc_executor.py
+3
-47
vllm/v1/executor/ray_executor.py
vllm/v1/executor/ray_executor.py
+0
-344
vllm/v1/executor/ray_utils.py
vllm/v1/executor/ray_utils.py
+0
-280
vllm/v1/executor/uniproc_executor.py
vllm/v1/executor/uniproc_executor.py
+0
-88
No files found.
vllm/executor/executor_base.py
View file @
6d0e3d37
...
@@ -79,16 +79,6 @@ class ExecutorBase(ABC):
...
@@ -79,16 +79,6 @@ class ExecutorBase(ABC):
b
=
min
([
r
[
1
]
for
r
in
results
])
b
=
min
([
r
[
1
]
for
r
in
results
])
return
a
,
b
return
a
,
b
def
initialize
(
self
,
num_gpu_blocks
:
int
)
->
None
:
"""
Initialize the KV caches and begin the model execution loop of the
underlying workers.
For V1 compatibility.
"""
logger
.
info
(
"# GPU blocks: %d"
,
num_gpu_blocks
)
self
.
collective_rpc
(
"initialize_cache"
,
args
=
(
num_gpu_blocks
,
))
self
.
collective_rpc
(
"compile_or_warm_up_model"
)
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
)
->
None
:
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
)
->
None
:
"""Initialize the KV cache by invoking the underlying worker.
"""Initialize the KV cache by invoking the underlying worker.
"""
"""
...
...
vllm/v1/executor/abstract.py
View file @
6d0e3d37
from
abc
import
ABC
,
abstractmethod
from
typing
import
Type
from
typing
import
Type
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.executor.ray_distributed_executor
import
(
# noqa
RayDistributedExecutor
as
RayDistributedExecutorV0
)
from
vllm.executor.uniproc_executor
import
(
# noqa
ExecutorWithExternalLauncher
as
ExecutorWithExternalLauncherV0
)
from
vllm.executor.uniproc_executor
import
(
# noqa
UniProcExecutor
as
UniProcExecutorV0
)
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
KVCacheSpec
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
KVCacheSpec
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
class
Executor
(
ABC
):
class
Executor
(
ExecutorBase
):
"""Abstract class for executors."""
"""
Abstract class for v1 executors, mainly define some methods for v1.
For methods shared by v0 and v1, define them in ExecutorBase"""
@
staticmethod
@
staticmethod
def
get_class
(
vllm_config
:
VllmConfig
)
->
Type
[
"Executor"
]:
def
get_class
(
vllm_config
:
VllmConfig
)
->
Type
[
"Executor"
]:
executor_class
:
Type
[
Executor
]
executor_class
:
Type
[
Executor
]
parallel_config
=
vllm_config
.
parallel_config
distributed_executor_backend
=
(
distributed_executor_backend
=
(
vllm_config
.
parallel_config
.
distributed_executor_backend
)
parallel_config
.
distributed_executor_backend
)
if
distributed_executor_backend
is
None
:
# If the user does not specify the distributed executor backend,
# we will choose the backend based on the world size.
if
parallel_config
.
world_size
>
1
:
distributed_executor_backend
=
"mp"
else
:
distributed_executor_backend
=
"uni"
if
distributed_executor_backend
==
"ray"
:
if
distributed_executor_backend
==
"ray"
:
from
vllm.executor.ray_distributed_executor
import
(
# noqa
RayDistributedExecutor
)
executor_class
=
RayDistributedExecutor
executor_class
=
RayDistributedExecutor
elif
distributed_executor_backend
==
"mp"
:
elif
distributed_executor_backend
==
"mp"
:
from
vllm.v1.executor.multiproc_executor
import
MultiprocExecutor
from
vllm.v1.executor.multiproc_executor
import
MultiprocExecutor
executor_class
=
MultiprocExecutor
executor_class
=
MultiprocExecutor
elif
distributed_executor_backend
==
"uni"
:
executor_class
=
UniProcExecutor
elif
distributed_executor_backend
==
"external_launcher"
:
# TODO: make v1 scheduling deterministic
# to support external launcher
executor_class
=
ExecutorWithExternalLauncher
else
:
else
:
assert
(
distributed_executor_backend
is
None
)
raise
ValueError
(
"Unknown distributed executor backend: "
from
vllm.v1.executor.uniproc_executor
import
UniprocExecutor
f
"
{
distributed_executor_backend
}
"
)
executor_class
=
UniprocExecutor
return
executor_class
return
executor_class
@
abstractmethod
def
__init__
(
self
,
vllm_config
:
VllmConfig
)
->
None
:
raise
NotImplementedError
@
abstractmethod
def
initialize
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
def
initialize
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
raise
NotImplementedError
"""
Initialize the KV caches and begin the model execution loop of the
underlying workers.
"""
self
.
collective_rpc
(
"initialize_cache"
,
args
=
(
kv_cache_config
,
))
self
.
collective_rpc
(
"compile_or_warm_up_model"
)
@
abstractmethod
def
determine_available_memory
(
self
)
->
int
:
# in bytes
def
determine_available_memory
(
self
)
->
int
:
# in bytes
raise
NotImplementedError
output
=
self
.
collective_rpc
(
"determine_available_memory"
)
# Since we use a shared centralized controller, we take the minimum
# memory size across all workers to make sure all the memory
# operators can be applied to all workers.
return
min
(
output
)
@
abstractmethod
def
get_kv_cache_spec
(
self
)
->
KVCacheSpec
:
def
get_kv_cache_spec
(
self
)
->
KVCacheSpec
:
raise
NotImplementedError
output
=
self
.
collective_rpc
(
"get_kv_cache_spec"
)
for
x
in
output
:
assert
x
==
output
[
0
]
return
output
[
0
]
@
abstractmethod
def
execute_model
(
def
execute_model
(
self
,
self
,
scheduler_output
,
scheduler_output
,
)
->
ModelRunnerOutput
:
)
->
ModelRunnerOutput
:
raise
NotImplementedError
output
=
self
.
collective_rpc
(
"execute_model"
,
args
=
(
scheduler_output
,
))
return
output
[
0
]
@
abstractmethod
def
profile
(
self
,
is_start
:
bool
=
True
):
def
profile
(
self
,
is_start
:
bool
=
True
):
raise
NotImplementedError
self
.
collective_rpc
(
"profile"
,
args
=
(
is_start
,
))
class
UniProcExecutor
(
UniProcExecutorV0
,
Executor
):
pass
class
ExecutorWithExternalLauncher
(
ExecutorWithExternalLauncherV0
,
Executor
):
pass
@
abstractmethod
def
shutdown
(
self
):
pass
@
abstractmethod
class
RayDistributedExecutor
(
RayDistributedExecutorV0
,
Executor
):
def
check_health
(
self
)
->
None
:
pass
raise
NotImplementedError
vllm/v1/executor/multiproc_executor.py
View file @
6d0e3d37
...
@@ -25,8 +25,6 @@ from vllm.logger import init_logger
...
@@ -25,8 +25,6 @@ from vllm.logger import init_logger
from
vllm.utils
import
(
get_distributed_init_method
,
get_mp_context
,
from
vllm.utils
import
(
get_distributed_init_method
,
get_mp_context
,
get_open_port
,
get_open_zmq_ipc_path
,
zmq_socket_ctx
)
get_open_port
,
get_open_zmq_ipc_path
,
zmq_socket_ctx
)
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
KVCacheSpec
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.worker.worker_base
import
WorkerWrapperBase
from
vllm.worker.worker_base
import
WorkerWrapperBase
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -37,7 +35,7 @@ POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
...
@@ -37,7 +35,7 @@ POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
class
MultiprocExecutor
(
Executor
):
class
MultiprocExecutor
(
Executor
):
def
_
_init_
_
(
self
,
vllm_config
:
VllmConfig
)
->
None
:
def
_init_
executor
(
self
)
->
None
:
# Call self.shutdown at exit to clean up
# Call self.shutdown at exit to clean up
# and ensure workers will be terminated.
# and ensure workers will be terminated.
self
.
_finalizer
=
weakref
.
finalize
(
self
,
self
.
shutdown
)
self
.
_finalizer
=
weakref
.
finalize
(
self
,
self
.
shutdown
)
...
@@ -55,9 +53,6 @@ class MultiprocExecutor(Executor):
...
@@ -55,9 +53,6 @@ class MultiprocExecutor(Executor):
signal
.
signal
(
signal
.
SIGUSR1
,
sigusr1_handler
)
signal
.
signal
(
signal
.
SIGUSR1
,
sigusr1_handler
)
self
.
vllm_config
=
vllm_config
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
world_size
=
self
.
parallel_config
.
world_size
self
.
world_size
=
self
.
parallel_config
.
world_size
tensor_parallel_size
=
self
.
parallel_config
.
tensor_parallel_size
tensor_parallel_size
=
self
.
parallel_config
.
tensor_parallel_size
assert
self
.
world_size
==
tensor_parallel_size
,
(
assert
self
.
world_size
==
tensor_parallel_size
,
(
...
@@ -82,7 +77,8 @@ class MultiprocExecutor(Executor):
...
@@ -82,7 +77,8 @@ class MultiprocExecutor(Executor):
# Create workers
# Create workers
self
.
workers
:
List
[
WorkerProcHandle
]
=
[]
self
.
workers
:
List
[
WorkerProcHandle
]
=
[]
for
rank
in
range
(
self
.
world_size
):
for
rank
in
range
(
self
.
world_size
):
worker
=
WorkerProc
.
make_worker_process
(
vllm_config
,
rank
,
rank
,
worker
=
WorkerProc
.
make_worker_process
(
self
.
vllm_config
,
rank
,
rank
,
distributed_init_method
,
distributed_init_method
,
scheduler_output_handle
)
scheduler_output_handle
)
self
.
workers
.
append
(
worker
)
self
.
workers
.
append
(
worker
)
...
@@ -93,34 +89,6 @@ class MultiprocExecutor(Executor):
...
@@ -93,34 +89,6 @@ class MultiprocExecutor(Executor):
for
w
in
self
.
workers
:
for
w
in
self
.
workers
:
w
.
worker_response_mq
.
wait_until_ready
()
w
.
worker_response_mq
.
wait_until_ready
()
def
initialize
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
"""
Initialize the KV caches and begin the model execution loop of the
underlying workers.
"""
self
.
collective_rpc
(
"initialize_cache"
,
args
=
(
kv_cache_config
,
))
self
.
collective_rpc
(
"compile_or_warm_up_model"
)
def
determine_available_memory
(
self
)
->
int
:
"""
Determine the available memory (in bytes) for KV cache by invoking the
underlying worker.
"""
memory_sizes
=
self
.
collective_rpc
(
"determine_available_memory"
)
# Since we use a shared centralized controller, we take the minimum
# memory size across all workers to make sure all the memory
# operators can be applied to all workers.
return
min
(
memory_sizes
)
def
get_kv_cache_spec
(
self
)
->
KVCacheSpec
:
"""
Get all kv cache needed by the model by invoking the underlying worker.
"""
kv_cache_specs
=
self
.
collective_rpc
(
"get_kv_cache_spec"
)
assert
all
(
s
==
kv_cache_specs
[
0
]
for
s
in
kv_cache_specs
)
return
kv_cache_specs
[
0
]
def
collective_rpc
(
self
,
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
],
method
:
Union
[
str
,
Callable
],
timeout
:
Optional
[
float
]
=
None
,
timeout
:
Optional
[
float
]
=
None
,
...
@@ -172,18 +140,6 @@ class MultiprocExecutor(Executor):
...
@@ -172,18 +140,6 @@ class MultiprocExecutor(Executor):
# Re-raise any other exceptions
# Re-raise any other exceptions
raise
e
raise
e
def
execute_model
(
self
,
scheduler_output
,
)
->
ModelRunnerOutput
:
model_output
=
self
.
collective_rpc
(
"execute_model"
,
args
=
(
scheduler_output
,
))[
0
]
return
model_output
def
profile
(
self
,
is_start
:
bool
=
True
):
self
.
collective_rpc
(
"profile"
,
args
=
(
is_start
,
))
return
def
_ensure_worker_termination
(
self
):
def
_ensure_worker_termination
(
self
):
"""Ensure that all worker processes are terminated. Assumes workers have
"""Ensure that all worker processes are terminated. Assumes workers have
received termination requests. Waits for processing, then sends
received termination requests. Waits for processing, then sends
...
...
vllm/v1/executor/ray_executor.py
deleted
100644 → 0
View file @
02798eca
import
os
from
collections
import
defaultdict
from
itertools
import
islice
,
repeat
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
import
vllm.envs
as
envs
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.executor.ray_utils
import
(
RayWorkerWrapper
,
initialize_ray_cluster
,
ray
)
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
KVCacheSpec
from
vllm.v1.outputs
import
ModelRunnerOutput
if
ray
is
not
None
:
from
ray.util.scheduling_strategies
import
PlacementGroupSchedulingStrategy
if
TYPE_CHECKING
:
from
ray.util.placement_group
import
PlacementGroup
logger
=
init_logger
(
__name__
)
class
RayExecutor
(
Executor
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
)
->
None
:
self
.
vllm_config
=
vllm_config
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
model_config
=
vllm_config
.
model_config
self
.
forward_dag
:
Optional
[
ray
.
dag
.
CompiledDAG
]
=
None
# Disable Ray usage stats collection.
ray_usage
=
os
.
environ
.
get
(
"RAY_USAGE_STATS_ENABLED"
,
"0"
)
if
ray_usage
!=
"1"
:
os
.
environ
[
"RAY_USAGE_STATS_ENABLED"
]
=
"0"
initialize_ray_cluster
(
self
.
parallel_config
)
placement_group
=
self
.
parallel_config
.
placement_group
# Create the parallel GPU workers.
self
.
_init_workers_ray
(
placement_group
)
def
_init_workers_ray
(
self
,
placement_group
:
"PlacementGroup"
,
**
ray_remote_kwargs
):
# A list of workers to run a model.
self
.
workers
:
List
[
RayWorkerWrapper
]
=
[]
if
self
.
parallel_config
.
ray_workers_use_nsight
:
ray_remote_kwargs
=
self
.
_configure_ray_workers_use_nsight
(
ray_remote_kwargs
)
# Create the workers.
driver_ip
=
get_ip
()
for
bundle_id
,
bundle
in
enumerate
(
placement_group
.
bundle_specs
):
if
not
bundle
.
get
(
"GPU"
,
0
):
# Skip bundles that don't have GPUs,
# as each worker needs one GPU.
continue
scheduling_strategy
=
PlacementGroupSchedulingStrategy
(
placement_group
=
placement_group
,
placement_group_capture_child_tasks
=
True
,
placement_group_bundle_index
=
bundle_id
,
)
worker
=
ray
.
remote
(
num_cpus
=
0
,
num_gpus
=
1
,
scheduling_strategy
=
scheduling_strategy
,
**
ray_remote_kwargs
,
)(
RayWorkerWrapper
).
remote
(
vllm_config
=
self
.
vllm_config
)
self
.
workers
.
append
(
worker
)
logger
.
debug
(
"workers: %s"
,
self
.
workers
)
worker_ips
=
[
ray
.
get
(
worker
.
get_node_ip
.
remote
())
# type: ignore[attr-defined]
for
worker
in
self
.
workers
]
ip_counts
:
Dict
[
str
,
int
]
=
{}
for
ip
in
worker_ips
:
ip_counts
[
ip
]
=
ip_counts
.
get
(
ip
,
0
)
+
1
worker_to_ip
=
dict
(
zip
(
self
.
workers
,
worker_ips
))
def
sort_by_driver_then_worker_ip
(
worker
):
"""
Sort the workers based on 3 properties:
1. If the worker is on the same node as the driver (vllm engine),
it should be placed first.
2. Then, if the worker is on a node with fewer workers, it should
be placed first.
3. Finally, if the work is on a node with smaller IP address, it
should be placed first. This is simply a tiebreaker to make
sure the workers are sorted in a deterministic way.
"""
ip
=
worker_to_ip
[
worker
]
return
(
ip
!=
driver_ip
,
ip_counts
[
ip
],
ip
)
# After sorting, the workers on the same node will be
# close to each other, and the workers on the driver
# node will be placed first.
self
.
workers
=
sorted
(
self
.
workers
,
key
=
sort_by_driver_then_worker_ip
)
# Get the set of GPU IDs used on each node.
worker_node_and_gpu_ids
=
self
.
_run_workers
(
"get_node_and_gpu_ids"
)
node_workers
=
defaultdict
(
list
)
# node id -> list of worker ranks
node_gpus
=
defaultdict
(
list
)
# node id -> list of gpu ids
for
i
,
(
node_id
,
gpu_ids
)
in
enumerate
(
worker_node_and_gpu_ids
):
node_workers
[
node_id
].
append
(
i
)
# `gpu_ids` can be a list of strings or integers.
# convert them to integers for consistency.
# NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
# string sorting is not sufficient.
# see https://github.com/vllm-project/vllm/issues/5590
gpu_ids
=
[
int
(
x
)
for
x
in
gpu_ids
]
node_gpus
[
node_id
].
extend
(
gpu_ids
)
for
node_id
,
gpu_ids
in
node_gpus
.
items
():
node_gpus
[
node_id
]
=
sorted
(
gpu_ids
)
all_ips
=
set
(
worker_ips
)
n_ips
=
len
(
all_ips
)
n_nodes
=
len
(
node_workers
)
if
n_nodes
!=
n_ips
:
raise
RuntimeError
(
f
"Every node should have a unique IP address. Got
{
n_nodes
}
"
f
" nodes with node ids
{
list
(
node_workers
.
keys
())
}
and "
f
"
{
n_ips
}
unique IP addresses
{
all_ips
}
. Please check your"
" network configuration. If you set `VLLM_HOST_IP` or "
"`HOST_IP` environment variable, make sure it is unique for"
" each node."
)
# Set environment variables for the driver and workers.
all_args_to_update_environment_variables
=
[({
"CUDA_VISIBLE_DEVICES"
:
","
.
join
(
map
(
str
,
node_gpus
[
node_id
])),
"VLLM_TRACE_FUNCTION"
:
str
(
envs
.
VLLM_TRACE_FUNCTION
),
"VLLM_USE_V1"
:
str
(
int
(
envs
.
VLLM_USE_V1
)),
**
({
"VLLM_ATTENTION_BACKEND"
:
envs
.
VLLM_ATTENTION_BACKEND
}
if
envs
.
VLLM_ATTENTION_BACKEND
is
not
None
else
{})
},
)
for
(
node_id
,
_
)
in
worker_node_and_gpu_ids
]
self
.
_env_vars_for_all_workers
=
(
all_args_to_update_environment_variables
)
self
.
_run_workers
(
"update_environment_variables"
,
all_args
=
self
.
_get_env_vars_to_be_updated
())
if
len
(
node_gpus
)
==
1
:
# in single node case, we don't need to get the IP address.
# the loopback address is sufficient
# NOTE: a node may have several IP addresses, one for each
# network interface. `get_ip()` might return any of them,
# while they might not work for communication inside the node
# if the network setup is complicated. Using the loopback address
# solves this issue, as it always works for communication inside
# the node.
driver_ip
=
"127.0.0.1"
distributed_init_method
=
get_distributed_init_method
(
driver_ip
,
get_open_port
())
# Initialize the actual workers inside worker wrapper.
init_worker_all_kwargs
=
[
self
.
_get_worker_kwargs
(
local_rank
=
node_workers
[
node_id
].
index
(
rank
),
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
)
for
rank
,
(
node_id
,
_
)
in
enumerate
(
worker_node_and_gpu_ids
)
]
self
.
_run_workers
(
"init_worker"
,
all_kwargs
=
init_worker_all_kwargs
)
self
.
_run_workers
(
"initialize"
)
self
.
_run_workers
(
"load_model"
)
def
_configure_ray_workers_use_nsight
(
self
,
ray_remote_kwargs
)
->
Dict
[
str
,
Any
]:
# If nsight profiling is enabled, we need to set the profiling
# configuration for the ray workers as runtime env.
runtime_env
=
ray_remote_kwargs
.
setdefault
(
"runtime_env"
,
{})
runtime_env
.
update
({
"nsight"
:
{
"t"
:
"cuda,cudnn,cublas"
,
"o"
:
"'worker_process_%p'"
,
"cuda-graph-trace"
:
"node"
,
}
})
return
ray_remote_kwargs
def
_get_env_vars_to_be_updated
(
self
):
return
self
.
_env_vars_for_all_workers
def
_get_worker_kwargs
(
self
,
local_rank
:
int
=
0
,
rank
:
int
=
0
,
distributed_init_method
:
Optional
[
str
]
=
None
)
->
Dict
[
str
,
Any
]:
"""
Return worker init args for a given rank.
"""
if
distributed_init_method
is
None
:
distributed_init_method
=
get_distributed_init_method
(
get_ip
(),
get_open_port
())
return
dict
(
vllm_config
=
self
.
vllm_config
,
local_rank
=
local_rank
,
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
)
def
determine_available_memory
(
self
)
->
int
:
"""
Determine the available GPU memory in bytes.
This invokes `determine_available_memory` on each worker and takes
the min of the results, guaranteeing that the selected cache sizes are
compatible with all workers.
"""
memory_sizes
=
self
.
_run_workers
(
"determine_available_memory"
)
# Since we use a shared centralized controller, we take the minimum
# memory size across all workers to make sure all the memory
# operators can be applied to all workers.
return
min
(
memory_sizes
)
def
initialize
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
"""
Initialize the KV cache in all workers.
"""
self
.
_run_workers
(
"initialize_cache"
,
kv_cache_config
)
self
.
_run_workers
(
"compile_or_warm_up_model"
)
def
get_kv_cache_spec
(
self
)
->
KVCacheSpec
:
"""
Get all kv cache needed by the model
This invokes `get_kv_cache_spec` on each worker and asserts that
they are identical. The KVCacheSpec is then returned.
"""
kv_cache_specs
=
self
.
_run_workers
(
"get_kv_cache_spec"
)
assert
all
(
s
==
kv_cache_specs
[
0
]
for
s
in
kv_cache_specs
)
return
kv_cache_specs
[
0
]
def
_run_workers
(
self
,
method
:
str
,
*
args
,
all_args
:
Optional
[
List
[
Tuple
[
Any
,
...]]]
=
None
,
all_kwargs
:
Optional
[
List
[
Dict
[
str
,
Any
]]]
=
None
,
**
kwargs
,
)
->
Any
:
"""
Runs the given method on all workers. Can be used in the following
ways:
Args:
- args/kwargs: All workers share the same args/kwargs
- all_args/all_kwargs: args/kwargs for each worker are specified
individually
"""
count
=
len
(
self
.
workers
)
all_worker_args
=
repeat
(
args
,
count
)
if
all_args
is
None
\
else
islice
(
all_args
,
0
,
None
)
all_worker_kwargs
=
repeat
(
kwargs
,
count
)
if
all_kwargs
is
None
\
else
islice
(
all_kwargs
,
0
,
None
)
ray_worker_refs
=
[
worker
.
execute_method
.
remote
(
# type: ignore[attr-defined]
method
,
*
worker_args
,
**
worker_kwargs
)
for
(
worker
,
worker_args
,
worker_kwargs
)
in
zip
(
self
.
workers
,
all_worker_args
,
all_worker_kwargs
)
]
return
ray
.
get
(
ray_worker_refs
)
def
execute_model
(
self
,
scheduler_output
,
)
->
ModelRunnerOutput
:
if
self
.
forward_dag
is
None
:
self
.
forward_dag
=
self
.
_compiled_ray_dag
()
# Only the first worker (with rank 0) returns the execution result.
# Others return None.
output
=
ray
.
get
(
self
.
forward_dag
.
execute
(
scheduler_output
))[
0
]
return
output
def
profile
(
self
,
is_start
=
True
):
raise
NotImplementedError
def
shutdown
(
self
):
if
hasattr
(
self
,
"forward_dag"
)
and
self
.
forward_dag
is
not
None
:
self
.
forward_dag
.
teardown
()
import
ray
for
worker
in
self
.
workers
:
ray
.
kill
(
worker
)
self
.
forward_dag
=
None
def
check_health
(
self
)
->
None
:
logger
.
debug
(
"Called check_health."
)
def
_check_ray_compiled_graph_installation
(
self
):
import
pkg_resources
from
packaging
import
version
required_version
=
version
.
parse
(
"2.39"
)
current_version
=
version
.
parse
(
pkg_resources
.
get_distribution
(
"ray"
).
version
)
if
current_version
<
required_version
:
raise
ValueError
(
f
"Ray version
{
required_version
}
is "
f
"required, but found
{
current_version
}
"
)
import
importlib.util
raycg
=
importlib
.
util
.
find_spec
(
"ray.experimental.compiled_dag_ref"
)
if
raycg
is
None
:
raise
ValueError
(
"Ray Compiled Graph is not installed. "
"Run `pip install ray[adag]` to install it."
)
cupy_spec
=
importlib
.
util
.
find_spec
(
"cupy"
)
if
cupy_spec
is
None
and
envs
.
VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL
:
raise
ValueError
(
"cupy is not installed but required since "
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL is set."
"Run `pip install ray[adag]` and check cupy installation."
)
def
_compiled_ray_dag
(
self
):
assert
self
.
parallel_config
.
use_ray
self
.
_check_ray_compiled_graph_installation
()
from
ray.dag
import
InputNode
,
MultiOutputNode
with
InputNode
()
as
input_batches
:
outputs
=
[
worker
.
execute_model
.
bind
(
# type: ignore[attr-defined]
input_batches
)
for
worker
in
self
.
workers
]
forward_dag
=
MultiOutputNode
(
outputs
)
return
forward_dag
.
experimental_compile
()
def
__del__
(
self
):
self
.
shutdown
()
vllm/v1/executor/ray_utils.py
deleted
100644 → 0
View file @
02798eca
import
time
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
from
vllm.config
import
ParallelConfig
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
get_ip
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.worker.worker_base
import
WorkerWrapperBase
if
TYPE_CHECKING
:
from
vllm.v1.core.scheduler
import
SchedulerOutput
logger
=
init_logger
(
__name__
)
PG_WAIT_TIMEOUT
=
60
try
:
import
ray
from
ray.util
import
placement_group_table
from
ray.util.placement_group
import
PlacementGroup
try
:
from
ray._private.state
import
available_resources_per_node
except
ImportError
:
# Ray 2.9.x doesn't expose `available_resources_per_node`
from
ray._private.state
import
state
as
_state
available_resources_per_node
=
_state
.
_available_resources_per_node
class
RayWorkerWrapper
(
WorkerWrapperBase
):
def
__init__
(
self
,
*
args
,
**
kwargs
)
->
None
:
super
().
__init__
(
*
args
,
**
kwargs
)
# Since the compiled DAG runs a main execution
# in a different thread that calls cuda.set_device.
# The flag indicates is set_device is called on
# that thread. It will be removed soon.
self
.
compiled_dag_cuda_device_set
=
False
def
get_node_ip
(
self
)
->
str
:
return
get_ip
()
def
get_node_and_gpu_ids
(
self
)
->
Tuple
[
str
,
List
[
int
]]:
node_id
=
ray
.
get_runtime_context
().
get_node_id
()
device_key
=
current_platform
.
ray_device_key
if
not
device_key
:
raise
RuntimeError
(
"current platform %s does not support ray."
,
current_platform
.
device_name
)
gpu_ids
=
ray
.
get_runtime_context
().
get_accelerator_ids
(
)[
device_key
]
return
node_id
,
gpu_ids
def
setup_device_if_necessary
(
self
):
# TODO(swang): This is needed right now because Ray CG executes
# on a background thread, so we need to reset torch's current
# device.
# We can remove this API after it is fixed in compiled graph.
import
torch
assert
self
.
worker
is
not
None
,
"Worker is not initialized"
if
not
self
.
compiled_dag_cuda_device_set
:
torch
.
cuda
.
set_device
(
self
.
worker
.
device
)
self
.
compiled_dag_cuda_device_set
=
True
def
execute_model
(
self
,
scheduler_output
:
"SchedulerOutput"
,
)
->
ModelRunnerOutput
:
self
.
setup_device_if_necessary
()
assert
self
.
worker
is
not
None
,
"Worker is not initialized"
output
=
self
.
worker
.
model_runner
.
execute_model
(
scheduler_output
)
return
output
ray_import_err
=
None
except
ImportError
as
e
:
ray
=
None
# type: ignore
ray_import_err
=
e
RayWorkerWrapper
=
None
# type: ignore
def
ray_is_available
()
->
bool
:
"""Returns True if Ray is available."""
return
ray
is
not
None
def
assert_ray_available
():
"""
Raise an exception if Ray is not available.
"""
if
ray
is
None
:
raise
ValueError
(
"Failed to import Ray, please install Ray with "
"`pip install ray`."
)
from
ray_import_err
def
_verify_bundles
(
placement_group
:
"PlacementGroup"
,
parallel_config
:
ParallelConfig
,
device_str
:
str
):
"""
Verify a given placement group has bundles located in the right place.
There are 2 rules.
- Warn if all tensor parallel workers cannot fit in a single node.
- Fail if driver node is not included in a placement group.
Args:
placement_group: The placement group to verify.
parallel_config: The parallel configuration.
device_str: The required device.
"""
assert
ray
.
is_initialized
(),
(
"Ray is not initialized although distributed-executor-backend is ray."
)
pg_data
=
placement_group_table
(
placement_group
)
# bundle_idx -> node_id
bundle_to_node_ids
=
pg_data
[
"bundles_to_node_id"
]
# bundle_idx -> bundle (e.g., {"GPU": 1})
bundles
=
pg_data
[
"bundles"
]
# node_id -> List of bundle (e.g., {"GPU": 1})
node_id_to_bundle
:
Dict
[
str
,
List
[
Dict
[
str
,
float
]]]
=
defaultdict
(
list
)
for
bundle_idx
,
node_id
in
bundle_to_node_ids
.
items
():
node_id_to_bundle
[
node_id
].
append
(
bundles
[
bundle_idx
])
driver_node_id
=
ray
.
get_runtime_context
().
get_node_id
()
if
driver_node_id
not
in
node_id_to_bundle
:
raise
RuntimeError
(
f
"driver node id
{
driver_node_id
}
is not included in a placement "
f
"group
{
placement_group
.
id
}
. Node id -> bundles "
f
"
{
node_id_to_bundle
}
. "
"You don't have enough GPUs available in a current node. Check "
"`ray status` to see if you have available GPUs in a node "
f
"
{
driver_node_id
}
before starting an vLLM engine."
)
for
node_id
,
bundles
in
node_id_to_bundle
.
items
():
if
len
(
bundles
)
<
parallel_config
.
tensor_parallel_size
:
logger
.
warning
(
"tensor_parallel_size=%d "
"is bigger than a reserved number of %ss (%d "
"%ss) in a node %s. Tensor parallel workers can be "
"spread out to 2+ nodes which can degrade the performance "
"unless you have fast interconnect across nodes, like "
"Infiniband. To resolve this issue, make sure you have more "
"than %d GPUs available at each node."
,
parallel_config
.
tensor_parallel_size
,
device_str
,
len
(
bundles
),
device_str
,
node_id
,
parallel_config
.
tensor_parallel_size
)
def
_wait_until_pg_ready
(
current_placement_group
:
"PlacementGroup"
):
"""Wait until a placement group is ready.
It prints the informative log messages if the placement group is
not created within time.
"""
# Wait until PG is ready - this will block until all
# requested resources are available, and will timeout
# if they cannot be provisioned.
placement_group_specs
=
current_placement_group
.
bundle_specs
s
=
time
.
time
()
pg_ready_ref
=
current_placement_group
.
ready
()
wait_interval
=
10
while
time
.
time
()
-
s
<
PG_WAIT_TIMEOUT
:
ready
,
_
=
ray
.
wait
([
pg_ready_ref
],
timeout
=
wait_interval
)
if
len
(
ready
)
>
0
:
break
# Exponential backoff for warning print.
wait_interval
*=
2
logger
.
info
(
"Waiting for creating a placement group of specs for "
"%d seconds. specs=%s. Check "
"`ray status` to see if you have enough resources."
,
int
(
time
.
time
()
-
s
),
placement_group_specs
)
try
:
ray
.
get
(
pg_ready_ref
,
timeout
=
0
)
except
ray
.
exceptions
.
GetTimeoutError
:
raise
ValueError
(
"Cannot provide a placement group of "
f
"
{
placement_group_specs
=
}
within
{
PG_WAIT_TIMEOUT
}
seconds. See "
"`ray status` to make sure the cluster has enough resources."
)
from
None
def
initialize_ray_cluster
(
parallel_config
:
ParallelConfig
,
ray_address
:
Optional
[
str
]
=
None
,
):
"""Initialize the distributed cluster with Ray.
it will connect to the Ray cluster and create a placement group
for the workers, which includes the specification of the resources
for each distributed worker.
Args:
parallel_config: The configurations for parallel execution.
ray_address: The address of the Ray cluster. If None, uses
the default Ray cluster address.
"""
assert_ray_available
()
# Connect to a ray cluster.
if
current_platform
.
is_rocm
()
or
current_platform
.
is_xpu
():
# Try to connect existing ray instance and create a new one if not found
try
:
ray
.
init
(
"auto"
)
except
ConnectionError
:
logger
.
warning
(
"No existing RAY instance detected. "
"A new instance will be launched with current node resources."
)
ray
.
init
(
address
=
ray_address
,
ignore_reinit_error
=
True
,
num_gpus
=
parallel_config
.
world_size
)
else
:
ray
.
init
(
address
=
ray_address
,
ignore_reinit_error
=
True
)
if
parallel_config
.
placement_group
:
# Placement group is already set.
return
device_str
=
current_platform
.
ray_device_key
if
not
device_str
:
raise
ValueError
(
f
"current platform
{
current_platform
.
device_name
}
does not "
"support ray."
)
# Create placement group for worker processes
current_placement_group
=
ray
.
util
.
get_current_placement_group
()
if
current_placement_group
:
# We are in a placement group
bundles
=
current_placement_group
.
bundle_specs
# Verify that we can use the placement group.
device_bundles
=
0
for
bundle
in
bundles
:
bundle_devices
=
bundle
.
get
(
device_str
,
0
)
if
bundle_devices
>
1
:
raise
ValueError
(
"Placement group bundle cannot have more than 1 "
f
"
{
device_str
}
."
)
if
bundle_devices
:
device_bundles
+=
1
if
parallel_config
.
world_size
>
device_bundles
:
raise
ValueError
(
f
"The number of required
{
device_str
}
s exceeds the total "
f
"number of available
{
device_str
}
s in the placement group."
f
"Required number of devices:
{
parallel_config
.
world_size
}
. "
f
"Total number of devices:
{
device_bundles
}
."
)
else
:
num_devices_in_cluster
=
ray
.
cluster_resources
().
get
(
device_str
,
0
)
if
parallel_config
.
world_size
>
num_devices_in_cluster
:
raise
ValueError
(
f
"The number of required
{
device_str
}
s exceeds the total "
f
"number of available
{
device_str
}
s in the placement group."
)
# Create a new placement group
placement_group_specs
:
List
[
Dict
[
str
,
float
]]
=
([{
device_str
:
1.0
}
for
_
in
range
(
parallel_config
.
world_size
)])
# vLLM engine is also a worker to execute model with an accelerator,
# so it requires to have the device in a current node. Check if
# the current node has at least one device.
current_ip
=
get_ip
()
current_node_id
=
ray
.
get_runtime_context
().
get_node_id
()
current_node_resource
=
available_resources_per_node
()[
current_node_id
]
if
current_node_resource
.
get
(
device_str
,
0
)
<
1
:
raise
ValueError
(
f
"Current node has no
{
device_str
}
available. "
f
"
{
current_node_resource
=
}
. vLLM engine cannot start without "
f
"
{
device_str
}
. Make sure you have at least 1
{
device_str
}
"
f
"available in a node
{
current_node_id
=
}
{
current_ip
=
}
."
)
# This way, at least bundle is required to be created in a current
# node.
placement_group_specs
[
0
][
f
"node:
{
current_ip
}
"
]
=
0.001
# By default, Ray packs resources as much as possible.
current_placement_group
=
ray
.
util
.
placement_group
(
placement_group_specs
,
strategy
=
"PACK"
)
_wait_until_pg_ready
(
current_placement_group
)
assert
current_placement_group
is
not
None
_verify_bundles
(
current_placement_group
,
parallel_config
,
device_str
)
# Set the placement group in the parallel config
parallel_config
.
placement_group
=
current_placement_group
vllm/v1/executor/uniproc_executor.py
deleted
100644 → 0
View file @
02798eca
import
os
from
typing
import
Optional
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
KVCacheSpec
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.worker.gpu_worker
import
Worker
logger
=
init_logger
(
__name__
)
class
UniprocExecutor
(
Executor
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
)
->
None
:
self
.
vllm_config
=
vllm_config
self
.
model_config
=
vllm_config
.
model_config
self
.
cache_config
=
vllm_config
.
cache_config
self
.
lora_config
=
vllm_config
.
lora_config
self
.
load_config
=
vllm_config
.
load_config
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
scheduler_config
=
vllm_config
.
scheduler_config
self
.
device_config
=
vllm_config
.
device_config
self
.
speculative_config
=
vllm_config
.
speculative_config
self
.
prompt_adapter_config
=
vllm_config
.
prompt_adapter_config
self
.
observability_config
=
vllm_config
.
observability_config
self
.
worker
:
Worker
=
self
.
_create_worker
()
self
.
worker
.
init_device
()
self
.
worker
.
load_model
()
def
_create_worker
(
self
,
local_rank
:
int
=
0
,
rank
:
int
=
0
,
distributed_init_method
:
Optional
[
str
]
=
None
)
->
Worker
:
"""Return worker init args for a given rank."""
# see https://github.com/NVIDIA/nccl/issues/1234
os
.
environ
[
'NCCL_CUMEM_ENABLE'
]
=
'0'
if
distributed_init_method
is
None
:
distributed_init_method
=
get_distributed_init_method
(
get_ip
(),
get_open_port
())
return
Worker
(
vllm_config
=
self
.
vllm_config
,
local_rank
=
local_rank
,
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
)
def
determine_available_memory
(
self
)
->
int
:
"""Determine the available memory (in bytes) for KV cache by invoking
the underlying worker.
"""
return
self
.
worker
.
determine_available_memory
()
def
get_kv_cache_spec
(
self
)
->
KVCacheSpec
:
"""Get all kv cache needed by the model by invoking the underlying
worker.
"""
return
self
.
worker
.
get_kv_cache_spec
()
def
initialize
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
"""Initialize the KV cache by invoking the underlying worker.
"""
self
.
worker
.
initialize_cache
(
kv_cache_config
)
self
.
worker
.
compile_or_warm_up_model
()
def
execute_model
(
self
,
scheduler_output
,
)
->
ModelRunnerOutput
:
output
=
self
.
worker
.
execute_model
(
scheduler_output
)
assert
output
is
not
None
return
output
def
profile
(
self
,
is_start
:
bool
=
True
):
self
.
worker
.
profile
(
is_start
)
def
shutdown
(
self
):
pass
def
check_health
(
self
)
->
None
:
# UniprocExecutor will always be healthy as long as
# it's running.
return
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment