Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d3317594
Unverified
Commit
d3317594
authored
Aug 01, 2025
by
Rui Qiao
Committed by
GitHub
Aug 01, 2025
Browse files
Introduce RayPPCommunicator for ray-based PP (#21660)
Signed-off-by:
Rui Qiao
<
ruisearch42@gmail.com
>
parent
9659bc7f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
280 additions
and
0 deletions
+280
-0
vllm/distributed/device_communicators/ray_communicator.py
vllm/distributed/device_communicators/ray_communicator.py
+257
-0
vllm/envs.py
vllm/envs.py
+8
-0
vllm/executor/ray_distributed_executor.py
vllm/executor/ray_distributed_executor.py
+15
-0
No files found.
vllm/distributed/device_communicators/ray_communicator.py
0 → 100644
View file @
d3317594
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
uuid
from
typing
import
Any
,
Optional
import
ray
import
torch
from
ray.exceptions
import
RayChannelError
from
ray.experimental.channel.communicator
import
(
Communicator
,
TorchTensorAllocator
)
from
torch.distributed
import
ReduceOp
from
vllm.distributed.device_communicators.base_device_communicator
import
(
DeviceCommunicatorBase
)
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.logger
import
init_logger
from
vllm.utils
import
current_stream
logger
=
init_logger
(
__name__
)
class
RayPPCommunicator
(
Communicator
):
"""
Communicator to be used for pipeline parallelism in Ray Compiled Graph.
This is wraps around the vLLM _PP GroupCoordinator.
This class is not thread-safe.
"""
_comm
:
Optional
[
DeviceCommunicatorBase
]
def
__init__
(
self
,
world_size
:
int
,
comm_id
:
Any
,
rank
:
Optional
[
int
],
actor_handles
:
list
[
"ray.actor.ActorHandle"
],
cuda_stream
:
Optional
[
torch
.
cuda
.
Stream
],
use_communication_streams
:
bool
=
False
,
):
"""
Initialize a RayPPCommunicator that can be used to communicate with
other Ray Compiled Graph actors for pipeline parallelism.
Args:
world_size: The number of participating actors.
comm_id: A unique communicator ID. This is just to conform with
the Ray Communicator API and is not used.
rank: The rank of this actor. If None, then the caller is not a
participant of the RayPPCommunicator group (e.g., the Ray
driver).
actor_handles: A list of actor handles.
cuda_stream: A CUDA stream to dispatch communication ops to. This
is not supported.
use_communication_streams: Whether to use communication streams.
This is not supported.
"""
self
.
_world_size
=
world_size
self
.
_rank
:
Optional
[
int
]
=
None
self
.
_actor_handles
=
actor_handles
if
use_communication_streams
:
raise
NotImplementedError
(
"use_communication_streams is not supported"
)
if
cuda_stream
is
not
None
and
cuda_stream
!=
current_stream
():
raise
ValueError
(
"cuda_stream other than the current stream is not supported"
)
if
rank
is
not
None
:
# Rank is not None, this is Ray worker
assert
ray
.
get_gpu_ids
(),
"RayPPCommunicator has no GPUs assigned"
self
.
_comm
=
get_pp_group
().
device_communicator
# Since we wrap around the vLLM _PP communicator, we use
# the rank from the vLLM communicator, and ignore the rank
# passed in from Ray.
# TODO(rui): refactor the Ray Communicator API so that
# it also supports no rank passed in.
self
.
_rank
=
self
.
_comm
.
rank_in_group
self
.
_build_actor_rank_mapping
()
else
:
# Rank is None, this is Ray driver
self
.
_comm
=
None
self
.
_closed
=
False
def
_build_actor_rank_mapping
(
self
):
"""
Use collective communication to build a mapping from actor IDs to ranks.
This should be called once during initialization.
"""
if
self
.
_comm
is
None
:
return
{}
current_actor
=
ray
.
get_runtime_context
().
current_actor
actor_id_str
=
current_actor
.
_actor_id
.
hex
()
# Ray actor IDs are 32-character hex strings (128 bits)
ACTOR_ID_LEN
=
32
actor_id_bytes
=
actor_id_str
.
encode
(
'utf-8'
)
assert
len
(
actor_id_bytes
)
==
ACTOR_ID_LEN
,
f
"Unexpected actor ID length:
{
len
(
actor_id_bytes
)
}
"
actor_id_tensor
=
torch
.
frombuffer
(
actor_id_bytes
,
dtype
=
torch
.
uint8
).
to
(
self
.
_comm
.
device
)
# All-gather full actor IDs from all actors
gathered_ids
=
self
.
_comm
.
all_gather
(
actor_id_tensor
,
dim
=
0
)
# Build mapping: actor_id -> device_comm_rank
self
.
_actor_id_to_rank
=
{}
for
rank
in
range
(
self
.
_world_size
):
start_idx
=
rank
*
ACTOR_ID_LEN
end_idx
=
(
rank
+
1
)
*
ACTOR_ID_LEN
actor_bytes
=
gathered_ids
[
start_idx
:
end_idx
].
cpu
().
numpy
(
).
tobytes
()
actor_id
=
actor_bytes
.
decode
(
'utf-8'
)
self
.
_actor_id_to_rank
[
actor_id
]
=
rank
def
initialize
(
self
,
rank
:
int
)
->
None
:
# No additional initialization is needed.
pass
def
get_actor_handles
(
self
)
->
list
[
"ray.actor.ActorHandle"
]:
return
self
.
_actor_handles
def
get_rank
(
self
,
actor
:
ray
.
actor
.
ActorHandle
)
->
int
:
"""
Return the given actor's rank using device communicator collective ops.
"""
assert
hasattr
(
self
,
'_actor_id_to_rank'
),
(
"Actor rank mapping not built. "
"This should have been done during initialization."
)
actor_id_str
=
actor
.
_actor_id
.
hex
()
if
actor_id_str
in
self
.
_actor_id_to_rank
:
return
self
.
_actor_id_to_rank
[
actor_id_str
]
# type: ignore
else
:
raise
ValueError
(
f
"Actor
{
actor
}
not found in communicator group"
)
def
get_self_rank
(
self
)
->
Optional
[
int
]:
"""
Return this actor's rank.
"""
return
self
.
_rank
def
get_world_size
(
self
)
->
int
:
"""
Return the number of ranks in the RayPPCommunicator group.
"""
return
self
.
_world_size
def
send
(
self
,
buf
:
"torch.Tensor"
,
peer_rank
:
int
)
->
None
:
"""
Send a torch.Tensor to a peer.
This returns when the send kernel has been queued, but the kernel may
not have completed. Therefore, the caller should ensure that there are
no concurrent writes to the sent `buf` until the send has finished.
That is, either all writes should be submitted on the current stream
(self._cuda_stream) or, if on a different stream, that stream should
synchronize with the current stream.
Args:
buf: The torch.Tensor to send. It should already be on this
actor's default device.
peer_rank: The rank of the actor to send to.
"""
if
self
.
_closed
:
raise
RayChannelError
(
"RayPPCommunicator has been destroyed."
)
assert
self
.
_comm
is
not
None
self
.
_comm
.
send
(
buf
,
peer_rank
)
def
recv
(
self
,
shape
:
tuple
[
int
],
dtype
:
"torch.dtype"
,
peer_rank
:
int
,
allocator
:
TorchTensorAllocator
,
)
->
"torch.Tensor"
:
"""
Receive a torch.Tensor from a peer and synchronize the current stream.
After this call returns, the receive buffer is safe to read from from
any stream. An RayChannelError will be raised if an error occurred
(e.g., remote actor died), and the buffer is not safe to read.
Args:
shape: The shape of the tensor to receive.
dtype: The dtype of the tensor to receive.
peer_rank: The rank of the actor to receive from.
allocator: The allocator to use to create the received tensor.
This is ignored for this implementation.
"""
if
self
.
_closed
:
raise
RayChannelError
(
"RayPPCommunicator has been destroyed."
)
assert
self
.
_comm
is
not
None
size
=
torch
.
Size
(
shape
)
buf
=
self
.
_comm
.
recv
(
size
,
dtype
,
src
=
peer_rank
)
# Buffer values are undefined if NCCL ops are aborted. Therefore, we
# need to synchronize here and check that the channel is still
# open to ensure that the receive buffer is valid.
# TODO(swang): Avoid CUDA synchronization.
current_stream
().
synchronize
()
if
self
.
_closed
:
raise
RayChannelError
(
"RayPPCommunicator has been destroyed."
)
return
buf
def
allgather
(
self
,
send_buf
:
"torch.Tensor"
,
recv_buf
:
"torch.Tensor"
,
):
raise
NotImplementedError
(
"allgather is not supported"
)
def
allreduce
(
self
,
send_buf
:
"torch.Tensor"
,
recv_buf
:
"torch.Tensor"
,
op
:
ReduceOp
=
ReduceOp
.
SUM
,
):
raise
NotImplementedError
(
"allreduce is not supported"
)
def
reducescatter
(
self
,
send_buf
:
"torch.Tensor"
,
recv_buf
:
"torch.Tensor"
,
op
:
ReduceOp
=
ReduceOp
.
SUM
,
):
raise
NotImplementedError
(
"reducescatter is not supported"
)
@
property
def
recv_stream
(
self
):
return
torch
.
cuda
.
StreamContext
(
current_stream
())
@
property
def
send_stream
(
self
):
return
torch
.
cuda
.
StreamContext
(
current_stream
())
def
destroy
(
self
)
->
None
:
# Just sets a flag, vLLM manages the lifecycle of the underlying
# _PP GroupCoordinator.
self
.
_closed
=
True
def
get_transport_name
(
self
)
->
str
:
return
"nccl"
@
classmethod
def
generate_communicator_id
(
cls
)
->
Any
:
return
uuid
.
uuid4
()
vllm/envs.py
View file @
d3317594
...
...
@@ -55,6 +55,7 @@ if TYPE_CHECKING:
VLLM_USE_RAY_COMPILED_DAG
:
bool
=
False
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
:
str
=
"auto"
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM
:
bool
=
False
VLLM_USE_RAY_WRAPPED_PP_COMM
:
bool
=
True
VLLM_XLA_USE_SPMD
:
bool
=
False
VLLM_WORKER_MULTIPROC_METHOD
:
str
=
"fork"
VLLM_ASSETS_CACHE
:
str
=
os
.
path
.
join
(
VLLM_CACHE_ROOT
,
"assets"
)
...
...
@@ -498,6 +499,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM"
,
"0"
))
),
# If the env var is set, it uses a Ray Communicator wrapping
# vLLM's pipeline parallelism communicator to interact with Ray's
# Compiled Graph. Otherwise, it uses Ray's NCCL communicator.
# This flag is ignored if VLLM_USE_RAY_COMPILED_DAG is not set.
"VLLM_USE_RAY_WRAPPED_PP_COMM"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_RAY_WRAPPED_PP_COMM"
,
"1"
))),
# Use dedicated multiprocess context for workers.
# Both spawn and fork work
"VLLM_WORKER_MULTIPROC_METHOD"
:
...
...
vllm/executor/ray_distributed_executor.py
View file @
d3317594
...
...
@@ -608,6 +608,21 @@ class RayDistributedExecutor(DistributedExecutorBase):
forward_dag
=
MultiOutputNode
(
outputs
)
if
envs
.
VLLM_USE_RAY_WRAPPED_PP_COMM
:
from
ray.experimental.channel.accelerator_context
import
(
register_accelerator_context
)
from
vllm.distributed.device_communicators.ray_communicator
import
(
RayPPCommunicator
)
register_accelerator_context
(
torch_module_name
=
"cuda"
,
communicator_cls
=
RayPPCommunicator
)
logger
.
info
(
"Using RayPPCommunicator "
"(which wraps vLLM _PP GroupCoordinator) "
"for Ray Compiled Graph communication."
)
else
:
logger
.
info
(
"Using Ray's NCCL communicator for "
"Ray Compiled Graph communication."
)
return
forward_dag
.
experimental_compile
(
enable_asyncio
=
enable_asyncio
,
_overlap_gpu_communication
=
envs
.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment