# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ EPLB communicator implementations and factory. """ from abc import ABC, abstractmethod from collections.abc import Sequence import torch from torch.distributed import ( P2POp, ProcessGroup, batch_isend_irecv, ) from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.pynccl_wrapper import ( ncclDataTypeEnum, ) from vllm.distributed.parallel_state import GroupCoordinator, is_local_first_rank from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator from vllm.logger import init_logger logger = init_logger(__name__) class EplbCommunicator(ABC): """Abstract EPLB communicator for expert weight transfers.""" @abstractmethod def add_send(self, tensor: torch.Tensor, dst_rank: int) -> None: pass @abstractmethod def add_recv(self, tensor: torch.Tensor, src_rank: int) -> None: pass @abstractmethod def execute(self) -> None: pass def set_stream(self, cuda_stream: torch.cuda.Stream | None) -> None: self._cuda_stream = cuda_stream def _log_initialized(self) -> None: if is_local_first_rank(): logger.info("Initialized EPLB communicator: %s.", self.__class__.__name__) class TorchDistNcclEplbCommunicator(EplbCommunicator): """EPLB communicator backed by torch.distributed isend/irecv.""" def __init__( self, ep_group: ProcessGroup, cuda_stream: torch.cuda.Stream | None = None, ) -> None: self._ep_group = ep_group self._cuda_stream = cuda_stream self._p2p_ops: list[P2POp] = [] self._log_initialized() def add_send(self, tensor: torch.Tensor, dst_rank: int) -> None: self._p2p_ops.append( P2POp( torch.distributed.isend, tensor, dst_rank, self._ep_group, ) ) def add_recv(self, tensor: torch.Tensor, src_rank: int) -> None: self._p2p_ops.append( P2POp( torch.distributed.irecv, tensor, src_rank, self._ep_group, ) ) def execute(self) -> None: if not self._p2p_ops: return try: with torch.cuda.stream(self._cuda_stream): reqs = batch_isend_irecv(self._p2p_ops) for req in reqs: req.wait() finally: self._p2p_ops.clear() class TorchDistGlooStagedEplbCommunicator(EplbCommunicator): """EPLB communicator using gloo P2P with CPU staging.""" def __init__( self, cpu_group: ProcessGroup, cuda_stream: torch.cuda.Stream | None = None, ) -> None: self._cpu_group = cpu_group self._cuda_stream = cuda_stream self._ops: list[tuple[str, torch.Tensor, int]] = [] self._log_initialized() def add_send(self, tensor: torch.Tensor, dst_rank: int) -> None: self._ops.append(("send", tensor, dst_rank)) def add_recv(self, tensor: torch.Tensor, src_rank: int) -> None: self._ops.append(("recv", tensor, src_rank)) def execute(self) -> None: if not self._ops: return p2p_ops: list[P2POp] = [] recv_staging: list[tuple[torch.Tensor, torch.Tensor]] = [] def build_ops() -> None: for op, tensor, peer_rank in self._ops: if op == "send": cpu_tensor = tensor.to(device="cpu", non_blocking=True) p2p_ops.append( P2POp( torch.distributed.isend, cpu_tensor, peer_rank, self._cpu_group, ) ) continue cpu_tensor = torch.empty_like(tensor, device="cpu") p2p_ops.append( P2POp( torch.distributed.irecv, cpu_tensor, peer_rank, self._cpu_group, ) ) recv_staging.append((tensor, cpu_tensor)) try: with torch.cuda.stream(self._cuda_stream): build_ops() finally: self._ops.clear() # Wait for all D2H copies to finish # before issuing gloo batch_isend_irecv operations. if self._cuda_stream is not None: self._cuda_stream.synchronize() else: torch.cuda.current_stream().synchronize() reqs = batch_isend_irecv(p2p_ops) for req in reqs: req.wait() if not recv_staging: return with torch.cuda.stream(self._cuda_stream): for dst_tensor, cpu_tensor in recv_staging: dst_tensor.copy_(cpu_tensor, non_blocking=True) class PyNcclEplbCommunicator(EplbCommunicator): """EPLB communicator backed by PyNcclCommunicator using ncclSend/ncclRecv.""" def __init__( self, pynccl_comm: PyNcclCommunicator, cuda_stream: torch.cuda.Stream | None = None, ) -> None: self._pynccl_comm = pynccl_comm self._cuda_stream = cuda_stream self._group_started = False self._log_initialized() def _ensure_group_started(self) -> None: if not self._group_started: self._pynccl_comm.group_start() self._group_started = True def add_send(self, tensor: torch.Tensor, dst_rank: int) -> None: self._ensure_group_started() self._pynccl_comm.send(tensor, dst_rank, stream=self._cuda_stream) def add_recv(self, tensor: torch.Tensor, src_rank: int) -> None: self._ensure_group_started() self._pynccl_comm.recv(tensor, src_rank, stream=self._cuda_stream) def execute(self) -> None: if self._group_started: self._pynccl_comm.group_end() self._group_started = False def create_eplb_communicator( group_coordinator: GroupCoordinator, backend: str | None, expert_weights: Sequence[torch.Tensor], ) -> EplbCommunicator: # Keep a safe default for callers that have not resolved communicator yet. if backend is None: backend = "torch_nccl" tensor_device_type = expert_weights[0].device.type if expert_weights else "cpu" torch_group = ( group_coordinator.cpu_group if tensor_device_type == "cpu" else group_coordinator.device_group ) def _create_pynccl() -> EplbCommunicator: if tensor_device_type == "cpu": raise RuntimeError( "EPLB communicator 'pynccl' supports only cuda-like devices " f"(got {tensor_device_type})." ) unsupported_dtypes = sorted( { tensor.dtype for tensor in expert_weights if not ncclDataTypeEnum.supports_torch_dtype(tensor.dtype) }, key=str, ) if unsupported_dtypes: raise RuntimeError( "EPLB communicator 'pynccl' requested but expert weights contain " "unsupported dtypes: " f"({', '.join(str(dtype) for dtype in unsupported_dtypes)})." ) device_comm = group_coordinator.device_communicator pynccl_comm = ( getattr(device_comm, "pynccl_comm", None) if device_comm is not None else None ) if pynccl_comm is None or pynccl_comm.disabled or not pynccl_comm.available: raise RuntimeError("EPLB communicator 'pynccl' requested but unavailable.") try: return PyNcclEplbCommunicator(pynccl_comm=pynccl_comm) except Exception as exc: raise RuntimeError( f"Failed to initialize PyNcclEplbCommunicator ({exc})." ) from exc is_stateless = isinstance(group_coordinator, StatelessGroupCoordinator) if is_stateless: if backend not in ("torch_nccl", "pynccl"): raise ValueError( f"Elastic EP requires 'torch_nccl' or 'pynccl' EPLB communicator " f"(got '{backend}'). torch_gloo is not supported with stateless groups." ) if backend == "torch_nccl": logger.warning( "Stateless elastic EP requires PyNCCL backend. " "Forcing EPLB communicator to 'pynccl'." ) backend = "pynccl" return _create_pynccl() if backend == "torch_gloo": return TorchDistGlooStagedEplbCommunicator( cpu_group=group_coordinator.cpu_group, ) elif backend == "torch_nccl": return TorchDistNcclEplbCommunicator(ep_group=torch_group) elif backend == "pynccl": return _create_pynccl() raise ValueError(f"Unknown EPLB communicator backend: {backend}")