Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5df2da5b
Unverified
Commit
5df2da5b
authored
Mar 20, 2025
by
Cody Yu
Committed by
GitHub
Mar 20, 2025
Browse files
[Misc] Better RayExecutor and multiprocessing compatibility (#14705)
Signed-off-by:
Cody Yu
<
hao.yu.cody@gmail.com
>
parent
11b986b3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
67 additions
and
21 deletions
+67
-21
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+14
-1
vllm/executor/multiproc_worker_utils.py
vllm/executor/multiproc_worker_utils.py
+2
-2
vllm/executor/ray_utils.py
vllm/executor/ray_utils.py
+13
-8
vllm/utils.py
vllm/utils.py
+38
-10
No files found.
vllm/engine/arg_utils.py
View file @
5df2da5b
...
@@ -26,7 +26,7 @@ from vllm.plugins import load_general_plugins
...
@@ -26,7 +26,7 @@ from vllm.plugins import load_general_plugins
from
vllm.test_utils
import
MODEL_WEIGHTS_S3_BUCKET
,
MODELS_ON_S3
from
vllm.test_utils
import
MODEL_WEIGHTS_S3_BUCKET
,
MODELS_ON_S3
from
vllm.transformers_utils.utils
import
check_gguf_file
from
vllm.transformers_utils.utils
import
check_gguf_file
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
FlexibleArgumentParser
,
StoreBoolean
from
vllm.utils
import
FlexibleArgumentParser
,
StoreBoolean
,
is_in_ray_actor
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.transformers_utils.tokenizer_group
import
BaseTokenizerGroup
from
vllm.transformers_utils.tokenizer_group
import
BaseTokenizerGroup
...
@@ -1245,6 +1245,18 @@ class EngineArgs:
...
@@ -1245,6 +1245,18 @@ class EngineArgs:
cpu_offload_gb
=
self
.
cpu_offload_gb
,
cpu_offload_gb
=
self
.
cpu_offload_gb
,
calculate_kv_scales
=
self
.
calculate_kv_scales
,
calculate_kv_scales
=
self
.
calculate_kv_scales
,
)
)
# Get the current placement group if Ray is initialized and
# we are in a Ray actor. If so, then the placement group will be
# passed to spawned processes.
placement_group
=
None
if
is_in_ray_actor
():
import
ray
# This call initializes Ray automatically if it is not initialized,
# but we should not do this here.
placement_group
=
ray
.
util
.
get_current_placement_group
()
parallel_config
=
ParallelConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
self
.
pipeline_parallel_size
,
pipeline_parallel_size
=
self
.
pipeline_parallel_size
,
tensor_parallel_size
=
self
.
tensor_parallel_size
,
tensor_parallel_size
=
self
.
tensor_parallel_size
,
...
@@ -1257,6 +1269,7 @@ class EngineArgs:
...
@@ -1257,6 +1269,7 @@ class EngineArgs:
self
.
tokenizer_pool_extra_config
,
self
.
tokenizer_pool_extra_config
,
),
),
ray_workers_use_nsight
=
self
.
ray_workers_use_nsight
,
ray_workers_use_nsight
=
self
.
ray_workers_use_nsight
,
placement_group
=
placement_group
,
distributed_executor_backend
=
self
.
distributed_executor_backend
,
distributed_executor_backend
=
self
.
distributed_executor_backend
,
worker_cls
=
self
.
worker_cls
,
worker_cls
=
self
.
worker_cls
,
worker_extension_cls
=
self
.
worker_extension_cls
,
worker_extension_cls
=
self
.
worker_extension_cls
,
...
...
vllm/executor/multiproc_worker_utils.py
View file @
5df2da5b
...
@@ -16,7 +16,7 @@ import torch
...
@@ -16,7 +16,7 @@ import torch
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
_
check_multiproc_method
,
get_mp_context
,
run_method
from
vllm.utils
import
_
maybe_force_spawn
,
get_mp_context
,
run_method
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -291,7 +291,7 @@ def set_multiprocessing_worker_envs(parallel_config):
...
@@ -291,7 +291,7 @@ def set_multiprocessing_worker_envs(parallel_config):
in a multiprocessing environment. This should be called by the parent
in a multiprocessing environment. This should be called by the parent
process before worker processes are created"""
process before worker processes are created"""
_
check_multiproc_method
()
_
maybe_force_spawn
()
# Configure thread parallelism if OMP_NUM_THREADS isn't set
# Configure thread parallelism if OMP_NUM_THREADS isn't set
#
#
...
...
vllm/executor/ray_utils.py
View file @
5df2da5b
...
@@ -284,8 +284,9 @@ def initialize_ray_cluster(
...
@@ -284,8 +284,9 @@ def initialize_ray_cluster(
assert_ray_available
()
assert_ray_available
()
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
# Connect to a ray cluster.
if
ray
.
is_initialized
():
if
current_platform
.
is_rocm
()
or
current_platform
.
is_xpu
():
logger
.
info
(
"Ray is already initialized. Skipping Ray initialization."
)
elif
current_platform
.
is_rocm
()
or
current_platform
.
is_xpu
():
# Try to connect existing ray instance and create a new one if not found
# Try to connect existing ray instance and create a new one if not found
try
:
try
:
ray
.
init
(
"auto"
,
ignore_reinit_error
=
True
)
ray
.
init
(
"auto"
,
ignore_reinit_error
=
True
)
...
@@ -299,19 +300,21 @@ def initialize_ray_cluster(
...
@@ -299,19 +300,21 @@ def initialize_ray_cluster(
else
:
else
:
ray
.
init
(
address
=
ray_address
,
ignore_reinit_error
=
True
)
ray
.
init
(
address
=
ray_address
,
ignore_reinit_error
=
True
)
if
parallel_config
.
placement_group
:
# Placement group is already set.
return
device_str
=
current_platform
.
ray_device_key
device_str
=
current_platform
.
ray_device_key
if
not
device_str
:
if
not
device_str
:
raise
ValueError
(
raise
ValueError
(
f
"current platform
{
current_platform
.
device_name
}
does not "
f
"current platform
{
current_platform
.
device_name
}
does not "
"support ray."
)
"support ray."
)
# Create placement group for worker processes
# Create or get the placement group for worker processes
current_placement_group
=
ray
.
util
.
get_current_placement_group
()
if
parallel_config
.
placement_group
:
current_placement_group
=
parallel_config
.
placement_group
else
:
current_placement_group
=
ray
.
util
.
get_current_placement_group
()
if
current_placement_group
:
if
current_placement_group
:
logger
.
info
(
"Using the existing placement group"
)
# We are in a placement group
# We are in a placement group
bundles
=
current_placement_group
.
bundle_specs
bundles
=
current_placement_group
.
bundle_specs
# Verify that we can use the placement group.
# Verify that we can use the placement group.
...
@@ -331,6 +334,8 @@ def initialize_ray_cluster(
...
@@ -331,6 +334,8 @@ def initialize_ray_cluster(
f
"Required number of devices:
{
parallel_config
.
world_size
}
. "
f
"Required number of devices:
{
parallel_config
.
world_size
}
. "
f
"Total number of devices:
{
device_bundles
}
."
)
f
"Total number of devices:
{
device_bundles
}
."
)
else
:
else
:
logger
.
info
(
"No current placement group found. "
"Creating a new placement group."
)
num_devices_in_cluster
=
ray
.
cluster_resources
().
get
(
device_str
,
0
)
num_devices_in_cluster
=
ray
.
cluster_resources
().
get
(
device_str
,
0
)
# Log a warning message and delay resource allocation failure response.
# Log a warning message and delay resource allocation failure response.
# Avoid immediate rejection to allow user-initiated placement group
# Avoid immediate rejection to allow user-initiated placement group
...
...
vllm/utils.py
View file @
5df2da5b
...
@@ -2147,20 +2147,48 @@ def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]:
...
@@ -2147,20 +2147,48 @@ def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]:
ctx
.
destroy
(
linger
=
0
)
ctx
.
destroy
(
linger
=
0
)
def
_check_multiproc_method
():
def
is_in_ray_actor
():
if
(
cuda_is_initialized
()
"""Check if we are in a Ray actor."""
and
os
.
environ
.
get
(
"VLLM_WORKER_MULTIPROC_METHOD"
)
!=
"spawn"
):
logger
.
warning
(
"CUDA was previously initialized. We must use "
try
:
"the `spawn` multiprocessing start method. Setting "
import
ray
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
return
(
ray
.
is_initialized
()
"See https://docs.vllm.ai/en/latest/getting_started/"
and
ray
.
get_runtime_context
().
get_actor_id
()
is
not
None
)
"troubleshooting.html#python-multiprocessing "
except
ImportError
:
"for more information."
)
return
False
def
_maybe_force_spawn
():
"""Check if we need to force the use of the `spawn` multiprocessing start
method.
"""
if
os
.
environ
.
get
(
"VLLM_WORKER_MULTIPROC_METHOD"
)
==
"spawn"
:
return
reason
=
None
if
cuda_is_initialized
():
reason
=
"CUDA is initialized"
elif
is_in_ray_actor
():
reason
=
"In a Ray actor and can only be spawned"
if
reason
is
not
None
:
logger
.
warning
(
"We must use the `spawn` multiprocessing start method. "
"Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
"See https://docs.vllm.ai/en/latest/getting_started/"
"troubleshooting.html#python-multiprocessing "
"for more information. Reason: %s"
,
reason
)
os
.
environ
[
"VLLM_WORKER_MULTIPROC_METHOD"
]
=
"spawn"
os
.
environ
[
"VLLM_WORKER_MULTIPROC_METHOD"
]
=
"spawn"
def
get_mp_context
():
def
get_mp_context
():
_check_multiproc_method
()
"""Get a multiprocessing context with a particular method (spawn or fork).
By default we follow the value of the VLLM_WORKER_MULTIPROC_METHOD to
determine the multiprocessing method (default is fork). However, under
certain conditions, we may enforce spawn and override the value of
VLLM_WORKER_MULTIPROC_METHOD.
"""
_maybe_force_spawn
()
mp_method
=
envs
.
VLLM_WORKER_MULTIPROC_METHOD
mp_method
=
envs
.
VLLM_WORKER_MULTIPROC_METHOD
return
multiprocessing
.
get_context
(
mp_method
)
return
multiprocessing
.
get_context
(
mp_method
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment