Unverified Commit 3d2a026f authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Feature] Pipeline Parallel Async send/recv, 2.9% E2E throughput improvement (#33368)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
Signed-off-by: default avatarWentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: default avatarNick Hill <nhill@redhat.com>
parent dddbff46
...@@ -19,6 +19,8 @@ from vllm.distributed import ( ...@@ -19,6 +19,8 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter, tensor_model_parallel_reduce_scatter,
) )
from vllm.distributed.parallel_state import GroupCoordinator, TensorMetadata
from vllm.v1.worker.gpu_worker import AsyncIntermediateTensors
from ..utils import ( from ..utils import (
init_test_distributed_environment, init_test_distributed_environment,
...@@ -200,6 +202,111 @@ def send_recv_tensor_dict_test_worker( ...@@ -200,6 +202,111 @@ def send_recv_tensor_dict_test_worker(
torch.testing.assert_close(recv_dict["f"], test_dict["f"]) torch.testing.assert_close(recv_dict["f"], test_dict["f"])
class _DummyWork:
def __init__(self) -> None:
self.wait_calls = 0
def wait(self) -> None:
self.wait_calls += 1
class _DummyAllGatherGroup:
def __init__(self, world_size: int, rank_in_group: int) -> None:
self.world_size = world_size
self.rank_in_group = rank_in_group
def all_gather(self, t: torch.Tensor, dim: int = 0) -> torch.Tensor:
# duplicate local slice across ranks.
assert dim == 0
return torch.cat([t for _ in range(self.world_size)], dim=0)
def _make_group_for_unit_test(
rank_in_group: int = 0, world_size: int = 2
) -> GroupCoordinator:
# avoid running GroupCoordinator.__init__ (it wires up real process groups).
g = GroupCoordinator.__new__(GroupCoordinator)
g.world_size = world_size
g.rank_in_group = rank_in_group
g.ranks = list(range(world_size))
g.use_cpu_custom_send_recv = False
g.device_group = None
g.cpu_group = None
return g
def test_irecv_tensor_dict_send_allgather_postprocess_binds_keys(
monkeypatch: pytest.MonkeyPatch,
) -> None:
def fake_irecv(t: torch.Tensor, *args: Any, **kwargs: Any) -> _DummyWork:
t.fill_(1)
return _DummyWork()
monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True)
monkeypatch.setattr(torch.distributed, "irecv", fake_irecv)
g = _make_group_for_unit_test(rank_in_group=0, world_size=2)
# 2 tensors so we can catch late-binding bugs in postprocess closures.
metadata_list = [
("a", TensorMetadata("cpu", torch.int32, torch.Size([4]))),
("b", TensorMetadata("cpu", torch.int32, torch.Size([4]))),
]
g.recv_object = lambda src=None: metadata_list # type: ignore[method-assign]
ag = _DummyAllGatherGroup(world_size=2, rank_in_group=0)
td, handles, postprocess = g.irecv_tensor_dict(all_gather_group=ag)
assert td is not None
assert len(handles) == 2
assert len(postprocess) == 2
# before postprocess, dict holds the TP slice (shape 2).
assert td["a"].shape == torch.Size([2])
assert td["b"].shape == torch.Size([2])
# simulate worker-side "defer wait": wait + postprocess later.
for handle in handles:
handle.wait()
for fn in postprocess:
fn()
# after postprocess, dict values are reconstructed to full shape (shape 4),
# and each key should be updated independently
assert td["a"].shape == torch.Size([4])
assert td["b"].shape == torch.Size([4])
torch.testing.assert_close(td["a"], torch.ones(4, dtype=torch.int32))
torch.testing.assert_close(td["b"], torch.ones(4, dtype=torch.int32))
def test_async_intermediate_tensors_lazy_wait() -> None:
work = _DummyWork()
post_calls = {"n": 0}
def post() -> None:
post_calls["n"] += 1
it = AsyncIntermediateTensors(
{"x": torch.tensor([1])},
comm_handles=[work],
comm_postprocess=[post],
)
# accessing non-tensor attributes should not trigger wait.
assert it.kv_connector_output is None
assert work.wait_calls == 0
assert post_calls["n"] == 0
# first access of `.tensors` triggers wait + postprocess.
_ = it.tensors
assert work.wait_calls == 1
assert post_calls["n"] == 1
# subsequent access should not re-wait.
_ = it.tensors
assert work.wait_calls == 1
assert post_calls["n"] == 1
@ray.remote(num_gpus=1, max_calls=1) @ray.remote(num_gpus=1, max_calls=1)
def send_recv_test_worker( def send_recv_test_worker(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
......
...@@ -33,7 +33,7 @@ from contextlib import contextmanager, nullcontext ...@@ -33,7 +33,7 @@ from contextlib import contextmanager, nullcontext
from dataclasses import dataclass from dataclasses import dataclass
from datetime import timedelta from datetime import timedelta
from multiprocessing import shared_memory from multiprocessing import shared_memory
from typing import Any from typing import Any, Protocol
from unittest.mock import patch from unittest.mock import patch
import torch import torch
...@@ -64,6 +64,14 @@ class GraphCaptureContext: ...@@ -64,6 +64,14 @@ class GraphCaptureContext:
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
class Handle(Protocol):
"""Minimal async work handle used by P2P send/recv methods."""
def is_completed(self) -> bool: ...
def wait(self) -> None: ...
def _split_tensor_dict( def _split_tensor_dict(
tensor_dict: dict[str, torch.Tensor | Any], tensor_dict: dict[str, torch.Tensor | Any],
) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]: ) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]:
...@@ -780,6 +788,20 @@ class GroupCoordinator: ...@@ -780,6 +788,20 @@ class GroupCoordinator:
async_handle.wait() async_handle.wait()
return tensor_dict return tensor_dict
def _should_use_all_gather(
self,
key: str,
numel: int,
all_gather_group: "GroupCoordinator | None",
all_gather_tensors: dict[str, bool] | None,
) -> bool:
if all_gather_group is None:
return False
use_all_gather = numel % all_gather_group.world_size == 0
if all_gather_tensors is not None:
use_all_gather = all_gather_tensors.get(key, use_all_gather)
return use_all_gather
def send_tensor_dict( def send_tensor_dict(
self, self,
tensor_dict: dict[str, torch.Tensor | Any], tensor_dict: dict[str, torch.Tensor | Any],
...@@ -808,6 +830,35 @@ class GroupCoordinator: ...@@ -808,6 +830,35 @@ class GroupCoordinator:
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1: if not torch.distributed.is_initialized() or self.world_size == 1:
return tensor_dict return tensor_dict
handles = self.isend_tensor_dict(
tensor_dict,
dst=dst,
all_gather_group=all_gather_group,
all_gather_tensors=all_gather_tensors,
)
for handle in handles:
handle.wait()
return None
def isend_tensor_dict(
self,
tensor_dict: dict[str, torch.Tensor | Any],
dst: int | None = None,
all_gather_group: "GroupCoordinator | None" = None,
all_gather_tensors: dict[str, bool] | None = None,
) -> list[Handle]:
if self.world_size <= 1:
return []
if self.use_cpu_custom_send_recv:
if self.device_communicator is None:
raise ValueError("No device communicator found")
# custom device communicator path is synchronous
self.device_communicator.send_tensor_dict( # type: ignore
tensor_dict, dst
)
return []
all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
all_gather_rank = ( all_gather_rank = (
0 if all_gather_group is None else all_gather_group.rank_in_group 0 if all_gather_group is None else all_gather_group.rank_in_group
...@@ -820,53 +871,31 @@ class GroupCoordinator: ...@@ -820,53 +871,31 @@ class GroupCoordinator:
dst = (self.rank_in_group + 1) % self.world_size dst = (self.rank_in_group + 1) % self.world_size
assert dst < self.world_size, f"Invalid dst rank ({dst})" assert dst < self.world_size, f"Invalid dst rank ({dst})"
if self.use_cpu_custom_send_recv:
if self.device_communicator is None:
raise ValueError("No device communicator found")
self.device_communicator.send_tensor_dict( # type: ignore
tensor_dict, dst
)
return None
metadata_list: list[tuple[Any, Any]] = []
assert isinstance(tensor_dict, dict), (
f"Expecting a dictionary, got {type(tensor_dict)}"
)
metadata_list, tensor_list = _split_tensor_dict(tensor_dict) metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
# `metadata_list` lives in CPU memory.
# `send_object_list` has serialization & deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
self.send_object(metadata_list, dst=dst) self.send_object(metadata_list, dst=dst)
tensor_keys = [k for k, v in tensor_dict.items() if isinstance(v, torch.Tensor)] tensor_keys = [k for k, v in tensor_dict.items() if isinstance(v, torch.Tensor)]
assert len(tensor_keys) == len(tensor_list) assert len(tensor_keys) == len(tensor_list)
handles: list[Handle] = []
for key, tensor in zip(tensor_keys, tensor_list): for key, tensor in zip(tensor_keys, tensor_list):
if tensor.numel() == 0: if tensor.numel() == 0:
# Skip sending empty tensors.
continue continue
# send-allgather: send only a slice, then do allgather. if self._should_use_all_gather(
use_all_gather = ( key, tensor.numel(), all_gather_group, all_gather_tensors
all_gather_group is not None and tensor.numel() % all_gather_size == 0 ):
)
use_all_gather = (
all_gather_tensors.get(key, use_all_gather)
if all_gather_tensors
else use_all_gather
)
if use_all_gather:
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
if tensor.is_cpu: comm_group = metadata_group if tensor.is_cpu else group
# use metadata_group for CPU tensors handle = torch.distributed.isend(
torch.distributed.send( tensor, dst=self.ranks[dst], group=comm_group
tensor, dst=self.ranks[dst], group=metadata_group )
) if tensor.is_cuda:
else: tensor.record_stream(torch.cuda.current_stream(tensor.device))
# use group for GPU tensors handles.append(handle)
torch.distributed.send(tensor, dst=self.ranks[dst], group=group)
return None return handles
def recv_tensor_dict( def recv_tensor_dict(
self, self,
...@@ -895,6 +924,38 @@ class GroupCoordinator: ...@@ -895,6 +924,38 @@ class GroupCoordinator:
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1: if not torch.distributed.is_initialized() or self.world_size == 1:
return None return None
tensor_dict, handles, postprocess = self.irecv_tensor_dict(
src=src,
all_gather_group=all_gather_group,
all_gather_tensors=all_gather_tensors,
)
for handle in handles:
handle.wait()
for fn in postprocess:
fn()
return tensor_dict
def irecv_tensor_dict(
self,
src: int | None = None,
all_gather_group: "GroupCoordinator | None" = None,
all_gather_tensors: dict[str, bool] | None = None,
) -> tuple[
dict[str, torch.Tensor | Any] | None,
list[Handle],
list[Callable[[], None]],
]:
if not torch.distributed.is_initialized() or self.world_size == 1:
return None, [], []
if self.use_cpu_custom_send_recv:
if self.device_communicator is None:
raise ValueError("No device communicator found")
# custom device communicator path is synchronous
sync_tensor_dict = self.device_communicator.recv_tensor_dict( # type: ignore
src
)
return sync_tensor_dict, [], []
all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
all_gather_rank = ( all_gather_rank = (
0 if all_gather_group is None else all_gather_group.rank_in_group 0 if all_gather_group is None else all_gather_group.rank_in_group
...@@ -907,57 +968,57 @@ class GroupCoordinator: ...@@ -907,57 +968,57 @@ class GroupCoordinator:
src = (self.rank_in_group - 1) % self.world_size src = (self.rank_in_group - 1) % self.world_size
assert src < self.world_size, f"Invalid src rank ({src})" assert src < self.world_size, f"Invalid src rank ({src})"
if self.use_cpu_custom_send_recv:
if self.device_communicator is None:
raise ValueError("No device communicator found")
return self.device_communicator.recv_tensor_dict( # type: ignore
src
)
recv_metadata_list = self.recv_object(src=src) recv_metadata_list = self.recv_object(src=src)
tensor_dict: dict[str, Any] = {} tensor_dict: dict[str, Any] = {}
handles: list[Handle] = []
postprocess: list[Callable[[], None]] = []
for key, value in recv_metadata_list: for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata): if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) full_tensor = torch.empty(
if tensor.numel() == 0: value.size, dtype=value.dtype, device=value.device
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
continue
# send-allgather: send only a slice, then do allgather.
use_all_gather = (
all_gather_group is not None
and tensor.numel() % all_gather_size == 0
)
use_all_gather = (
all_gather_tensors.get(key, use_all_gather)
if all_gather_tensors
else use_all_gather
) )
if full_tensor.numel() == 0:
tensor_dict[key] = full_tensor
continue
if use_all_gather: if self._should_use_all_gather(
orig_shape = tensor.shape key, full_tensor.numel(), all_gather_group, all_gather_tensors
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] ):
orig_shape = full_tensor.shape
if tensor.is_cpu: slice_tensor = full_tensor.reshape(all_gather_size, -1)[
# use metadata_group for CPU tensors all_gather_rank
torch.distributed.recv( ]
tensor, src=self.ranks[src], group=metadata_group comm_group = metadata_group if slice_tensor.is_cpu else group
handle = torch.distributed.irecv(
slice_tensor, src=self.ranks[src], group=comm_group
) )
handles.append(handle)
def _postprocess(
key: str = key,
slice_tensor: torch.Tensor = slice_tensor,
orig_shape: tuple[int, ...] = tuple(orig_shape),
all_gather_group=all_gather_group,
) -> None:
assert all_gather_group is not None
tensor_dict[key] = all_gather_group.all_gather(
slice_tensor, dim=0
).reshape(orig_shape)
postprocess.append(_postprocess)
tensor_dict[key] = slice_tensor
else: else:
# use group for GPU tensors comm_group = metadata_group if full_tensor.is_cpu else group
torch.distributed.recv(tensor, src=self.ranks[src], group=group) handle = torch.distributed.irecv(
if use_all_gather: full_tensor, src=self.ranks[src], group=comm_group
# do the allgather
tensor = all_gather_group.all_gather( # type: ignore
tensor, dim=0
) )
tensor = tensor.reshape(orig_shape) handles.append(handle)
tensor_dict[key] = full_tensor
tensor_dict[key] = tensor
else: else:
tensor_dict[key] = value tensor_dict[key] = value
return tensor_dict
return tensor_dict, handles, postprocess
def barrier(self): def barrier(self):
"""Barrier synchronization among the group. """Barrier synchronization among the group.
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import gc import gc
import os import os
from collections.abc import Callable
from contextlib import AbstractContextManager, nullcontext from contextlib import AbstractContextManager, nullcontext
from types import NoneType from types import NoneType
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any, cast
...@@ -30,6 +31,7 @@ from vllm.distributed.kv_transfer import ( ...@@ -30,6 +31,7 @@ from vllm.distributed.kv_transfer import (
has_kv_transfer_group, has_kv_transfer_group,
) )
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
Handle,
get_pcp_group, get_pcp_group,
get_pp_group, get_pp_group,
get_tp_group, get_tp_group,
...@@ -68,6 +70,38 @@ if TYPE_CHECKING: ...@@ -68,6 +70,38 @@ if TYPE_CHECKING:
from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_model_runner import GPUModelRunner
class AsyncIntermediateTensors(IntermediateTensors):
"""IntermediateTensors with lazy comm synchronization"""
def __init__(
self,
tensors: dict[str, torch.Tensor],
comm_handles: list[Handle] | None = None,
comm_postprocess: list[Callable[[], None]] | None = None,
) -> None:
super().__init__(tensors)
self._comm_handles = comm_handles
self._comm_postprocess = comm_postprocess
self._comm_waited = False
def wait_for_comm(self) -> None:
if self._comm_waited:
return
if self._comm_handles:
for handle in self._comm_handles:
handle.wait()
if self._comm_postprocess:
for fn in self._comm_postprocess:
fn()
self._comm_waited = True
def __getattribute__(self, name: str):
# ensure `.tensors` is ready before use
if name == "tensors" and not object.__getattribute__(self, "_comm_waited"):
object.__getattribute__(self, "wait_for_comm")()
return object.__getattribute__(self, name)
class Worker(WorkerBase): class Worker(WorkerBase):
def __init__( def __init__(
self, self,
...@@ -113,6 +147,8 @@ class Worker(WorkerBase): ...@@ -113,6 +147,8 @@ class Worker(WorkerBase):
raise ValueError(f"Unknown profiler type: {self.profiler_config.profiler}") raise ValueError(f"Unknown profiler type: {self.profiler_config.profiler}")
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
# pending non-blocking PP send work from the previous iteration
self._pp_send_work: list[Handle] = []
def sleep(self, level: int = 1) -> None: def sleep(self, level: int = 1) -> None:
from vllm.device_allocator.cumem import CuMemAllocator from vllm.device_allocator.cumem import CuMemAllocator
...@@ -600,6 +636,12 @@ class Worker(WorkerBase): ...@@ -600,6 +636,12 @@ class Worker(WorkerBase):
def execute_model( def execute_model(
self, scheduler_output: "SchedulerOutput" self, scheduler_output: "SchedulerOutput"
) -> ModelRunnerOutput | AsyncModelRunnerOutput | None: ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
# ensure any previous non-blocking PP sends are complete
if self._pp_send_work:
for handle in self._pp_send_work:
handle.wait()
self._pp_send_work = []
intermediate_tensors = None intermediate_tensors = None
forward_pass = scheduler_output.total_num_scheduled_tokens > 0 forward_pass = scheduler_output.total_num_scheduled_tokens > 0
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
...@@ -637,12 +679,18 @@ class Worker(WorkerBase): ...@@ -637,12 +679,18 @@ class Worker(WorkerBase):
} }
if forward_pass and not get_pp_group().is_first_rank: if forward_pass and not get_pp_group().is_first_rank:
tensor_dict = get_pp_group().recv_tensor_dict( tensor_dict, comm_handles, comm_postprocess = (
all_gather_group=get_tp_group(), get_pp_group().irecv_tensor_dict(
all_gather_tensors=all_gather_tensors, all_gather_group=get_tp_group(),
all_gather_tensors=all_gather_tensors,
)
) )
assert tensor_dict is not None assert tensor_dict is not None
intermediate_tensors = IntermediateTensors(tensor_dict) intermediate_tensors = AsyncIntermediateTensors(
tensor_dict,
comm_handles=comm_handles,
comm_postprocess=comm_postprocess,
)
with self.annotate_profile(scheduler_output): with self.annotate_profile(scheduler_output):
output = self.model_runner.execute_model( output = self.model_runner.execute_model(
...@@ -660,7 +708,8 @@ class Worker(WorkerBase): ...@@ -660,7 +708,8 @@ class Worker(WorkerBase):
and not get_pp_group().is_last_rank and not get_pp_group().is_last_rank
) )
get_pp_group().send_tensor_dict( # launch non-blocking send of intermediate tensors
self._pp_send_work = get_pp_group().isend_tensor_dict(
output.tensors, output.tensors,
all_gather_group=get_tp_group(), all_gather_group=get_tp_group(),
all_gather_tensors=all_gather_tensors, all_gather_tensors=all_gather_tensors,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment