Unverified Commit dea26833 authored by Itay Alroy's avatar Itay Alroy Committed by GitHub
Browse files

[1/N] Elastic EP Milestone 2 (#34861)


Signed-off-by: default avatarYongji Wu <wuyongji317@gmail.com>
Signed-off-by: default avatarItay Alroy <ialroy@nvidia.com>
Signed-off-by: default avatarTyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: default avatarRon Tourgeman <rtourgeman@nvidia.com>
Co-authored-by: default avatarYongji Wu <wuyongji317@gmail.com>
Co-authored-by: default avatarTyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: default avatarRon Tourgeman <rtourgeman@nvidia.com>
parent 90805ff4
...@@ -31,8 +31,8 @@ class NaiveAll2AllManager(All2AllManagerBase): ...@@ -31,8 +31,8 @@ class NaiveAll2AllManager(All2AllManagerBase):
debugging. debugging.
""" """
def __init__(self, cpu_group): def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group) super().__init__(cpu_group, tcp_store_group)
def naive_multicast( def naive_multicast(
self, self,
...@@ -138,8 +138,8 @@ class AgRsAll2AllManager(All2AllManagerBase): ...@@ -138,8 +138,8 @@ class AgRsAll2AllManager(All2AllManagerBase):
all-gather (dispatch) and reduce-scatter (combine). all-gather (dispatch) and reduce-scatter (combine).
""" """
def __init__(self, cpu_group): def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group) super().__init__(cpu_group, tcp_store_group)
def dispatch_router_logits( def dispatch_router_logits(
self, self,
...@@ -239,12 +239,12 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase): ...@@ -239,12 +239,12 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
All2All communication based on DeepEP High-Throughput kernels. All2All communication based on DeepEP High-Throughput kernels.
""" """
def __init__(self, cpu_group): def __init__(self, cpu_group, tcp_store_group=None):
assert has_deep_ep(), ( assert has_deep_ep(), (
"DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md" "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
" to install DeepEP kernels." " to install DeepEP kernels."
) # noqa ) # noqa
super().__init__(cpu_group) super().__init__(cpu_group, tcp_store_group)
self.handle_cache = Cache() self.handle_cache = Cache()
# This is the DeepEP default. Stick to it till we can establish # This is the DeepEP default. Stick to it till we can establish
...@@ -282,7 +282,10 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase): ...@@ -282,7 +282,10 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
raise NotImplementedError raise NotImplementedError
def destroy(self): def destroy(self):
pass with self.handle_cache._lock:
for _, handle in self.handle_cache._cache.items():
handle.destroy()
self.handle_cache._cache.clear()
class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
...@@ -290,8 +293,8 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): ...@@ -290,8 +293,8 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
All2All communication based on DeepEP High-Throughput kernels. All2All communication based on DeepEP High-Throughput kernels.
""" """
def __init__(self, cpu_group): def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group) super().__init__(cpu_group, tcp_store_group)
def _make_all2all_kwargs(self) -> dict[Any, Any]: def _make_all2all_kwargs(self) -> dict[Any, Any]:
# Defaults for internode and intranode are taken from DeepEP tests. # Defaults for internode and intranode are taken from DeepEP tests.
...@@ -314,6 +317,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): ...@@ -314,6 +317,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
num_rdma_bytes=num_rdma_bytes, num_rdma_bytes=num_rdma_bytes,
low_latency_mode=False, low_latency_mode=False,
num_qps_per_rank=num_qps_per_rank, num_qps_per_rank=num_qps_per_rank,
explicitly_destroy=True,
) )
def get_handle(self, kwargs): def get_handle(self, kwargs):
...@@ -347,8 +351,8 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): ...@@ -347,8 +351,8 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
All2All communication based on DeepEP Low-Latency kernels. All2All communication based on DeepEP Low-Latency kernels.
""" """
def __init__(self, cpu_group): def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group) super().__init__(cpu_group, tcp_store_group)
def _make_all2all_kwargs( def _make_all2all_kwargs(
self, self,
...@@ -387,6 +391,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): ...@@ -387,6 +391,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
num_qps_per_rank=num_qps_per_rank, num_qps_per_rank=num_qps_per_rank,
allow_nvlink_for_low_latency_mode=True, allow_nvlink_for_low_latency_mode=True,
allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL, allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
explicitly_destroy=True,
) )
def get_handle(self, kwargs): def get_handle(self, kwargs):
...@@ -418,11 +423,11 @@ class FlashInferAllToAllManager(All2AllManagerBase): ...@@ -418,11 +423,11 @@ class FlashInferAllToAllManager(All2AllManagerBase):
rank: int rank: int
world_size: int world_size: int
def __init__(self, cpu_group): def __init__(self, cpu_group, tcp_store_group=None):
assert has_flashinfer_all2all(), ( assert has_flashinfer_all2all(), (
"flashinfer all2all module not found. Please install/check flashinfer" "flashinfer all2all module not found. Please install/check flashinfer"
) # noqa ) # noqa
super().__init__(cpu_group) super().__init__(cpu_group, tcp_store_group)
logger.debug( logger.debug(
"Initialize for flashinfer All2All rank=%d, world size=%d", "Initialize for flashinfer All2All rank=%d, world size=%d",
self.rank, self.rank,
......
...@@ -29,8 +29,9 @@ class All2AllManagerBase: ...@@ -29,8 +29,9 @@ class All2AllManagerBase:
rank: int rank: int
world_size: int world_size: int
def __init__(self, cpu_group): def __init__(self, cpu_group, tcp_store_group=None):
self.cpu_group = cpu_group self.cpu_group = cpu_group
self.tcp_store_group = tcp_store_group
# compute some common properties # compute some common properties
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
...@@ -47,12 +48,17 @@ class All2AllManagerBase: ...@@ -47,12 +48,17 @@ class All2AllManagerBase:
# when we create this object # when we create this object
self.dp_rank = self.dp_group.rank_in_group self.dp_rank = self.dp_group.rank_in_group
self.dp_world_size = self.dp_group.world_size self.dp_world_size = self.dp_group.world_size
self.rank = dist.get_rank(cpu_group) self.rank = cpu_group.rank()
self.world_size = dist.get_world_size(cpu_group) self.world_size = cpu_group.size()
# all2all communication often has separate implementations for # all2all communication often has separate implementations for
# intra-node and inter-node communication # intra-node and inter-node communication
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0)) if tcp_store_group is None:
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
else:
self.internode = not all(
in_the_same_node_as(tcp_store_group, source_rank=0)
)
def get_handle(self, kwargs): def get_handle(self, kwargs):
# get a handle for the all2all communication, # get a handle for the all2all communication,
...@@ -121,17 +127,36 @@ class DeviceCommunicatorBase: ...@@ -121,17 +127,36 @@ class DeviceCommunicatorBase:
device: torch.device | None = None, device: torch.device | None = None,
device_group: ProcessGroup | None = None, device_group: ProcessGroup | None = None,
unique_name: str = "", unique_name: str = "",
global_ranks: list[int] | None = None,
global_world_size: int | None = None,
): ):
self.device = device or torch.device("cpu") self.device = device or torch.device("cpu")
self.cpu_group = cpu_group self.cpu_group = cpu_group
self.device_group = device_group self.device_group = device_group
self.unique_name = unique_name self.unique_name = unique_name
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group) # Check if this is a stateless process group
self.ranks = dist.get_process_group_ranks(cpu_group) from torch.distributed.distributed_c10d import _world
self.global_rank = dist.get_rank()
self.global_world_size = dist.get_world_size() is_stateless = _world.pg_map.get(cpu_group, None) is None
self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank)
if is_stateless:
# For stateless groups, we can't use torch.distributed methods
self.rank = cpu_group.rank()
self.world_size = cpu_group.size()
assert global_ranks is not None
assert global_world_size is not None
self.ranks = global_ranks
self.global_rank = self.ranks[self.rank]
self.global_world_size = global_world_size
self.rank_in_group = self.rank
else:
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
self.ranks = dist.get_process_group_ranks(cpu_group)
self.global_rank = dist.get_rank()
self.global_world_size = dist.get_world_size()
self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank)
use_ep = False use_ep = False
all2all_backend = None all2all_backend = None
...@@ -145,7 +170,7 @@ class DeviceCommunicatorBase: ...@@ -145,7 +170,7 @@ class DeviceCommunicatorBase:
use_ep = config.parallel_config.data_parallel_size > 1 use_ep = config.parallel_config.data_parallel_size > 1
all2all_backend = config.parallel_config.all2all_backend all2all_backend = config.parallel_config.all2all_backend
self.is_ep_communicator = "ep" in unique_name self.is_ep_communicator = unique_name.split(":")[0] == "ep"
self.use_all2all = self.is_ep_communicator and use_ep self.use_all2all = self.is_ep_communicator and use_ep
self.all2all_backend = all2all_backend self.all2all_backend = all2all_backend
self.all2all_manager: All2AllManagerBase | None = None self.all2all_manager: All2AllManagerBase | None = None
...@@ -275,6 +300,13 @@ class DeviceCommunicatorBase: ...@@ -275,6 +300,13 @@ class DeviceCommunicatorBase:
torch.distributed.recv(tensor, self.ranks[src], self.device_group) torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor return tensor
def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
"""Broadcast a tensor from source rank to all ranks."""
if self.world_size == 1:
return tensor
torch.distributed.broadcast(tensor, self.ranks[src], self.device_group)
return tensor
def destroy(self): def destroy(self):
pass pass
...@@ -343,3 +375,6 @@ class DeviceCommunicatorBase: ...@@ -343,3 +375,6 @@ class DeviceCommunicatorBase:
This is a no-op in the base class. This is a no-op in the base class.
""" """
return hidden_states return hidden_states
def batch_isend_irecv(self, p2p_ops: list):
raise NotImplementedError
...@@ -16,6 +16,7 @@ from vllm.distributed.device_communicators.pynccl_allocator import ( ...@@ -16,6 +16,7 @@ from vllm.distributed.device_communicators.pynccl_allocator import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ..utils import StatelessProcessGroup
from .base_device_communicator import DeviceCommunicatorBase from .base_device_communicator import DeviceCommunicatorBase
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -28,8 +29,18 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -28,8 +29,18 @@ class CudaCommunicator(DeviceCommunicatorBase):
device: torch.device | None = None, device: torch.device | None = None,
device_group: ProcessGroup | None = None, device_group: ProcessGroup | None = None,
unique_name: str = "", unique_name: str = "",
global_ranks: list[int] | None = None,
global_world_size: int | None = None,
tcp_store_group: StatelessProcessGroup | None = None,
): ):
super().__init__(cpu_group, device, device_group, unique_name) super().__init__(
cpu_group,
device,
device_group,
unique_name,
global_ranks,
global_world_size,
)
if "tp" not in unique_name: if "tp" not in unique_name:
# custom allreduce or torch symm mem can be used only by tp # custom allreduce or torch symm mem can be used only by tp
use_custom_allreduce = False use_custom_allreduce = False
...@@ -62,7 +73,7 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -62,7 +73,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.pynccl_comm: PyNcclCommunicator | None = None self.pynccl_comm: PyNcclCommunicator | None = None
if self.world_size > 1: if self.world_size > 1:
self.pynccl_comm = PyNcclCommunicator( self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group, group=self.cpu_group if tcp_store_group is None else tcp_store_group,
device=self.device, device=self.device,
) )
if is_symmetric_memory_enabled(): if is_symmetric_memory_enabled():
...@@ -107,19 +118,27 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -107,19 +118,27 @@ class CudaCommunicator(DeviceCommunicatorBase):
if self.all2all_backend == "naive": if self.all2all_backend == "naive":
from .all2all import NaiveAll2AllManager from .all2all import NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group) self.all2all_manager = NaiveAll2AllManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "allgather_reducescatter": elif self.all2all_backend == "allgather_reducescatter":
from .all2all import AgRsAll2AllManager from .all2all import AgRsAll2AllManager
self.all2all_manager = AgRsAll2AllManager(self.cpu_group) self.all2all_manager = AgRsAll2AllManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "deepep_high_throughput": elif self.all2all_backend == "deepep_high_throughput":
from .all2all import DeepEPHTAll2AllManager from .all2all import DeepEPHTAll2AllManager
self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group) self.all2all_manager = DeepEPHTAll2AllManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "deepep_low_latency": elif self.all2all_backend == "deepep_low_latency":
from .all2all import DeepEPLLAll2AllManager from .all2all import DeepEPLLAll2AllManager
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group) self.all2all_manager = DeepEPLLAll2AllManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "mori": elif self.all2all_backend == "mori":
from .all2all import MoriAll2AllManager from .all2all import MoriAll2AllManager
...@@ -127,7 +146,9 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -127,7 +146,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
elif self.all2all_backend == "flashinfer_all2allv": elif self.all2all_backend == "flashinfer_all2allv":
from .all2all import FlashInferAllToAllManager from .all2all import FlashInferAllToAllManager
self.all2all_manager = FlashInferAllToAllManager(self.cpu_group) self.all2all_manager = FlashInferAllToAllManager(
self.cpu_group, tcp_store_group
)
else: else:
raise ValueError(f"Unknown all2all backend: {self.all2all_backend}") raise ValueError(f"Unknown all2all backend: {self.all2all_backend}")
...@@ -284,6 +305,18 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -284,6 +305,18 @@ class CudaCommunicator(DeviceCommunicatorBase):
torch.distributed.recv(tensor, self.ranks[src], self.device_group) torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor return tensor
def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
"""Broadcast a tensor from source rank to all ranks."""
if self.world_size == 1:
return tensor
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.broadcast(tensor, src)
return tensor
else:
raise ValueError("No PyNCCL communicator found")
def destroy(self): def destroy(self):
if self.pynccl_comm is not None: if self.pynccl_comm is not None:
self.pynccl_comm = None self.pynccl_comm = None
...@@ -403,3 +436,10 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -403,3 +436,10 @@ class CudaCommunicator(DeviceCommunicatorBase):
hidden_states, hidden_states,
is_sequence_parallel, is_sequence_parallel,
) )
def batch_isend_irecv(self, p2p_ops: list):
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.batch_isend_irecv(p2p_ops)
else:
raise ValueError("No PyNCCL communicator found")
...@@ -312,10 +312,19 @@ class PyNcclCommunicator: ...@@ -312,10 +312,19 @@ class PyNcclCommunicator:
) )
if stream is None: if stream is None:
stream = current_stream() stream = current_stream()
if tensor.dtype in [
torch.float8_e5m2,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
]:
nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8)
else:
nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype)
self.nccl.ncclSend( self.nccl.ncclSend(
buffer_type(tensor.data_ptr()), buffer_type(tensor.data_ptr()),
tensor.numel(), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), nccl_dtype,
dst, dst,
self.comm, self.comm,
cudaStream_t(stream.cuda_stream), cudaStream_t(stream.cuda_stream),
...@@ -330,10 +339,19 @@ class PyNcclCommunicator: ...@@ -330,10 +339,19 @@ class PyNcclCommunicator:
) )
if stream is None: if stream is None:
stream = current_stream() stream = current_stream()
if tensor.dtype in [
torch.float8_e5m2,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
]:
nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8)
else:
nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype)
self.nccl.ncclRecv( self.nccl.ncclRecv(
buffer_type(tensor.data_ptr()), buffer_type(tensor.data_ptr()),
tensor.numel(), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), nccl_dtype,
src, src,
self.comm, self.comm,
cudaStream_t(stream.cuda_stream), cudaStream_t(stream.cuda_stream),
...@@ -384,3 +402,17 @@ class PyNcclCommunicator: ...@@ -384,3 +402,17 @@ class PyNcclCommunicator:
def deregister_comm_window(self, window): def deregister_comm_window(self, window):
return self.nccl.ncclCommWindowDeregister(self.comm, window) return self.nccl.ncclCommWindowDeregister(self.comm, window)
def batch_isend_irecv(self, p2p_ops: list, stream=None):
if self.disabled:
return
if stream is None:
stream = current_stream()
self.group_start()
for op in p2p_ops:
if op.op is torch.distributed.isend:
self.send(op.tensor, op.group_peer, stream)
elif op.op is torch.distributed.irecv:
self.recv(op.tensor, op.group_peer, stream)
self.group_end()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import gc
import weakref
from collections.abc import Iterable, Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed import P2POp
from vllm.compilation.counter import compilation_counter
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.wrapper import reset_compile_wrapper
from vllm.config import (
CompilationMode,
set_current_vllm_config,
)
from vllm.distributed import (
get_dp_group,
get_ep_group,
get_pcp_group,
get_tp_group,
)
from vllm.distributed.elastic_ep.standby_state import (
create_standby_groups,
get_standby_dp_group,
get_standby_ep_group,
pop_standby_groups,
)
from vllm.distributed.parallel_state import (
_replace_active_groups,
prepare_communication_buffer_for_model,
)
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
from vllm.v1.worker.workspace import lock_workspace, unlock_workspace
logger = init_logger(__name__)
def batch_transfer_weights(
model: nn.Module,
is_sender: bool,
peer_rank: int,
dp_group: StatelessGroupCoordinator,
expert_weights: Sequence[Iterable[torch.Tensor]],
) -> None:
device_comm = dp_group.device_communicator
if device_comm is None:
raise ValueError("No device communicator found")
expert_weights_set = set()
for weight_group in expert_weights:
for weight in weight_group:
expert_weights_set.add(weight.data_ptr())
state_dict = model.state_dict()
all_params = []
for name, param in state_dict.items():
if name.endswith("expert_map"):
continue
if param.data_ptr() not in expert_weights_set:
all_params.append(param.data)
assert len(all_params) > 0
p2p_ops = []
for param in all_params:
op = object.__new__(P2POp)
if is_sender:
op.op = torch.distributed.isend
op.tensor = param
else:
op.op = torch.distributed.irecv
op.tensor = param
op.group_peer = peer_rank
p2p_ops.append(op)
device_comm.batch_isend_irecv(p2p_ops)
def broadcast_expert_mapping(
physical_to_logical: torch.Tensor | None,
num_local_physical_experts: int | None,
num_logical_experts: int | None,
dp_group: StatelessGroupCoordinator,
device: torch.device,
src_rank: int = 0,
) -> tuple[torch.Tensor, int, int]:
if dp_group.rank_in_group == src_rank:
assert physical_to_logical is not None
assert num_local_physical_experts is not None
assert num_logical_experts is not None
assert physical_to_logical.dtype == torch.int64
shape_tensor = torch.tensor(
list(physical_to_logical.shape), dtype=torch.int64, device="cpu"
)
metadata_tensor = torch.tensor(
[num_local_physical_experts, num_logical_experts],
dtype=torch.int64,
device="cpu",
)
else:
shape_tensor = torch.empty(2, dtype=torch.int64, device="cpu")
metadata_tensor = torch.empty(2, dtype=torch.int64, device="cpu")
shape_tensor = dp_group.tcp_store_group.broadcast(shape_tensor, src_rank)
metadata_tensor = dp_group.tcp_store_group.broadcast(metadata_tensor, src_rank)
if dp_group.rank_in_group != src_rank:
assert device is not None
physical_to_logical = torch.empty(
tuple(shape_tensor.tolist()),
dtype=torch.int64,
device=device,
)
assert physical_to_logical is not None
physical_to_logical = dp_group.broadcast(physical_to_logical, src_rank)
num_local_physical_experts = int(metadata_tensor[0].item())
num_logical_experts = int(metadata_tensor[1].item())
return physical_to_logical, num_local_physical_experts, num_logical_experts
class ElasticEPScalingExecutor:
def __init__(self, worker):
self.worker_ref = weakref.ref(worker)
self.reconfig_request = None
@property
def worker(self):
worker = self.worker_ref()
if worker is None:
raise RuntimeError("Worker has been garbage collected")
return worker
def execute(self, execute_method: str, *args, **kwargs):
method = getattr(self, execute_method, None)
if method is None:
raise ValueError(f"Unknown execute method: {execute_method}")
return method(*args, **kwargs)
def create_standby_groups(
self, reconfig_request: ReconfigureDistributedRequest
) -> None:
self.reconfig_request = reconfig_request
new_dp_size = reconfig_request.new_data_parallel_size
world_size = self.worker.vllm_config.parallel_config.world_size
new_world_size_across_dp = world_size * new_dp_size
updated_config = copy.copy(self.worker.vllm_config)
updated_config.parallel_config = copy.deepcopy(
self.worker.vllm_config.parallel_config
)
updated_config.parallel_config.data_parallel_size = new_dp_size
with set_current_vllm_config(updated_config):
create_standby_groups(
new_dp_size=new_dp_size,
new_world_size_across_dp=new_world_size_across_dp,
master_ip=reconfig_request.new_data_parallel_master_ip,
world_group_ports=reconfig_request.new_stateless_world_group_port_list,
dp_group_ports=reconfig_request.new_stateless_dp_group_port_list,
ep_group_ports=reconfig_request.new_stateless_ep_group_port_list,
eplb_group_ports=reconfig_request.new_stateless_eplb_group_port_list,
)
self.worker.model_runner.eep_eplb_suppressed = True
standby_ep_group = get_standby_ep_group()
assert standby_ep_group is not None
if standby_ep_group.rank == 0:
logger.info("[Elastic EP] EPLB disabled during elastic scaling transition")
def transfer_weights(self, old_dp_size: int, new_dp_size: int) -> None:
standby_dp_group = get_standby_dp_group()
assert standby_dp_group is not None
# Broadcast old_dp_size to all workers in standby group
if standby_dp_group.rank_in_group < old_dp_size:
old_dp_size_tensor = torch.tensor(
[old_dp_size], dtype=torch.int64, device="cpu"
)
else:
old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu")
old_dp_size_tensor = standby_dp_group.tcp_store_group.broadcast(
old_dp_size_tensor, 0
)
num_new_workers = new_dp_size - old_dp_size
dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank
# Sender-receiver pairing: the first new_workers % old_dp_size
# senders get (k+1) contiguous receivers, the rest get k
# receivers.
num_dst_per_sender = num_new_workers // old_dp_size
remainder = num_new_workers % old_dp_size
if dp_rank < remainder:
recv_begin = dp_rank * (num_dst_per_sender + 1)
recv_end = recv_begin + num_dst_per_sender + 1
else:
recv_begin = (
remainder * (num_dst_per_sender + 1)
+ (dp_rank - remainder) * num_dst_per_sender
)
recv_end = recv_begin + num_dst_per_sender
ranks_to_send = list(range(old_dp_size + recv_begin, old_dp_size + recv_end))
model = self.worker.model_runner.get_model()
for new_worker_rank in sorted(ranks_to_send):
batch_transfer_weights(
model=model,
is_sender=True,
peer_rank=new_worker_rank,
dp_group=standby_dp_group,
expert_weights=model.expert_weights,
)
torch.cuda.synchronize()
def broadcast_expert_mapping(self) -> None:
standby_dp_group = get_standby_dp_group()
assert standby_dp_group is not None
model_config = self.worker.model_runner.model_config
eplb_state = self.worker.model_runner.eplb_state
assert eplb_state is not None
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
physical_to_logical = eplb_model_state.physical_to_logical_map
num_physical_experts = physical_to_logical.shape[1]
num_local_physical_experts = num_physical_experts // get_ep_group().world_size
num_logical_experts = eplb_model_state.logical_replica_count.shape[1]
broadcast_expert_mapping(
physical_to_logical=physical_to_logical,
num_local_physical_experts=num_local_physical_experts,
num_logical_experts=num_logical_experts,
dp_group=standby_dp_group,
src_rank=0,
device=self.worker.device,
)
def switch_and_remove(self) -> None:
_replace_active_groups(world=None, dp=None, ep=None, eplb=None, node_count=None)
def switch_and_prepare(self) -> None:
old_dp_size = get_dp_group().world_size
old_ep_size = get_ep_group().world_size
_replace_active_groups(**pop_standby_groups())
parallel_config = self.worker.vllm_config.parallel_config
reconfig_request = self.reconfig_request
assert reconfig_request is not None
new_dp_size = reconfig_request.new_data_parallel_size
new_ep_size = get_ep_group().world_size
parallel_config.data_parallel_size = new_dp_size
if (
reconfig_request.new_data_parallel_rank
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
if (
reconfig_request.new_data_parallel_rank_local
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
parallel_config.data_parallel_rank_local = (
reconfig_request.new_data_parallel_rank_local
)
parallel_config.data_parallel_master_ip = (
reconfig_request.new_data_parallel_master_ip
)
parallel_config.data_parallel_master_port = (
reconfig_request.new_data_parallel_master_port
)
# Reconfigure MoE modules with new EP size
moe_modules = [
module
for module in self.worker.model_runner.model.modules()
if (
module.__class__.__name__ == "FusedMoE"
or module.__class__.__name__ == "SharedFusedMoE"
)
]
num_local_experts = moe_modules[0].moe_config.num_local_experts
assert all(
module.moe_config.num_local_experts == num_local_experts
for module in moe_modules
), "All MoE modules must have the same number of experts"
for module in moe_modules:
module.moe_config.num_experts = num_local_experts * new_ep_size
module.global_num_experts = module.moe_config.num_experts
tp_size = get_tp_group().world_size
is_sequence_parallel = parallel_config.use_sequence_parallel_moe
sp_size = tp_size if is_sequence_parallel else 1
module.moe_parallel_config = FusedMoEParallelConfig.make(
tp_size_=tp_size,
pcp_size_=get_pcp_group().world_size,
dp_size_=get_dp_group().world_size,
sp_size_=sp_size,
vllm_parallel_config=parallel_config,
)
module.moe_config.moe_parallel_config = module.moe_parallel_config
# Update EPLB state
eplb_state = self.worker.model_runner.eplb_state
assert eplb_state is not None
model_config = self.worker.model_runner.model_config
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
num_physical_experts = num_local_experts * new_ep_size
num_logical_experts = eplb_model_state.logical_replica_count.shape[1]
parallel_config.eplb_config.num_redundant_experts = (
num_physical_experts - num_logical_experts
)
old_physical_to_logical = eplb_model_state.physical_to_logical_map
num_moe_layers = old_physical_to_logical.shape[0]
num_local_experts = eplb_model_state.expert_load_pass.shape[1] // old_ep_size
if new_dp_size > old_dp_size:
expanded_physical_to_logical = torch.full(
(num_moe_layers, num_local_experts * new_ep_size),
-1,
dtype=old_physical_to_logical.dtype,
device=old_physical_to_logical.device,
)
expanded_physical_to_logical[:, : num_local_experts * old_ep_size] = (
old_physical_to_logical
)
eplb_model_state.physical_to_logical_map = expanded_physical_to_logical
old_num_physical_experts = eplb_model_state.expert_load_pass.shape[1]
pad_size = num_physical_experts - old_num_physical_experts
if new_dp_size > old_dp_size:
assert pad_size > 0
expanded_expert_load_pass = F.pad(
eplb_model_state.expert_load_pass, (0, pad_size), value=0
)
expanded_expert_load_window = F.pad(
eplb_model_state.expert_load_window, (0, pad_size), value=0
)
eplb_model_state.expert_load_pass = expanded_expert_load_pass
eplb_model_state.expert_load_window = expanded_expert_load_window
eplb_state.num_valid_physical_experts = old_num_physical_experts
else:
assert pad_size < 0
eplb_model_state.expert_load_pass = eplb_model_state.expert_load_pass[
:, :num_physical_experts
]
eplb_model_state.expert_load_window = eplb_model_state.expert_load_window[
:, :, :num_physical_experts
]
eplb_state.num_valid_physical_experts = num_physical_experts
model = self.worker.model_runner.get_model()
model.expert_weights = []
with set_current_vllm_config(self.worker.vllm_config):
model.set_eplb_state(
eplb_model_state.expert_load_pass,
eplb_model_state.logical_to_physical_map,
eplb_model_state.logical_replica_count,
)
model.update_physical_experts_metadata(
num_physical_experts=num_physical_experts,
num_local_physical_experts=num_local_experts,
)
# Force re-creation of the modular kernel (and all2all manager)
# for the new EP size by resetting quant_method to base
for module in moe_modules:
if hasattr(module.quant_method, "old_quant_method"):
module.quant_method = module.quant_method.old_quant_method
module.runner = module._init_runner()
prepare_communication_buffer_for_model(self.worker.model_runner.model)
if (
self.worker.vllm_config.compilation_config.mode
== CompilationMode.STOCK_TORCH_COMPILE
):
# NOTE(yongji): when using stock torch.compile,
# torch.compile is triggered during GPUModelRunner's load_model()
# TODO(yongji):check do we need to re-trigger torch.compile here?
# any changes to the tensor shapes in execution should already
# be handled internally by torch.compile.
backend = self.worker.vllm_config.compilation_config.init_backend(
self.worker.vllm_config
)
compilation_counter.stock_torch_compile_count += 1
self.worker.model_runner.model.compile(fullgraph=True, backend=backend)
# release all previously captured CUDA graphs
if isinstance(self.worker.model_runner.model, CUDAGraphWrapper):
wrapper = self.worker.model_runner.model
wrapper.concrete_cudagraph_entries = {}
elif isinstance(self.worker.model_runner.model, UBatchWrapper):
raise RuntimeError("DBO is not yet supported in elastic EP")
multi_block_table = self.worker.model_runner.input_batch.block_table
saved_block_tables: list[tuple[torch.Tensor, torch.Tensor]] = []
for bt in multi_block_table.block_tables:
saved_block_tables.append(
(bt.block_table.gpu.clone(), bt.block_table.cpu.clone())
)
multi_block_table.clear()
# reset the compile wrapper
torch.compiler.reset()
with set_current_vllm_config(self.worker.vllm_config):
reset_compile_wrapper(self.worker.model_runner.get_model())
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()
unlock_workspace()
self.worker.compile_or_warm_up_model()
lock_workspace()
for bt, (saved_gpu, saved_cpu) in zip(
multi_block_table.block_tables, saved_block_tables
):
bt.block_table.gpu.copy_(saved_gpu)
bt.block_table.cpu.copy_(saved_cpu)
def perform_eplb_reshuffle(self, new_dp_size: int | None = None) -> None:
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Starting expert resharding...")
eplb_state = self.worker.model_runner.eplb_state
assert eplb_state is not None
model_config = self.worker.model_runner.model_config
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
is_async_enabled = eplb_state.is_async
eplb_state.is_async = False
if new_dp_size is None:
eplb_state.rearrange()
else:
# scale down
parallel_config = self.worker.vllm_config.parallel_config
tp_size = parallel_config.tensor_parallel_size
old_ep_size = parallel_config.data_parallel_size * tp_size
new_ep_size = new_dp_size * tp_size
rank_mapping = {
old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
for old_ep_rank in range(old_ep_size)
}
eplb_state.rearrange(rank_mapping=rank_mapping)
# NOTE(yongji): check whether we need to synchronize here
torch.cuda.synchronize()
# reset expert_rearrangement_step to ensure all ranks are synchronized
eplb_state.expert_rearrangement_step = 0
eplb_state.num_valid_physical_experts = (
eplb_model_state.physical_to_logical_map.shape[1]
)
eplb_state.is_async = is_async_enabled
self.worker.model_runner.eep_eplb_suppressed = False
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Expert resharding completed")
def receive_weights(self) -> None:
dp_group = get_dp_group()
assert isinstance(dp_group, StatelessGroupCoordinator)
new_dp_size = dp_group.world_size
dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank
# Receive old_dp_size broadcasted during transfer_weights
old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu")
old_dp_size_tensor = dp_group.tcp_store_group.broadcast(old_dp_size_tensor, 0)
old_dp_size = int(old_dp_size_tensor[0].item())
# Calculate which existing worker will send to this new worker
num_new_workers = new_dp_size - old_dp_size
new_worker_idx = dp_rank - old_dp_size
num_dst_per_sender = num_new_workers // old_dp_size
remainder = num_new_workers % old_dp_size
if new_worker_idx < remainder * (num_dst_per_sender + 1):
sender_rank = new_worker_idx // (num_dst_per_sender + 1)
else:
sender_rank = (
remainder
+ (new_worker_idx - remainder * (num_dst_per_sender + 1))
// num_dst_per_sender
)
model = self.worker.model_runner.get_model()
batch_transfer_weights(
model=model,
is_sender=False,
peer_rank=sender_rank,
dp_group=dp_group,
expert_weights=model.expert_weights,
)
torch.cuda.synchronize()
def receive_expert_mapping(self) -> tuple[torch.Tensor, int, int]:
dp_group = get_dp_group()
assert isinstance(dp_group, StatelessGroupCoordinator)
physical_to_logical, num_local_physical_experts, num_logical_experts = (
broadcast_expert_mapping(
physical_to_logical=None,
num_local_physical_experts=None,
num_logical_experts=None,
dp_group=dp_group,
src_rank=0,
device=self.worker.device,
)
)
num_moe_layers = physical_to_logical.shape[0]
new_dp_size = get_dp_group().world_size
tp_size = self.worker.vllm_config.parallel_config.tensor_parallel_size
new_ep_size = new_dp_size * tp_size
expanded_physical_to_logical = torch.full(
(num_moe_layers, num_local_physical_experts * new_ep_size),
-1,
dtype=physical_to_logical.dtype,
device=physical_to_logical.device,
)
old_num_physical_experts = physical_to_logical.shape[1]
expanded_physical_to_logical[:, :old_num_physical_experts] = physical_to_logical
return (
expanded_physical_to_logical,
num_logical_experts,
old_num_physical_experts,
)
def prepare_new_worker(self) -> None:
with set_current_vllm_config(self.worker.vllm_config):
prepare_communication_buffer_for_model(self.worker.model_runner.get_model())
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import time
import weakref
from datetime import timedelta
from typing import TYPE_CHECKING, Literal
import torch.distributed
from vllm.config import ParallelConfig
from vllm.distributed import (
sched_yield,
stateless_destroy_torch_distributed_process_group,
)
from vllm.logger import init_logger
from vllm.v1.engine import (
EEPNotificationType,
ReconfigureDistributedRequest,
ReconfigureRankType,
)
from vllm.v1.engine.core import DPEngineCoreProc
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.v1.executor.abstract import Executor
logger = init_logger(__name__)
WorkerType = Literal["existing", "new", "removing"]
class ScaleUpExistingEngineState(enum.IntEnum):
WAIT_NEW_CORE_ENGINES_INIT = 0
CREATE_STANDBY_GROUPS = 1
TRANSFER_EXPERT_MAPPING = 2
WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT = 3
TRANSFER_WEIGHTS = 4
SYNC_KV_CACHE_MEMORY_SIZE = 5
SWITCH_AND_PREPARE = 6
EPLB_RESHUFFLE = 7
COMPLETE = 8
class ScaleUpNewEngineState(enum.IntEnum):
PREPARE = 0
EPLB_RESHUFFLE = 1
COMPLETE = 2
class ScaleDownRemainingEngineState(enum.IntEnum):
PREPARE = 0
EPLB_RESHUFFLE = 1
SWITCH_AND_PREPARE = 2
COMPLETE = 3
class ScaleDownRemovingEngineState(enum.IntEnum):
PREPARE = 0
EPLB_RESHUFFLE = 1
COMPLETE = 2
class _BarrierTimeoutError(RuntimeError):
"""
Exception raised for timeout
in the first stage of our two-staged
TCPStore based barrier to synchronize the
execution of all engines in the DP group.
"""
class ElasticEPScalingState:
def __init__(
self,
model_executor: "Executor",
engine_core: "DPEngineCoreProc",
vllm_config: "VllmConfig",
new_parallel_config: ParallelConfig,
worker_type: WorkerType,
scale_type: Literal["scale_up", "scale_down"],
reconfig_request: ReconfigureDistributedRequest | None = None,
):
self.model_executor_ref = weakref.ref(model_executor)
self.engine_core_ref = weakref.ref(engine_core)
self.vllm_config = vllm_config
self.old_dp_group = self.engine_core.dp_group if worker_type != "new" else None
self.old_dp_store = self.engine_core.dp_store if worker_type != "new" else None
self.new_parallel_config: ParallelConfig = new_parallel_config
self.new_dp_group: torch.distributed.ProcessGroup | None = (
self.engine_core.dp_group if worker_type == "new" else None
)
self.new_dp_store = self.engine_core.dp_store if worker_type == "new" else None
self.worker_type = worker_type
self.scale_type = scale_type
self.reconfig_request = reconfig_request
if scale_type == "scale_up":
self.state = (
ScaleUpNewEngineState.PREPARE
if worker_type == "new"
else ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT
)
else:
self.state = (
ScaleDownRemovingEngineState.PREPARE
if worker_type == "removing"
else ScaleDownRemainingEngineState.PREPARE
)
@property
def model_executor(self) -> "Executor":
model_executor = self.model_executor_ref()
if model_executor is None:
raise RuntimeError("Model executor has been garbage collected")
return model_executor
@property
def engine_core(self) -> "DPEngineCoreProc":
engine_core = self.engine_core_ref()
if engine_core is None:
raise RuntimeError("Engine core has been garbage collected")
return engine_core
def progress(self) -> bool:
if self.scale_type == "scale_up":
return (
self._progress_new_engine()
if self.worker_type == "new"
else self._progress_existing_engine()
)
return (
self._progress_removing_engine()
if self.worker_type == "removing"
else self._progress_remaining_engine()
)
def _execute_tcp_store_barrier(
self, dp_store, group_rank, group_size, barrier_id, timeout=None
):
arrival_key = f"arrival_{barrier_id}_{group_rank}"
dp_store.set(arrival_key, b"1")
start_time = time.time()
processes_arrived: set[int] = set()
while len(processes_arrived) < group_size:
if (
timeout is not None
and time.time() - start_time > timeout.total_seconds()
):
raise _BarrierTimeoutError(
f"Barrier timed out after {timeout.total_seconds()} seconds"
)
for i in range(group_size):
if i in processes_arrived:
continue
key = f"arrival_{barrier_id}_{i}"
present = dp_store.check([key])
if present:
processes_arrived.add(i)
if len(processes_arrived) < group_size:
sched_yield()
def _staged_barrier(self, use_new_group: bool, barrier_name: str) -> bool:
"""
Execute a two-staged barrier to synchronize all engines in the DP group.
Some DP EngineCores may receive the reconfiguration notifications
later than others, and already proceed to engine step (model forward)
in the busy loop.
In this case, EngineCores that already proceed to reconfiguration
should skip reconfiguration and execute model forward for one more
step, so in the next step, all EngineCores will be synchronized.
We use a two-staged barrier to achieve this. The first time each
EngineCore executes the barrier, if a timeout is reached before the
barrier completes, that means some EngineCores have already entered
engine step. The EngineCores that timed out will then proceed to
engine step, and will synchronize with the other EngineCores in the
next step with a barrier without timeout.
"""
dp_store = self.new_dp_store if use_new_group else self.old_dp_store
dp_group = self.new_dp_group if use_new_group else self.old_dp_group
assert dp_group is not None
group_rank = dp_group.rank()
group_size = dp_group.size()
barrier_id = f"eep_barrier_{barrier_name}"
sync_key = f"{barrier_id}_sync"
# TODO(yongji): figure out appropriate timeout for the barrier
timeout = None if dp_store.check([sync_key]) else timedelta(seconds=5)
try:
self._execute_tcp_store_barrier(
dp_store, group_rank, group_size, barrier_id, timeout=timeout
)
torch.distributed.barrier(dp_group)
if group_rank == 0:
dp_store.delete_key(sync_key)
for i in range(group_size):
dp_store.delete_key(f"arrival_{barrier_id}_{i}")
return True
except _BarrierTimeoutError as e:
if timeout is None:
raise RuntimeError("Unexpected timeout encountered") from e
dp_store.compare_set(sync_key, "", b"1")
return False
def _progress_existing_engine(self) -> bool:
state = self.state
if state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT:
return False
elif state == ScaleUpExistingEngineState.CREATE_STANDBY_GROUPS:
# NOTE(yongji): wait for all existing workers to receive the request
if (
int(self.old_dp_store.get("eep_barrier_engine_count"))
< self.old_dp_group.size()
):
return False
if not self._staged_barrier(
use_new_group=False, barrier_name="create_standby_groups"
):
return False
if self.old_dp_group.rank() == 0:
self.old_dp_store.delete_key("eep_barrier_engine_count")
self._create_standby_groups()
self.state = ScaleUpExistingEngineState.TRANSFER_EXPERT_MAPPING
return True
elif state == ScaleUpExistingEngineState.TRANSFER_EXPERT_MAPPING:
self._transfer_expert_mapping()
self.state = ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT
return True
elif state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT:
return False
elif state == ScaleUpExistingEngineState.TRANSFER_WEIGHTS:
if (
int(self.old_dp_store.get("eep_barrier_engine_count"))
< self.old_dp_group.size()
):
return False
if not self._staged_barrier(
use_new_group=False, barrier_name="transfer_weights"
):
return False
if self.old_dp_group.rank() == 0:
self.old_dp_store.delete_key("eep_barrier_engine_count")
self._transfer_weights()
self.state = ScaleUpExistingEngineState.SYNC_KV_CACHE_MEMORY_SIZE
return True
elif state == ScaleUpExistingEngineState.SYNC_KV_CACHE_MEMORY_SIZE:
self._sync_kv_cache_memory_size()
self.state = ScaleUpExistingEngineState.SWITCH_AND_PREPARE
return True
elif state == ScaleUpExistingEngineState.SWITCH_AND_PREPARE:
self._switch_and_prepare()
self.state = ScaleUpExistingEngineState.EPLB_RESHUFFLE
self.new_dp_store.add("eep_barrier_engine_count", 1)
return True
elif state == ScaleUpExistingEngineState.EPLB_RESHUFFLE:
assert self.new_dp_group is not None
if (
int(self.new_dp_store.get("eep_barrier_engine_count"))
< self.new_dp_group.size()
):
return False
if not self._staged_barrier(
use_new_group=True, barrier_name="eplb_reshuffle"
):
return False
if self.new_dp_group.rank() == 0:
self.new_dp_store.delete_key("eep_barrier_engine_count")
self._eplb_reshuffle()
self.state = ScaleUpExistingEngineState.COMPLETE
self._update_parallel_config()
return True
else:
assert self.state == ScaleUpExistingEngineState.COMPLETE
return True
def _progress_new_engine(self) -> bool:
state = self.state
assert self.new_dp_group is not None
if state == ScaleUpNewEngineState.PREPARE:
tensor = torch.tensor([0, 0, 0], dtype=torch.int32, device="cpu")
torch.distributed.all_reduce(
tensor,
op=torch.distributed.ReduceOp.MAX,
group=self.new_dp_group,
)
data = tensor.tolist()
self.engine_core.engines_running = bool(data[0])
self.engine_core.current_wave = int(data[1])
self.engine_core.step_counter = int(data[2])
self.state = ScaleUpNewEngineState.EPLB_RESHUFFLE
self.new_dp_store.add("eep_barrier_engine_count", 1)
return True
elif state == ScaleUpNewEngineState.EPLB_RESHUFFLE:
if (
int(self.new_dp_store.get("eep_barrier_engine_count"))
< self.new_dp_group.size()
):
return False
if not self._staged_barrier(
use_new_group=True, barrier_name="eplb_reshuffle"
):
return False
assert self.new_dp_group.rank() > 0
self._eplb_reshuffle()
self.state = ScaleUpNewEngineState.COMPLETE
return True
else:
assert self.state == ScaleUpNewEngineState.COMPLETE
return True
def _progress_remaining_engine(self) -> bool:
state = self.state
if state == ScaleDownRemainingEngineState.PREPARE:
self.state = ScaleDownRemainingEngineState.EPLB_RESHUFFLE
self.old_dp_store.add("eep_barrier_engine_count", 1)
return True
elif state == ScaleDownRemainingEngineState.EPLB_RESHUFFLE:
if (
int(self.old_dp_store.get("eep_barrier_engine_count"))
< self.old_dp_group.size()
):
return False
if not self._staged_barrier(
use_new_group=False, barrier_name="eplb_reshuffle"
):
return False
if self.old_dp_group.rank() == 0:
self.old_dp_store.delete_key("eep_barrier_engine_count")
self._eplb_reshuffle_before_scale_down()
self.state = ScaleDownRemainingEngineState.SWITCH_AND_PREPARE
# NOTE(yongji): currently, after EPLB reshuffle
# that redistributes experts to remaining workers, workers
# to be removed will immediately initiate shutdown;
# existing workers can no longer execute forward steps using
# the old setup. In the future, we may keep
# the removing workers alive a bit longer,
# e.g., to drain in-batch requests.
self._create_standby_groups()
self._switch_and_prepare()
self._update_parallel_config()
self.state = ScaleDownRemainingEngineState.COMPLETE
return True
else:
assert self.state == ScaleDownRemainingEngineState.COMPLETE
return True
def _progress_removing_engine(self) -> bool:
state = self.state
if state == ScaleDownRemovingEngineState.PREPARE:
self.state = ScaleDownRemovingEngineState.EPLB_RESHUFFLE
self.old_dp_store.add("eep_barrier_engine_count", 1)
return True
if state == ScaleDownRemovingEngineState.EPLB_RESHUFFLE:
if (
int(self.old_dp_store.get("eep_barrier_engine_count"))
< self.old_dp_group.size()
):
return False
if not self._staged_barrier(
use_new_group=False, barrier_name="eplb_reshuffle"
):
return False
assert self.old_dp_group.rank() > 0
self._eplb_reshuffle_before_scale_down()
self._switch_and_remove()
self.state = ScaleDownRemovingEngineState.COMPLETE
self.engine_core._eep_send_engine_core_notification(
EEPNotificationType.SHUTDOWN_COMPLETE
)
self.engine_core.shutdown()
return True
else:
assert self.state == ScaleDownRemovingEngineState.COMPLETE
return True
def handle_notification(self, notification_type: EEPNotificationType):
assert self.worker_type != "new"
if (
notification_type == EEPNotificationType.NEW_CORE_ENGINES_INIT_READY
and self.state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT
):
self.old_dp_store.add("eep_barrier_engine_count", 1)
self.state = ScaleUpExistingEngineState.CREATE_STANDBY_GROUPS
elif (
notification_type == EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY
and self.state
== ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT
):
self.old_dp_store.add("eep_barrier_engine_count", 1)
self.state = ScaleUpExistingEngineState.TRANSFER_WEIGHTS
def is_complete(self) -> bool:
if self.scale_type == "scale_up":
return (
self.state == ScaleUpNewEngineState.COMPLETE
if self.worker_type == "new"
else self.state == ScaleUpExistingEngineState.COMPLETE
)
return (
self.state == ScaleDownRemovingEngineState.COMPLETE
if self.worker_type == "removing"
else self.state == ScaleDownRemainingEngineState.COMPLETE
)
def _create_standby_groups(self):
self.new_dp_group, self.new_dp_store = (
self.new_parallel_config.stateless_init_dp_group(return_store=True)
)
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("create_standby_groups", self.reconfig_request)
)
if self.old_dp_group.rank() == 0:
logger.info("[Elastic EP] Created standby communication groups")
def _transfer_weights(self):
assert self.reconfig_request is not None
old_dp_size = self.old_dp_group.size()
new_dp_size = self.reconfig_request.new_data_parallel_size
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("transfer_weights", old_dp_size, new_dp_size)
)
if self.old_dp_group.rank() == 0:
logger.info("[Elastic EP] Transferred weights to new workers")
def _transfer_expert_mapping(self):
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("broadcast_expert_mapping",)
)
if self.old_dp_group.rank() == 0:
logger.info("[Elastic EP] Broadcasted expert mapping to new workers")
def _sync_kv_cache_memory_size(self):
assert self.engine_core.available_gpu_memory_for_kv_cache > 0
assert self.new_dp_group is not None
ParallelConfig.sync_kv_cache_memory_size(
self.new_dp_group,
self.engine_core.available_gpu_memory_for_kv_cache,
)
if self.old_dp_group.rank() == 0:
logger.info("[Elastic EP] Synced KV cache memory size to new workers")
def _switch_and_prepare(self):
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("switch_and_prepare",)
)
old_dp_group = self.old_dp_group
stateless_destroy_torch_distributed_process_group(old_dp_group)
assert self.new_dp_group is not None
new_dp_group = self.new_dp_group
self.engine_core.dp_group = new_dp_group
self.engine_core.dp_rank = new_dp_group.rank()
self.engine_core.dp_store = self.new_dp_store
engines_running = int(self.engine_core.engines_running)
current_wave = self.engine_core.current_wave
step_counter = self.engine_core.step_counter
tensor = torch.tensor(
[engines_running, current_wave, step_counter],
dtype=torch.int32,
device="cpu",
)
torch.distributed.all_reduce(
tensor, op=torch.distributed.ReduceOp.MAX, group=new_dp_group
)
data = tensor.tolist()
self.engine_core.engines_running = bool(data[0])
self.engine_core.current_wave = int(data[1])
self.engine_core.step_counter = int(data[2])
if new_dp_group.rank() == 0:
self.engine_core._eep_send_engine_core_notification(
EEPNotificationType.RECONFIGURE_FINISHED
)
logger.info("[Elastic EP] Switched to new setup")
def _eplb_reshuffle(self):
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("perform_eplb_reshuffle",)
)
assert self.new_dp_group is not None
if self.new_dp_group.rank() == 0:
logger.info("[Elastic EP] EPLB reshuffle completed")
def _eplb_reshuffle_before_scale_down(self):
assert self.reconfig_request is not None
self.model_executor.collective_rpc(
"elastic_ep_execute",
args=(
"perform_eplb_reshuffle",
self.reconfig_request.new_data_parallel_size,
),
)
if self.old_dp_group.rank() == 0:
logger.info("[Elastic EP] EPLB reshuffle completed")
def _switch_and_remove(self):
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("switch_and_remove",)
)
def _update_parallel_config(self):
assert self.reconfig_request is not None
reconfig_request = self.reconfig_request
parallel_config = self.vllm_config.parallel_config
parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
if (
reconfig_request.new_data_parallel_rank
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
if (
reconfig_request.new_data_parallel_rank_local
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
parallel_config.data_parallel_rank_local = (
reconfig_request.new_data_parallel_rank_local
)
parallel_config.data_parallel_master_ip = (
reconfig_request.new_data_parallel_master_ip
)
parallel_config.data_parallel_master_port = (
reconfig_request.new_data_parallel_master_port
)
parallel_config._data_parallel_master_port_list = (
reconfig_request.new_data_parallel_master_port_list
)
parallel_config._stateless_world_group_port_list = (
reconfig_request.new_stateless_world_group_port_list
)
parallel_config._stateless_dp_group_port_list = (
reconfig_request.new_stateless_dp_group_port_list
)
parallel_config._stateless_ep_group_port_list = (
reconfig_request.new_stateless_ep_group_port_list
)
parallel_config._stateless_eplb_group_port_list = (
reconfig_request.new_stateless_eplb_group_port_list
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.distributed.parallel_state import (
_init_stateless_group,
_node_count,
get_pp_group,
get_tp_group,
get_world_group,
)
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
_STANDBY_WORLD: StatelessGroupCoordinator | None = None
_STANDBY_WORLD_NODE_COUNT: int | None = None
_STANDBY_DP: StatelessGroupCoordinator | None = None
_STANDBY_EP: StatelessGroupCoordinator | None = None
_STANDBY_EPLB: StatelessGroupCoordinator | None = None
def get_standby_dp_group() -> StatelessGroupCoordinator | None:
return _STANDBY_DP
def get_standby_ep_group() -> StatelessGroupCoordinator | None:
return _STANDBY_EP
def get_standby_eplb_group() -> StatelessGroupCoordinator | None:
return _STANDBY_EPLB
def get_standby_world_group() -> StatelessGroupCoordinator | None:
return _STANDBY_WORLD
def create_standby_groups(
new_dp_size: int,
new_world_size_across_dp: int,
master_ip: str,
world_group_ports: list[list[int]],
dp_group_ports: list[list[int]],
ep_group_ports: list[list[int]],
eplb_group_ports: list[list[int]] | None = None,
backend: str | None = None,
) -> None:
global \
_STANDBY_WORLD, \
_STANDBY_WORLD_NODE_COUNT, \
_STANDBY_DP, \
_STANDBY_EP, \
_STANDBY_EPLB
assert new_world_size_across_dp == torch.distributed.get_world_size() * new_dp_size
world_group = get_world_group()
assert isinstance(world_group, StatelessGroupCoordinator)
backend = backend or world_group.backend
standby_world_ranks = [list(range(new_world_size_across_dp))]
_STANDBY_WORLD = _init_stateless_group(
standby_world_ranks,
"world",
world_group_ports,
master_ip,
backend,
use_device_communicator=False,
)
_STANDBY_WORLD_NODE_COUNT = _node_count(_STANDBY_WORLD.tcp_store_group)
tp_size = get_tp_group().world_size
pp_size = get_pp_group().world_size
all_ranks = torch.arange(new_world_size_across_dp).reshape(
-1, new_dp_size, pp_size, tp_size
)
standby_dp_ranks = all_ranks.transpose(1, 3).reshape(-1, new_dp_size).unbind(0)
standby_dp_ranks = [x.tolist() for x in standby_dp_ranks]
_STANDBY_DP = _init_stateless_group(
standby_dp_ranks, "dp", dp_group_ports, master_ip, backend
)
standby_ep_ranks = (
all_ranks.transpose(1, 2).reshape(-1, new_dp_size * tp_size).unbind(0)
)
standby_ep_ranks = [x.tolist() for x in standby_ep_ranks]
_STANDBY_EP = _init_stateless_group(
standby_ep_ranks, "ep", ep_group_ports, master_ip, backend
)
if eplb_group_ports is not None:
_STANDBY_EPLB = _init_stateless_group(
standby_ep_ranks, "eplb", eplb_group_ports, master_ip, backend
)
def pop_standby_groups() -> dict:
"""Return all standby groups and clear the standby state."""
global \
_STANDBY_WORLD, \
_STANDBY_WORLD_NODE_COUNT, \
_STANDBY_DP, \
_STANDBY_EP, \
_STANDBY_EPLB
result = dict(
world=_STANDBY_WORLD,
dp=_STANDBY_DP,
ep=_STANDBY_EP,
eplb=_STANDBY_EPLB,
node_count=_STANDBY_WORLD_NODE_COUNT,
)
_STANDBY_WORLD = None
_STANDBY_WORLD_NODE_COUNT = None
_STANDBY_DP = None
_STANDBY_EP = None
_STANDBY_EPLB = None
return result
...@@ -24,7 +24,6 @@ logger = init_logger(__name__) ...@@ -24,7 +24,6 @@ logger = init_logger(__name__)
def start_async_worker( def start_async_worker(
state: "EplbState", state: "EplbState",
rank_mapping: dict[int, int] | None = None,
is_profile: bool = False, is_profile: bool = False,
) -> threading.Thread: ) -> threading.Thread:
eplb_group = get_eplb_group().device_group eplb_group = get_eplb_group().device_group
...@@ -45,7 +44,6 @@ def start_async_worker( ...@@ -45,7 +44,6 @@ def start_async_worker(
eplb_group=eplb_group, eplb_group=eplb_group,
cuda_stream=cuda_stream, cuda_stream=cuda_stream,
is_profile=is_profile, is_profile=is_profile,
rank_mapping=rank_mapping,
) )
) )
except Exception as exc: # pragma: no cover - diagnostic path except Exception as exc: # pragma: no cover - diagnostic path
...@@ -107,7 +105,6 @@ async def transfer_run_periodically( ...@@ -107,7 +105,6 @@ async def transfer_run_periodically(
eplb_group: ProcessGroup, eplb_group: ProcessGroup,
cuda_stream: torch.cuda.Stream, cuda_stream: torch.cuda.Stream,
is_profile: bool = False, is_profile: bool = False,
rank_mapping: dict[int, int] | None = None,
) -> None: ) -> None:
while True: while True:
await asyncio.to_thread(state.rearrange_event.wait) await asyncio.to_thread(state.rearrange_event.wait)
...@@ -176,7 +173,6 @@ async def transfer_run_periodically( ...@@ -176,7 +173,6 @@ async def transfer_run_periodically(
ep_group=eplb_group, ep_group=eplb_group,
is_profile=is_profile, is_profile=is_profile,
cuda_stream=cuda_stream, cuda_stream=cuda_stream,
rank_mapping=rank_mapping,
) )
event = torch.cuda.Event(blocking=False) event = torch.cuda.Event(blocking=False)
cuda_stream.record_event(event) cuda_stream.record_event(event)
......
...@@ -40,6 +40,7 @@ from vllm.distributed.parallel_state import ( ...@@ -40,6 +40,7 @@ from vllm.distributed.parallel_state import (
get_node_count, get_node_count,
in_the_same_node_as, in_the_same_node_as,
) )
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
from vllm.distributed.utils import StatelessProcessGroup from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MixtureOfExperts from vllm.model_executor.models.interfaces import MixtureOfExperts
...@@ -302,6 +303,14 @@ class EplbState: ...@@ -302,6 +303,14 @@ class EplbState:
""" """
CUDA device index for the async EPLB worker thread. CUDA device index for the async EPLB worker thread.
""" """
self.num_valid_physical_experts: int = 0
"""
Number of valid physical experts.
This is the number of physical experts that are
actually mapped to logical experts. In elastic EP,
newly started EP ranks may not have physical experts
mapped yet.
"""
if self.device.type == "cuda": if self.device.type == "cuda":
self.cuda_device_index = self.device.index self.cuda_device_index = self.device.index
if self.cuda_device_index is None and torch.cuda.is_available(): if self.cuda_device_index is None and torch.cuda.is_available():
...@@ -367,9 +376,6 @@ class EplbState: ...@@ -367,9 +376,6 @@ class EplbState:
self, self,
model: MixtureOfExperts, model: MixtureOfExperts,
model_config: ModelConfig, model_config: ModelConfig,
global_expert_load: torch.Tensor | None = None,
old_global_expert_indices: torch.Tensor | None = None,
rank_mapping: dict[int, int] | None = None,
): ):
""" """
Build the initial EPLB state. Build the initial EPLB state.
...@@ -462,75 +468,15 @@ class EplbState: ...@@ -462,75 +468,15 @@ class EplbState:
) )
self.expert_rearrangement_step_interval = eplb_step_interval self.expert_rearrangement_step_interval = eplb_step_interval
# Set the policy based on the selected eplb algorithm type.
policy_type = self.parallel_config.eplb_config.policy policy_type = self.parallel_config.eplb_config.policy
self.policy = EPLB_POLICIES[policy_type] self.policy = EPLB_POLICIES[policy_type]
logger.debug("Selected EPLB policy: %s", policy_type) logger.debug("Selected EPLB policy: %s", policy_type)
if global_expert_load is not None:
ep_group = get_ep_group().device_group
assert global_expert_load.shape == (
model.num_moe_layers,
model.num_logical_experts,
)
assert global_expert_load.dtype == torch.int64
num_replicas = model.num_physical_experts
num_groups = model.num_expert_groups
num_nodes = get_node_count()
num_gpus = ep_group.size()
if num_gpus % num_nodes != 0:
num_nodes = 1
logger.warning_once(
f"num_gpus % num_nodes != 0, "
"not using hierarchical rearrangement algorithm.\n"
f"{num_gpus=}, {num_nodes=}"
)
# Get new expert mappings
(
new_physical_to_logical_map,
new_logical_to_physical_map,
new_logical_replica_count,
) = self.policy.rebalance_experts(
global_expert_load,
num_replicas,
num_groups,
num_nodes,
num_gpus,
)
max_physical_slots = new_logical_to_physical_map.shape[-1]
assert max_physical_slots <= logical_to_physical_map.shape[-1]
new_logical_to_physical_map = torch.nn.functional.pad(
new_logical_to_physical_map,
(0, logical_to_physical_map.shape[-1] - max_physical_slots),
value=-1,
)
physical_to_logical_map = new_physical_to_logical_map.to(self.device)
logical_to_physical_map.copy_(new_logical_to_physical_map)
logical_replica_count.copy_(new_logical_replica_count)
else:
new_physical_to_logical_map = None
new_logical_to_physical_map = None
new_logical_replica_count = None
model.set_eplb_state( model.set_eplb_state(
expert_load_pass, expert_load_pass,
logical_to_physical_map, logical_to_physical_map,
logical_replica_count, logical_replica_count,
) )
if global_expert_load is not None:
rearrange_expert_weights_inplace(
old_global_expert_indices,
new_physical_to_logical_map,
model.expert_weights,
ep_group,
False,
rank_mapping,
)
self.expert_rearrangement_step = 0
expert_buffer = [torch.empty_like(w) for w in model.expert_weights[0]] expert_buffer = [torch.empty_like(w) for w in model.expert_weights[0]]
...@@ -561,11 +507,12 @@ class EplbState: ...@@ -561,11 +507,12 @@ class EplbState:
recv_dst_rows=np.array([]), recv_dst_rows=np.array([]),
), ),
cuda_device_index=self.cuda_device_index, cuda_device_index=self.cuda_device_index,
new_physical_to_logical_map=new_physical_to_logical_map, new_physical_to_logical_map=None,
new_logical_to_physical_map=new_logical_to_physical_map, new_logical_to_physical_map=None,
new_logical_replica_count=new_logical_replica_count, new_logical_replica_count=None,
) )
self.model_states[model_config.compute_hash()] = model_state self.model_states[model_config.compute_hash()] = model_state
self.num_valid_physical_experts = model.num_physical_experts
def step( def step(
self, self,
...@@ -696,8 +643,6 @@ class EplbState: ...@@ -696,8 +643,6 @@ class EplbState:
def rearrange( def rearrange(
self, self,
is_profile: bool = False, is_profile: bool = False,
execute_shuffle: bool = True,
global_expert_loads: list[torch.Tensor] | None = None,
rank_mapping: dict[int, int] | None = None, rank_mapping: dict[int, int] | None = None,
) -> torch.Tensor | None: ) -> torch.Tensor | None:
""" """
...@@ -707,12 +652,6 @@ class EplbState: ...@@ -707,12 +652,6 @@ class EplbState:
is_profile (bool): If `True`, perform a dummy rearrangement. is_profile (bool): If `True`, perform a dummy rearrangement.
This is used in `profile_run` to reserve enough memory, This is used in `profile_run` to reserve enough memory,
no memory movement will be performed. Default is False. no memory movement will be performed. Default is False.
execute_shuffle (bool): If `True`, execute the shuffle
in elastic expert parallel (EEP). Default is True.
global_expert_loads (list[torch.Tensor] | None): The global expert
loads when scaling is done in EEP.
List of expert loads for the main and drafter
(when spec decode is used) models.
rank_mapping (dict[int, int] | None): The rank mapping rank_mapping (dict[int, int] | None): The rank mapping
when scaling is done in EEP. when scaling is done in EEP.
""" """
...@@ -734,67 +673,34 @@ class EplbState: ...@@ -734,67 +673,34 @@ class EplbState:
"(profile)" if is_profile else "", "(profile)" if is_profile else "",
) )
if global_expert_loads is None: # Map the physical expert load to global logical experts
# Map the physical expert load to global logical experts global_expert_load_windows = []
global_expert_load_windows = [] for eplb_model_state in self.model_states.values():
if not execute_shuffle: expert_load_window = eplb_model_state.expert_load_window[
num_models = torch.tensor( :, :, : self.num_valid_physical_experts
[len(self.model_states)], dtype=torch.int32, device="cpu" ]
) logical_expert_load_window = torch.zeros(
torch.distributed.broadcast( self.expert_load_window_size,
num_models, group=get_ep_group().cpu_group, group_src=0 eplb_model_state.model.num_moe_layers,
) eplb_model_state.model.num_logical_experts,
dtype=eplb_model_state.expert_load_window.dtype,
for eplb_model_state in self.model_states.values(): device=eplb_model_state.expert_load_window.device,
logical_expert_load_window = torch.zeros( )
self.expert_load_window_size, logical_expert_load_window.scatter_add_(
eplb_model_state.model.num_moe_layers, dim=-1,
eplb_model_state.model.num_logical_experts, index=eplb_model_state.physical_to_logical_map[
dtype=eplb_model_state.expert_load_window.dtype, :, : self.num_valid_physical_experts
device=eplb_model_state.expert_load_window.device, ]
) .unsqueeze(0)
logical_expert_load_window.scatter_add_( .expand_as(expert_load_window)
dim=-1, .long(),
index=eplb_model_state.physical_to_logical_map.unsqueeze(0) src=expert_load_window,
.expand_as(eplb_model_state.expert_load_window)
.long(),
src=eplb_model_state.expert_load_window,
)
if not execute_shuffle:
metadata = torch.tensor(
[
eplb_model_state.model.num_moe_layers,
eplb_model_state.model.num_logical_experts,
eplb_model_state.physical_to_logical_map.shape[1],
],
dtype=torch.int32,
device="cpu",
)
torch.distributed.broadcast(
metadata, group=get_ep_group().cpu_group, group_src=0
)
global_expert_load_window = logical_expert_load_window.sum(dim=0)
global_expert_load_windows.append(global_expert_load_window)
# Perform all-reduce to get the expert load across all ranks for each model
global_expert_load_windows = self._allreduce_list(
global_expert_load_windows
) )
if not execute_shuffle:
for eplb_model_state, global_expert_load_window in zip( global_expert_load_window = logical_expert_load_window.sum(dim=0)
self.model_states.values(), global_expert_load_windows global_expert_load_windows.append(global_expert_load_window)
): # Perform all-reduce to get the expert load across all ranks for each model
# (num_moe_layers, old_num_physical_experts) global_expert_load_windows = self._allreduce_list(global_expert_load_windows)
old_global_expert_indices = eplb_model_state.physical_to_logical_map
torch.distributed.broadcast(
old_global_expert_indices, group=ep_group, group_src=0
)
if not execute_shuffle:
return global_expert_load_windows
else:
assert execute_shuffle
global_expert_load_windows = global_expert_loads
# TODO(bowen): Treat differently for prefill and decode nodes # TODO(bowen): Treat differently for prefill and decode nodes
eplb_model_state = next(iter(self.model_states.values())) eplb_model_state = next(iter(self.model_states.values()))
...@@ -806,8 +712,10 @@ class EplbState: ...@@ -806,8 +712,10 @@ class EplbState:
# NOTE(yongji): scale down, we need to rebalance the experts on # NOTE(yongji): scale down, we need to rebalance the experts on
# remaining GPUs, transfer the experts while we haven't shutdown # remaining GPUs, transfer the experts while we haven't shutdown
# the GPUs to be released. # the GPUs to be released.
cpu_group = get_ep_group().cpu_group coordinator = get_ep_group()
num_nodes = _node_count_with_rank_mapping(cpu_group, rank_mapping) assert isinstance(coordinator, StatelessGroupCoordinator)
tcp_store_group = coordinator.tcp_store_group
num_nodes = _node_count_with_rank_mapping(tcp_store_group, rank_mapping)
num_gpus = sum(new_rank != -1 for new_rank in rank_mapping.values()) num_gpus = sum(new_rank != -1 for new_rank in rank_mapping.values())
num_replicas = ( num_replicas = (
num_replicas // ep_group.size() * num_gpus num_replicas // ep_group.size() * num_gpus
...@@ -933,7 +841,6 @@ class EplbState: ...@@ -933,7 +841,6 @@ class EplbState:
if self.async_worker is None: if self.async_worker is None:
self.async_worker = start_async_worker( self.async_worker = start_async_worker(
self, self,
rank_mapping=rank_mapping,
is_profile=is_profile, is_profile=is_profile,
) )
...@@ -1089,83 +996,6 @@ class EplbState: ...@@ -1089,83 +996,6 @@ class EplbState:
model_state.new_logical_to_physical_map = None model_state.new_logical_to_physical_map = None
model_state.new_logical_replica_count = None model_state.new_logical_replica_count = None
@staticmethod
def recv_state() -> tuple[list[torch.Tensor], list[torch.Tensor]]:
"""
Receive the expert load and old placement from the master rank.
"""
ep_group = get_ep_group()
num_models = torch.empty(1, dtype=torch.int32, device="cpu")
torch.distributed.broadcast(num_models, group=ep_group.cpu_group, group_src=0)
num_models = num_models.item()
global_expert_loads = []
old_global_expert_indices_per_model = []
for _ in range(num_models):
metadata = torch.empty(3, dtype=torch.int32, device="cpu")
torch.distributed.broadcast(metadata, group=ep_group.cpu_group, group_src=0)
num_moe_layers, num_logical_experts, num_old_physical_experts = (
metadata.tolist()
)
global_expert_load = torch.zeros(
(num_moe_layers, num_logical_experts),
dtype=torch.int64,
device=ep_group.device,
)
all_reduce(global_expert_load, group=ep_group.device_group)
old_global_expert_indices = torch.empty(
(num_moe_layers, num_old_physical_experts),
dtype=torch.int64,
device=ep_group.device,
)
torch.distributed.broadcast(
old_global_expert_indices,
group=ep_group.device_group,
group_src=0,
)
global_expert_loads.append(global_expert_load)
old_global_expert_indices_per_model.append(old_global_expert_indices)
return global_expert_loads, old_global_expert_indices_per_model
@classmethod
def get_eep_state(
cls, parallel_config: ParallelConfig
) -> tuple[
list[torch.Tensor] | None,
list[torch.Tensor] | None,
dict[int, int] | None,
]:
num_local_physical_experts = torch.empty(1, dtype=torch.int32, device="cpu")
torch.distributed.broadcast(
num_local_physical_experts,
group=get_ep_group().cpu_group,
group_src=0,
)
num_local_physical_experts = int(num_local_physical_experts.item())
new_ep_size = get_ep_group().world_size
global_expert_loads, old_global_expert_indices_per_model = (
EplbState.recv_state()
)
# EP configuration for all models has to be the same so as eplb config
num_logical_experts = global_expert_loads[0].shape[1]
parallel_config.eplb_config.num_redundant_experts = (
num_local_physical_experts * new_ep_size - num_logical_experts
)
assert (
old_global_expert_indices_per_model[0].shape[1] % num_local_physical_experts
== 0
)
old_ep_size = (
old_global_expert_indices_per_model[0].shape[1]
// num_local_physical_experts
)
rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
return (
global_expert_loads,
old_global_expert_indices_per_model,
rank_mapping,
)
def _allreduce_list(self, tensor_list: list[torch.Tensor]) -> list[torch.Tensor]: def _allreduce_list(self, tensor_list: list[torch.Tensor]) -> list[torch.Tensor]:
""" """
All-reduce a list of tensors. All-reduce a list of tensors.
...@@ -1203,6 +1033,60 @@ class EplbState: ...@@ -1203,6 +1033,60 @@ class EplbState:
load_pass_list.append(eplb_model_state.expert_load_pass.clone()) load_pass_list.append(eplb_model_state.expert_load_pass.clone())
return self._allreduce_list(load_pass_list) return self._allreduce_list(load_pass_list)
@classmethod
def from_mapping(
cls,
model: MixtureOfExperts,
model_config: ModelConfig,
device: torch.device,
parallel_config: ParallelConfig,
expanded_physical_to_logical: torch.Tensor,
num_valid_physical_experts: int,
) -> "EplbState":
eplb_state = cls(
parallel_config=parallel_config,
device=device,
)
eplb_state.add_model(
model=model,
model_config=model_config,
)
eplb_state.num_valid_physical_experts = num_valid_physical_experts
num_moe_layers = expanded_physical_to_logical.shape[0]
num_physical_experts = expanded_physical_to_logical.shape[1]
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
eplb_model_state.physical_to_logical_map.copy_(expanded_physical_to_logical)
logical_to_physical_map = torch.full(
(
num_moe_layers,
model.num_logical_experts,
eplb_model_state.logical_to_physical_map.shape[2],
),
-1,
dtype=torch.int64,
)
logical_replica_count = torch.zeros(
(num_moe_layers, model.num_logical_experts),
dtype=torch.int64,
)
expanded_physical_to_logical_numpy = expanded_physical_to_logical.cpu().numpy()
for layer_idx in range(num_moe_layers):
for phys_idx in range(num_physical_experts):
logical_idx = expanded_physical_to_logical_numpy[layer_idx, phys_idx]
if logical_idx >= 0:
replica_idx = logical_replica_count[layer_idx, logical_idx]
logical_to_physical_map[layer_idx, logical_idx, replica_idx] = (
phys_idx
)
logical_replica_count[layer_idx, logical_idx] += 1
logical_to_physical_map = logical_to_physical_map.to(device)
logical_replica_count = logical_replica_count.to(device)
eplb_model_state.logical_to_physical_map.copy_(logical_to_physical_map)
eplb_model_state.logical_replica_count.copy_(logical_replica_count)
return eplb_state
@dataclass @dataclass
class EplbLayerState: class EplbLayerState:
......
...@@ -19,6 +19,8 @@ from torch.distributed import ( ...@@ -19,6 +19,8 @@ from torch.distributed import (
get_global_rank, get_global_rank,
) )
from vllm.distributed.parallel_state import get_ep_group
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -249,10 +251,18 @@ def move_to_buffer( ...@@ -249,10 +251,18 @@ def move_to_buffer(
b[dst].copy_(w[src_local], non_blocking=True) b[dst].copy_(w[src_local], non_blocking=True)
p2p_ops: list[P2POp] = [] p2p_ops: list[P2POp] = []
if isinstance(get_ep_group(), StatelessGroupCoordinator):
ep_group = get_ep_group()
is_stateless = True
else:
is_stateless = False
# Pre-compute global ranks mapping # Pre-compute global ranks mapping (only needed for non-stateless groups)
ep_size = ep_group.size() ep_size = ep_group.size()
rank_to_global = {rank: get_global_rank(ep_group, rank) for rank in range(ep_size)} if not is_stateless:
rank_to_global = {
rank: get_global_rank(ep_group, rank) for rank in range(ep_size)
}
# 2. Post sends # 2. Post sends
if send_count > 0: if send_count > 0:
...@@ -284,15 +294,23 @@ def move_to_buffer( ...@@ -284,15 +294,23 @@ def move_to_buffer(
if recver_pos < len(ranks_to_recv): if recver_pos < len(ranks_to_recv):
recv_ranks.append(ranks_to_recv[recver_pos]) recv_ranks.append(ranks_to_recv[recver_pos])
for dst in recv_ranks: for dst in recv_ranks:
dst_global = rank_to_global[dst] if is_stateless:
p2p_ops += [ for w in expert_weights:
P2POp( op = object.__new__(P2POp)
torch.distributed.isend, op.op = torch.distributed.isend
w[src], op.tensor = w[src]
dst_global, op.group_peer = dst
) p2p_ops.append(op)
for w in expert_weights else:
] dst_global = rank_to_global[dst]
p2p_ops += [
P2POp(
torch.distributed.isend,
w[src],
dst_global,
)
for w in expert_weights
]
# 3. Post recvs # 3. Post recvs
if recv_count > 0: if recv_count > 0:
...@@ -321,26 +339,40 @@ def move_to_buffer( ...@@ -321,26 +339,40 @@ def move_to_buffer(
src = ranks_to_send[recver_pos // num_dst_per_sender] src = ranks_to_send[recver_pos // num_dst_per_sender]
else: else:
src = ranks_to_send[recver_pos - remainder_start] src = ranks_to_send[recver_pos - remainder_start]
src_global = rank_to_global[src] if is_stateless:
p2p_ops += [ for b in expert_weights_buffers:
P2POp( op = object.__new__(P2POp)
torch.distributed.irecv, op.op = torch.distributed.irecv
b[dst], op.tensor = b[dst]
src_global, op.group_peer = src
) p2p_ops.append(op)
for b in expert_weights_buffers else:
] src_global = rank_to_global[src]
p2p_ops += [
P2POp(
torch.distributed.irecv,
b[dst],
src_global,
)
for b in expert_weights_buffers
]
# 4. Execute the P2P operations. The real communication happens here. # 4. Execute the P2P operations. The real communication happens here.
if p2p_ops and cuda_stream is not None: if p2p_ops and cuda_stream is not None:
with torch.cuda.stream(cuda_stream): with torch.cuda.stream(cuda_stream):
if is_stateless:
ep_group.device_communicator.batch_isend_irecv(p2p_ops)
else:
reqs = batch_isend_irecv(p2p_ops)
for req in reqs:
req.wait()
elif p2p_ops:
if is_stateless:
ep_group.device_communicator.batch_isend_irecv(p2p_ops)
else:
reqs = batch_isend_irecv(p2p_ops) reqs = batch_isend_irecv(p2p_ops)
for req in reqs: for req in reqs:
req.wait() req.wait()
elif p2p_ops:
reqs = batch_isend_irecv(p2p_ops)
for req in reqs:
req.wait()
# wait for the communication to finish # wait for the communication to finish
return ( return (
is_unchanged, is_unchanged,
......
...@@ -33,7 +33,7 @@ from contextlib import contextmanager, nullcontext ...@@ -33,7 +33,7 @@ from contextlib import contextmanager, nullcontext
from dataclasses import dataclass from dataclasses import dataclass
from datetime import timedelta from datetime import timedelta
from multiprocessing import shared_memory from multiprocessing import shared_memory
from typing import Any, Protocol from typing import TYPE_CHECKING, Any, Protocol
from unittest.mock import patch from unittest.mock import patch
import torch import torch
...@@ -55,6 +55,9 @@ from vllm.utils.torch_utils import ( ...@@ -55,6 +55,9 @@ from vllm.utils.torch_utils import (
direct_register_custom_op, direct_register_custom_op,
) )
if TYPE_CHECKING:
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
@dataclass @dataclass
class GraphCaptureContext: class GraphCaptureContext:
...@@ -1157,6 +1160,55 @@ def init_model_parallel_group( ...@@ -1157,6 +1160,55 @@ def init_model_parallel_group(
) )
def _init_stateless_group(
group_ranks: list[list[int]],
group_name: str,
group_ports: list[list[int]],
host: str,
backend: str,
use_device_communicator: bool = True,
) -> "StatelessGroupCoordinator":
"""Create a StatelessGroupCoordinator with the given parameters."""
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
world = get_world_group()
return StatelessGroupCoordinator(
group_ranks=group_ranks,
local_rank=world.local_rank,
torch_distributed_backend=backend,
use_device_communicator=use_device_communicator,
group_name=group_name,
host=host,
group_ports=group_ports,
global_rank=world.rank,
global_world_size=world.world_size,
)
def _replace_active_groups(
*,
world: GroupCoordinator | None,
dp: GroupCoordinator | None,
ep: GroupCoordinator | None,
eplb: GroupCoordinator | None,
node_count: int | None,
) -> None:
"""Destroy the current DP/EP/WORLD/EPLB groups and replace them.
Destruction is collective — all ranks in the old groups must call this
function together. Pass all-``None`` to tear down without replacement.
"""
global _WORLD, _DP, _EP, _EPLB, _NODE_COUNT
for group in (_DP, _EP, _WORLD, _EPLB):
if group is not None:
group.destroy()
_WORLD = world
_DP = dp
_EP = ep
_EPLB = eplb
_NODE_COUNT = node_count
_TP: GroupCoordinator | None = None _TP: GroupCoordinator | None = None
...@@ -1254,6 +1306,39 @@ def set_custom_all_reduce(enable: bool): ...@@ -1254,6 +1306,39 @@ def set_custom_all_reduce(enable: bool):
_ENABLE_CUSTOM_ALL_REDUCE = enable _ENABLE_CUSTOM_ALL_REDUCE = enable
def _init_elastic_ep_world(
config, local_rank: int, backend: str, rank: int, world_size: int
) -> None:
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
global _WORLD, _NODE_COUNT
assert _WORLD is None, "world group already initialized"
parallel_config = config.parallel_config
global_rank = parallel_config.data_parallel_rank * world_size + rank
global_world_size = parallel_config.world_size_across_dp
all_ranks = list(range(global_world_size))
group_ranks = [all_ranks[i : i + 1] for i in range(global_world_size)]
if global_rank in all_ranks:
group_ranks = [all_ranks]
group_ports = [parallel_config.get_next_stateless_world_group_port()]
world = StatelessGroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_device_communicator=False,
group_name="world",
host=parallel_config.data_parallel_master_ip,
group_ports=group_ports,
global_rank=global_rank,
global_world_size=global_world_size,
)
assert parallel_config.nnodes_within_dp == 1, (
"Elastic EP is not supported with multi-node TP/PP"
)
_NODE_COUNT = _node_count(world.tcp_store_group)
_WORLD = world
def init_distributed_environment( def init_distributed_environment(
world_size: int = -1, world_size: int = -1,
rank: int = -1, rank: int = -1,
...@@ -1273,6 +1358,7 @@ def init_distributed_environment( ...@@ -1273,6 +1358,7 @@ def init_distributed_environment(
from vllm.config import get_current_vllm_config_or_none from vllm.config import get_current_vllm_config_or_none
config = get_current_vllm_config_or_none() config = get_current_vllm_config_or_none()
enable_elastic_ep = config is not None and config.parallel_config.enable_elastic_ep
if ( if (
config is not None config is not None
and config.parallel_config.distributed_executor_backend != "external_launcher" and config.parallel_config.distributed_executor_backend != "external_launcher"
...@@ -1280,6 +1366,7 @@ def init_distributed_environment( ...@@ -1280,6 +1366,7 @@ def init_distributed_environment(
config.parallel_config.nnodes > 1 config.parallel_config.nnodes > 1
or config.parallel_config.data_parallel_size > 1 or config.parallel_config.data_parallel_size > 1
) )
and not enable_elastic_ep
): ):
parallel_config = config.parallel_config parallel_config = config.parallel_config
# adjust to take into account data parallelism # adjust to take into account data parallelism
...@@ -1333,6 +1420,18 @@ def init_distributed_environment( ...@@ -1333,6 +1420,18 @@ def init_distributed_environment(
rank=rank, rank=rank,
timeout=timeout, timeout=timeout,
) )
if enable_elastic_ep:
tp_pp_cpu_group = torch.distributed.new_group(
backend="gloo", timeout=timeout
)
if _node_count(tp_pp_cpu_group) > 1:
# NOTE(yongji): StatelessGroupCoordinator uses data_parallel_master_ip
# to initialize all DP/EP groups, hence all ranks within TP/PP group
# must reside on the same node
raise RuntimeError(
"Elastic EP is not yet supported with multi-node TP/PP"
)
# set the local rank # set the local rank
# local_rank is not available in torch ProcessGroup, # local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816 # see https://github.com/pytorch/pytorch/issues/122816
...@@ -1341,6 +1440,9 @@ def init_distributed_environment( ...@@ -1341,6 +1440,9 @@ def init_distributed_environment(
# setting, where we can use rank as local rank # setting, where we can use rank as local rank
local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank
global _WORLD, _NODE_COUNT, _INNER_DP_WORLD global _WORLD, _NODE_COUNT, _INNER_DP_WORLD
if enable_elastic_ep:
_init_elastic_ep_world(config, local_rank, backend, rank, world_size)
return
if _WORLD is None: if _WORLD is None:
ranks = list(range(torch.distributed.get_world_size())) ranks = list(range(torch.distributed.get_world_size()))
_WORLD = init_world_group(ranks, local_rank, backend) _WORLD = init_world_group(ranks, local_rank, backend)
...@@ -1404,16 +1506,33 @@ def initialize_model_parallel( ...@@ -1404,16 +1506,33 @@ def initialize_model_parallel(
""" """
# Get world size and rank. Ensure some consistencies. # Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized() assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
data_parallel_size = 1
from vllm.config import get_current_vllm_config_or_none
config = get_current_vllm_config_or_none() from vllm.config import get_current_vllm_config
if config is not None:
data_parallel_size = config.parallel_config.data_parallel_size config = get_current_vllm_config()
data_parallel_size = config.parallel_config.data_parallel_size
enable_elastic_ep = config.parallel_config.enable_elastic_ep
if enable_elastic_ep:
# Use stateless world group for global information
world_size = get_world_group().world_size
rank = get_world_group().rank
backend = backend or "nccl"
tp_pp_pcp_size = (
tensor_model_parallel_size
* pipeline_model_parallel_size
* prefill_context_model_parallel_size
)
local_all_ranks = torch.arange(tp_pp_pcp_size).reshape(
pipeline_model_parallel_size,
prefill_context_model_parallel_size,
tensor_model_parallel_size,
)
else:
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
backend = backend or torch.distributed.get_backend(
get_world_group().device_group
)
# the layout order is: ExternalDP x DP x PP x TP # the layout order is: ExternalDP x DP x PP x TP
# ExternalDP is the data parallel group that is not part of the model, # ExternalDP is the data parallel group that is not part of the model,
...@@ -1437,7 +1556,9 @@ def initialize_model_parallel( ...@@ -1437,7 +1556,9 @@ def initialize_model_parallel(
assert _TP is None, "tensor model parallel group is already initialized" assert _TP is None, "tensor model parallel group is already initialized"
group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0) group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks] group_ranks = [x.tolist() for x in group_ranks]
if enable_elastic_ep:
group_ranks = local_all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
# message queue broadcaster is only used in tensor model parallel group # message queue broadcaster is only used in tensor model parallel group
_TP = init_model_parallel_group( _TP = init_model_parallel_group(
group_ranks, group_ranks,
...@@ -1456,6 +1577,11 @@ def initialize_model_parallel( ...@@ -1456,6 +1577,11 @@ def initialize_model_parallel(
# TP group into tp_size//dcp_size DCP groups. # TP group into tp_size//dcp_size DCP groups.
group_ranks = all_ranks.reshape(-1, decode_context_model_parallel_size).unbind(0) group_ranks = all_ranks.reshape(-1, decode_context_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks] group_ranks = [x.tolist() for x in group_ranks]
if enable_elastic_ep:
group_ranks = local_all_ranks.reshape(
-1, decode_context_model_parallel_size
).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
_DCP = init_model_parallel_group( _DCP = init_model_parallel_group(
group_ranks, group_ranks,
get_world_group().local_rank, get_world_group().local_rank,
...@@ -1472,6 +1598,13 @@ def initialize_model_parallel( ...@@ -1472,6 +1598,13 @@ def initialize_model_parallel(
.unbind(0) .unbind(0)
) )
group_ranks = [x.tolist() for x in group_ranks] group_ranks = [x.tolist() for x in group_ranks]
if enable_elastic_ep:
group_ranks = (
local_all_ranks.transpose(1, 2)
.reshape(-1, prefill_context_model_parallel_size)
.unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
_PCP = init_model_parallel_group( _PCP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="pcp" group_ranks, get_world_group().local_rank, backend, group_name="pcp"
) )
...@@ -1483,6 +1616,13 @@ def initialize_model_parallel( ...@@ -1483,6 +1616,13 @@ def initialize_model_parallel(
all_ranks.transpose(2, 4).reshape(-1, pipeline_model_parallel_size).unbind(0) all_ranks.transpose(2, 4).reshape(-1, pipeline_model_parallel_size).unbind(0)
) )
group_ranks = [x.tolist() for x in group_ranks] group_ranks = [x.tolist() for x in group_ranks]
if enable_elastic_ep:
group_ranks = (
local_all_ranks.transpose(0, 2)
.reshape(-1, pipeline_model_parallel_size)
.unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
_PP = init_model_parallel_group( _PP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="pp" group_ranks, get_world_group().local_rank, backend, group_name="pp"
) )
...@@ -1491,14 +1631,27 @@ def initialize_model_parallel( ...@@ -1491,14 +1631,27 @@ def initialize_model_parallel(
assert _DP is None, "data parallel group is already initialized" assert _DP is None, "data parallel group is already initialized"
group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0) group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks] group_ranks = [x.tolist() for x in group_ranks]
_DP = init_model_parallel_group( if enable_elastic_ep:
group_ranks, get_world_group().local_rank, backend, group_name="dp" parallel_config = config.parallel_config
) dp_ports = [
parallel_config.get_next_stateless_dp_group_port() for _ in group_ranks
]
_DP = _init_stateless_group(
group_ranks,
"dp",
dp_ports,
parallel_config.data_parallel_master_ip,
backend,
)
else:
_DP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="dp"
)
global _EP global _EP
assert _EP is None, "expert parallel group is already initialized" assert _EP is None, "expert parallel group is already initialized"
# Don't create EP group for dense models. # Don't create EP group for dense models.
if config is None or config.model_config is None or config.model_config.is_moe: if config.model_config is None or config.model_config.is_moe:
group_ranks = ( group_ranks = (
all_ranks.transpose(1, 2) all_ranks.transpose(1, 2)
.reshape( .reshape(
...@@ -1510,9 +1663,22 @@ def initialize_model_parallel( ...@@ -1510,9 +1663,22 @@ def initialize_model_parallel(
.unbind(0) .unbind(0)
) )
group_ranks = [x.tolist() for x in group_ranks] group_ranks = [x.tolist() for x in group_ranks]
_EP = init_model_parallel_group( if enable_elastic_ep:
group_ranks, get_world_group().local_rank, backend, group_name="ep" parallel_config = config.parallel_config
) ep_ports = [
parallel_config.get_next_stateless_ep_group_port() for _ in group_ranks
]
_EP = _init_stateless_group(
group_ranks,
"ep",
ep_ports,
parallel_config.data_parallel_master_ip,
backend,
)
else:
_EP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="ep"
)
# Create EPLB group with the same ranks as EP if EPLB is enabled. # Create EPLB group with the same ranks as EP if EPLB is enabled.
# This is a separate process group to isolate EPLB communications # This is a separate process group to isolate EPLB communications
...@@ -1525,10 +1691,25 @@ def initialize_model_parallel( ...@@ -1525,10 +1691,25 @@ def initialize_model_parallel(
and config.parallel_config is not None and config.parallel_config is not None
and config.parallel_config.enable_eplb and config.parallel_config.enable_eplb
): ):
# Reuse the same group_ranks from EP if enable_elastic_ep:
_EPLB = init_model_parallel_group( eplb_ports = [
group_ranks, get_world_group().local_rank, backend, group_name="eplb" parallel_config.get_next_stateless_eplb_group_port()
) for _ in group_ranks
]
_EPLB = _init_stateless_group(
group_ranks,
"eplb",
eplb_ports,
parallel_config.data_parallel_master_ip,
backend,
)
else:
_EPLB = init_model_parallel_group(
group_ranks,
get_world_group().local_rank,
backend,
group_name="eplb",
)
# If no EP group needed, _EP remains None # If no EP group needed, _EP remains None
# If no EPLB group needed, _EPLB remains None # If no EPLB group needed, _EPLB remains None
...@@ -1558,7 +1739,11 @@ def ensure_model_parallel_initialized( ...@@ -1558,7 +1739,11 @@ def ensure_model_parallel_initialized(
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
values if the model parallel groups are initialized. values if the model parallel groups are initialized.
""" """
backend = backend or torch.distributed.get_backend(get_world_group().device_group) world_group = get_world_group()
if hasattr(world_group, "backend"):
backend = backend or world_group.backend
else:
backend = backend or torch.distributed.get_backend(world_group.device_group)
if not model_parallel_is_initialized(): if not model_parallel_is_initialized():
initialize_model_parallel( initialize_model_parallel(
tensor_model_parallel_size, tensor_model_parallel_size,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
import torch
from torch.distributed import Backend, ProcessGroup
from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator
from vllm.distributed.parallel_state import (
GroupCoordinator,
TensorMetadata,
_get_unique_name,
_register_group,
_split_tensor_dict,
)
from vllm.distributed.utils import (
StatelessProcessGroup,
stateless_destroy_torch_distributed_process_group,
stateless_init_torch_distributed_process_group,
)
from vllm.logger import init_logger
from vllm.utils.import_utils import resolve_obj_by_qualname
logger = init_logger(__name__)
class StatelessGroupCoordinator(GroupCoordinator):
"""
A stateless version of the GroupCoordinator class in parallel_state,
It will create CPU, device and TCPStore based communication groups
that are independent of PyTorch's WORLD group. Hence,
communication groups with a different set of participants GPUs
can be created without destroying the existing ones.
"""
def __init__(
self,
group_ranks: list[list[int]],
local_rank: int,
torch_distributed_backend: str | Backend,
use_device_communicator: bool,
use_message_queue_broadcaster: bool = False,
group_name: str | None = None,
host: str = "127.0.0.1",
group_ports: list[list[int]] | None = None,
global_rank: int = 0,
global_world_size: int = 1,
):
group_name = group_name or "anonymous"
self.unique_name = _get_unique_name(group_name)
_register_group(self)
self.rank = global_rank
self.local_rank = local_rank
self_device_group = None
self_cpu_group = None
self_tcp_store_group = None
from vllm.platforms import current_platform
backend = str(torch_distributed_backend)
self.backend = backend
assert group_ports is not None, "group_ports is not provided"
for idx, ranks in enumerate(group_ranks):
if self.rank in ranks:
self.ranks = ranks
self.world_size = len(ranks)
self.rank_in_group = ranks.index(self.rank)
ports = group_ports[idx]
device_port = ports[0]
cpu_port = ports[1]
tcp_store_port = ports[2]
device_group = stateless_init_torch_distributed_process_group(
host=host,
port=device_port,
rank=self.rank_in_group,
world_size=self.world_size,
backend=backend,
group_name=f"{self.unique_name}_device",
)
cpu_group = stateless_init_torch_distributed_process_group(
host=host,
port=cpu_port,
rank=self.rank_in_group,
world_size=self.world_size,
backend="gloo",
group_name=f"{self.unique_name}_cpu",
)
tcp_store_group = StatelessProcessGroup.create(
host=host,
port=tcp_store_port,
rank=self.rank_in_group,
world_size=self.world_size,
)
self_device_group = device_group
self_cpu_group = cpu_group
self_tcp_store_group = tcp_store_group
assert self_cpu_group is not None
assert self_device_group is not None
assert self_tcp_store_group is not None
self.cpu_group = self_cpu_group
self.device_group = self_device_group
self.tcp_store_group = self_tcp_store_group
if current_platform.is_cuda_alike():
self.device = torch.device(f"cuda:{local_rank}")
elif current_platform.is_xpu():
self.device = torch.device(f"xpu:{local_rank}")
elif current_platform.is_out_of_tree():
self.device = torch.device(f"{current_platform.device_name}:{local_rank}")
else:
self.device = torch.device("cpu")
self.use_device_communicator = use_device_communicator
self.device_communicator = None
if use_device_communicator and self.world_size > 1:
device_comm_cls = resolve_obj_by_qualname(
current_platform.get_device_communicator_cls()
)
assert device_comm_cls == CudaCommunicator
self.device_communicator = CudaCommunicator(
cpu_group=self.cpu_group,
device=self.device,
device_group=self.device_group,
unique_name=self.unique_name,
global_ranks=self.ranks,
global_world_size=global_world_size,
tcp_store_group=self.tcp_store_group,
)
self.mq_broadcaster = None
self.use_custom_op_call = (
current_platform.is_cuda_alike() or current_platform.is_tpu()
)
self.use_cpu_custom_send_recv = False
def destroy(self):
if self.device_communicator:
self.device_communicator.destroy()
if self.device_group:
stateless_destroy_torch_distributed_process_group(self.device_group)
if self.cpu_group:
stateless_destroy_torch_distributed_process_group(self.cpu_group)
def size(self) -> int:
"""Return the world size of this group."""
return self.world_size
def broadcast(self, input_: torch.Tensor, src: int = 0):
if self.world_size == 1:
return input_
if self.device_communicator and input_.is_cuda:
return self.device_communicator.broadcast(input_, src)
else:
return self.tcp_store_group.broadcast(input_, src)
def broadcast_object(self, obj=None, src: int = 0):
if self.world_size == 1:
return obj
return self.tcp_store_group.broadcast_obj(obj, src)
def broadcast_object_list(
self, obj_list: list[Any], src: int = 0, group: ProcessGroup | None = None
):
assert src < self.world_size
if self.world_size == 1:
return obj_list
if self.rank_in_group == src:
for obj in obj_list:
self.tcp_store_group.broadcast_obj(obj, src)
else:
for i in range(len(obj_list)):
obj_list[i] = self.tcp_store_group.broadcast_obj(None, src)
return obj_list
def broadcast_tensor_dict(
self,
tensor_dict: dict[str, torch.Tensor | Any] | None = None,
src: int = 0,
group: ProcessGroup | None = None,
metadata_group: ProcessGroup | None = None,
) -> dict[str, torch.Tensor | Any] | None:
if self.world_size == 1:
return tensor_dict
if self.rank_in_group == src:
assert isinstance(tensor_dict, dict), (
f"Expecting a dictionary, got {type(tensor_dict)}"
)
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
else:
metadata_list = None
tensor_list = []
recv_metadata_list: list[tuple[str, Any]] = self.tcp_store_group.broadcast_obj(
metadata_list, src
)
if self.rank_in_group != src:
tensor_dict = {}
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(
value.size, dtype=value.dtype, device=value.device
)
tensor_list.append(tensor)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value
for tensor in tensor_list:
if tensor.numel() == 0:
continue
if self.device_communicator and tensor.is_cuda:
tensor.copy_(self.device_communicator.broadcast(tensor, src))
else:
tensor.copy_(self.tcp_store_group.broadcast(tensor, src))
return tensor_dict
def send_object(self, obj, dst: int) -> None:
assert dst < self.world_size
assert dst != self.rank_in_group
self.tcp_store_group.send_obj(obj, dst)
def recv_object(self, src: int):
assert src < self.world_size
assert src != self.rank_in_group
return self.tcp_store_group.recv_obj(src)
def send_tensor_dict(
self,
tensor_dict: dict[str, torch.Tensor | Any],
dst: int | None = None,
all_gather_group: Optional["GroupCoordinator"] = None,
all_gather_tensors: dict[str, bool] | None = None,
) -> dict[str, torch.Tensor | Any] | None:
if self.world_size == 1:
return tensor_dict
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
assert dst < self.world_size
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
self.tcp_store_group.send_obj(metadata_list, dst)
for tensor in tensor_list:
if tensor.numel() == 0:
continue
if self.device_communicator and tensor.is_cuda:
self.device_communicator.send(tensor, dst)
else:
self.tcp_store_group.send(tensor, dst)
return None
def recv_tensor_dict(
self,
src: int | None = None,
all_gather_group: Optional["GroupCoordinator"] = None,
all_gather_tensors: dict[str, bool] | None = None,
) -> dict[str, torch.Tensor | Any] | None:
if self.world_size == 1:
return None
if src is None:
src = (self.rank_in_group - 1) % self.world_size
assert src < self.world_size
recv_metadata_list = self.tcp_store_group.recv_obj(src)
tensor_dict = {}
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size, dtype=value.dtype, device=value.device)
if tensor.numel() > 0:
if self.device_communicator and tensor.is_cuda:
tensor = self.device_communicator.recv(
tensor.size(), tensor.dtype, src
)
else:
tensor = self.tcp_store_group.recv(tensor, src)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value
return tensor_dict
def barrier(self):
self.tcp_store_group.barrier()
def gather(
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
) -> torch.Tensor | None:
if self.world_size == 1:
return input_
if self.device_communicator is None:
raise ValueError("No device communicator found")
if self.rank_in_group == dst:
gathered_list = [torch.empty_like(input_) for _ in range(self.world_size)]
gathered_list[self.rank_in_group] = input_
for src_rank in range(self.world_size):
if src_rank != self.rank_in_group:
gathered_list[src_rank] = self.device_communicator.recv(
input_.size(), input_.dtype, src_rank
)
return torch.cat(gathered_list, dim=dim)
else:
self.device_communicator.send(input_, dst)
return None
...@@ -18,7 +18,7 @@ from datetime import timedelta ...@@ -18,7 +18,7 @@ from datetime import timedelta
from typing import Any from typing import Any
import torch import torch
from torch.distributed import ProcessGroup, TCPStore from torch.distributed import ProcessGroup, Store, TCPStore
from torch.distributed.distributed_c10d import ( from torch.distributed.distributed_c10d import (
Backend, Backend,
PrefixStore, PrefixStore,
...@@ -228,6 +228,55 @@ class StatelessProcessGroup: ...@@ -228,6 +228,55 @@ class StatelessProcessGroup:
gathered_objs.append(recv_obj) gathered_objs.append(recv_obj)
return gathered_objs return gathered_objs
def broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
"""Broadcast a tensor from source rank to all other ranks."""
if self.rank == src:
tensor_bytes = pickle.dumps(tensor)
self.expire_data()
key = f"broadcast_tensor/{src}/{self.broadcast_send_counter}"
self.store.set(key, tensor_bytes)
self.broadcast_send_counter += 1
self.entries.append((key, time.time()))
return tensor
else:
key = f"broadcast_tensor/{src}/{self.broadcast_recv_src_counter[src]}"
tensor = pickle.loads(self.store.get(key))
self.broadcast_recv_src_counter[src] += 1
return tensor
def send(self, tensor: torch.Tensor, dst: int):
"""Send a tensor to a destination rank."""
self.expire_data()
key = f"send_tensor/{dst}/{self.send_dst_counter[dst]}"
self.store.set(key, pickle.dumps(tensor))
self.send_dst_counter[dst] += 1
self.entries.append((key, time.time()))
def recv(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
"""Receive a tensor from a source rank."""
key = f"send_tensor/{self.rank}/{self.recv_src_counter[src]}"
received = pickle.loads(self.store.get(key))
self.recv_src_counter[src] += 1
tensor.copy_(received)
return tensor
def all_reduce(
self, tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM
) -> torch.Tensor:
"""All-reduce a tensor across all ranks."""
tensors = self.all_gather_obj(tensor)
result = tensors[0].clone()
for t in tensors[1:]:
if op == torch.distributed.ReduceOp.SUM:
result.add_(t)
elif op == torch.distributed.ReduceOp.PRODUCT:
result.mul_(t)
elif op == torch.distributed.ReduceOp.MAX:
result = torch.maximum(result, t)
elif op == torch.distributed.ReduceOp.MIN:
result = torch.minimum(result, t)
return result
def barrier(self, timeout: float = 30.0): def barrier(self, timeout: float = 30.0):
"""A robust barrier to synchronize all ranks. """A robust barrier to synchronize all ranks.
...@@ -448,8 +497,14 @@ def init_gloo_process_group( ...@@ -448,8 +497,14 @@ def init_gloo_process_group(
def stateless_init_torch_distributed_process_group( def stateless_init_torch_distributed_process_group(
host: str, port: int, rank: int, world_size: int, backend: str host: str,
) -> ProcessGroup: port: int,
rank: int,
world_size: int,
backend: str,
group_name: str | None = None,
return_store: bool = False,
) -> ProcessGroup | tuple[ProcessGroup, Store]:
""" """
A replacement for `torch.distributed.init_process_group` that does not A replacement for `torch.distributed.init_process_group` that does not
pollute the global state. The created ProcessGroup object can be used for pollute the global state. The created ProcessGroup object can be used for
...@@ -496,26 +551,36 @@ def stateless_init_torch_distributed_process_group( ...@@ -496,26 +551,36 @@ def stateless_init_torch_distributed_process_group(
# Use a PrefixStore to avoid accidental overrides of keys used by # Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant. # different systems (e.g. RPC) in case the store is multi-tenant.
prefix_store = PrefixStore(init_method, store) prefix_store = PrefixStore(init_method, store)
try:
from vllm.platforms import current_platform
return current_platform.stateless_init_device_torch_dist_pg( if backend == "gloo":
backend=backend, pg = init_gloo_process_group(
prefix_store=prefix_store, prefix_store=prefix_store,
group_rank=group_rank, group_rank=group_rank,
group_size=group_size, group_size=group_size,
timeout=timeout, timeout=timeout,
) )
except NotImplementedError: else:
# If platform doesn't implement stateless_init_device_torch_dist_pg, it from vllm.platforms import current_platform
# will raise a NotImplementedError. In this case, we fall back to gloo.
return init_gloo_process_group( pg = current_platform.stateless_init_device_torch_dist_pg(
backend=backend,
prefix_store=prefix_store, prefix_store=prefix_store,
group_rank=group_rank, group_rank=group_rank,
group_size=group_size, group_size=group_size,
timeout=timeout, timeout=timeout,
) )
if group_name is not None:
from torch._C._distributed_c10d import _register_process_group
pg._set_group_name(group_name)
_register_process_group(group_name, pg)
if return_store:
return pg, store
else:
return pg
def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None: def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None:
""" """
......
...@@ -419,6 +419,7 @@ class EngineArgs: ...@@ -419,6 +419,7 @@ class EngineArgs:
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
moe_backend: MoEBackend = KernelConfig.moe_backend moe_backend: MoEBackend = KernelConfig.moe_backend
all2all_backend: All2AllBackend = ParallelConfig.all2all_backend all2all_backend: All2AllBackend = ParallelConfig.all2all_backend
enable_elastic_ep: bool = ParallelConfig.enable_elastic_ep
enable_dbo: bool = ParallelConfig.enable_dbo enable_dbo: bool = ParallelConfig.enable_dbo
ubatch_size: int = ParallelConfig.ubatch_size ubatch_size: int = ParallelConfig.ubatch_size
dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
...@@ -896,6 +897,9 @@ class EngineArgs: ...@@ -896,6 +897,9 @@ class EngineArgs:
"--ubatch-size", "--ubatch-size",
**parallel_kwargs["ubatch_size"], **parallel_kwargs["ubatch_size"],
) )
parallel_group.add_argument(
"--enable-elastic-ep", **parallel_kwargs["enable_elastic_ep"]
)
parallel_group.add_argument( parallel_group.add_argument(
"--dbo-decode-token-threshold", "--dbo-decode-token-threshold",
**parallel_kwargs["dbo_decode_token_threshold"], **parallel_kwargs["dbo_decode_token_threshold"],
...@@ -1698,6 +1702,7 @@ class EngineArgs: ...@@ -1698,6 +1702,7 @@ class EngineArgs:
is_moe_model=model_config.is_moe, is_moe_model=model_config.is_moe,
enable_expert_parallel=self.enable_expert_parallel, enable_expert_parallel=self.enable_expert_parallel,
all2all_backend=self.all2all_backend, all2all_backend=self.all2all_backend,
enable_elastic_ep=self.enable_elastic_ep,
enable_dbo=self.enable_dbo, enable_dbo=self.enable_dbo,
ubatch_size=self.ubatch_size, ubatch_size=self.ubatch_size,
dbo_decode_token_threshold=self.dbo_decode_token_threshold, dbo_decode_token_threshold=self.dbo_decode_token_threshold,
......
...@@ -246,8 +246,12 @@ def run_multi_api_server(args: argparse.Namespace): ...@@ -246,8 +246,12 @@ def run_multi_api_server(args: argparse.Namespace):
api_server_manager: APIServerProcessManager | None = None api_server_manager: APIServerProcessManager | None = None
from vllm.v1.engine.utils import get_engine_zmq_addresses
addresses = get_engine_zmq_addresses(vllm_config, num_api_servers)
with launch_core_engines( with launch_core_engines(
vllm_config, executor_class, log_stats, num_api_servers vllm_config, executor_class, log_stats, addresses, num_api_servers
) as (local_engine_manager, coordinator, addresses): ) as (local_engine_manager, coordinator, addresses):
# Construct common args for the APIServerProcessManager up-front. # Construct common args for the APIServerProcessManager up-front.
api_server_manager_kwargs = dict( api_server_manager_kwargs = dict(
......
...@@ -243,6 +243,8 @@ if TYPE_CHECKING: ...@@ -243,6 +243,8 @@ if TYPE_CHECKING:
VLLM_LORA_DISABLE_PDL: bool = False VLLM_LORA_DISABLE_PDL: bool = False
VLLM_ENABLE_CUDA_COMPATIBILITY: bool = False VLLM_ENABLE_CUDA_COMPATIBILITY: bool = False
VLLM_CUDA_COMPATIBILITY_PATH: str | None = None VLLM_CUDA_COMPATIBILITY_PATH: str | None = None
VLLM_ELASTIC_EP_SCALE_UP_LAUNCH: bool = False
VLLM_ELASTIC_EP_DRAIN_REQUESTS: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -1617,6 +1619,16 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1617,6 +1619,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_CUDA_COMPATIBILITY_PATH": lambda: os.environ.get( "VLLM_CUDA_COMPATIBILITY_PATH": lambda: os.environ.get(
"VLLM_CUDA_COMPATIBILITY_PATH", None "VLLM_CUDA_COMPATIBILITY_PATH", None
), ),
# Whether it is a scale up launch engine for elastic EP,
# Should only be set by EngineCoreClient.
"VLLM_ELASTIC_EP_SCALE_UP_LAUNCH": lambda: bool(
int(os.getenv("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH", "0"))
),
# Whether to wait for all requests to drain before sending the
# scaling command in elastic EP.
"VLLM_ELASTIC_EP_DRAIN_REQUESTS": lambda: bool(
int(os.getenv("VLLM_ELASTIC_EP_DRAIN_REQUESTS", "0"))
),
} }
......
...@@ -627,6 +627,7 @@ class FusedMoE(CustomOp): ...@@ -627,6 +627,7 @@ class FusedMoE(CustomOp):
moe_quant_params["intermediate_size_full"] = intermediate_size moe_quant_params["intermediate_size_full"] = intermediate_size
self.quant_method.create_weights(layer=self, **moe_quant_params) self.quant_method.create_weights(layer=self, **moe_quant_params)
self.base_quant_method = self.quant_method
# Disable shared expert overlap if: # Disable shared expert overlap if:
# - we are using eplb with non-default backend, because of correctness issues # - we are using eplb with non-default backend, because of correctness issues
...@@ -683,7 +684,7 @@ class FusedMoE(CustomOp): ...@@ -683,7 +684,7 @@ class FusedMoE(CustomOp):
# routing_tables only needed for round-robin expert placement with # routing_tables only needed for round-robin expert placement with
# DeepEP all2all backend. # DeepEP all2all backend.
routing_tables = self._maybe_init_expert_routing_tables() routing_tables = self._maybe_init_expert_routing_tables()
prepare_finalize = self.quant_method.maybe_make_prepare_finalize( prepare_finalize = self.base_quant_method.maybe_make_prepare_finalize(
routing_tables=routing_tables routing_tables=routing_tables
) )
if prepare_finalize is not None: if prepare_finalize is not None:
...@@ -693,7 +694,7 @@ class FusedMoE(CustomOp): ...@@ -693,7 +694,7 @@ class FusedMoE(CustomOp):
self._replace_quant_method( self._replace_quant_method(
FusedMoEModularMethod.make( FusedMoEModularMethod.make(
self, self,
self.quant_method, self.base_quant_method,
prepare_finalize, prepare_finalize,
self.shared_experts, self.shared_experts,
inplace=not self.moe_config.disable_inplace, inplace=not self.moe_config.disable_inplace,
......
...@@ -6,10 +6,13 @@ pynvml. However, it should not initialize cuda context. ...@@ -6,10 +6,13 @@ pynvml. However, it should not initialize cuda context.
import os import os
from collections.abc import Callable from collections.abc import Callable
from datetime import timedelta
from functools import cache, wraps from functools import cache, wraps
from typing import TYPE_CHECKING, TypeVar from typing import TYPE_CHECKING, TypeVar
import torch import torch
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
# import custom ops, trigger op registration # import custom ops, trigger op registration
...@@ -482,6 +485,37 @@ class CudaPlatformBase(Platform): ...@@ -482,6 +485,37 @@ class CudaPlatformBase(Platform):
def get_static_graph_wrapper_cls(cls) -> str: def get_static_graph_wrapper_cls(cls) -> str:
return "vllm.compilation.cuda_graph.CUDAGraphWrapper" return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
@classmethod
def stateless_init_device_torch_dist_pg(
cls,
backend: str,
prefix_store: PrefixStore,
group_rank: int,
group_size: int,
timeout: timedelta,
) -> ProcessGroup:
assert is_nccl_available()
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
from torch.distributed.distributed_c10d import ProcessGroupNCCL
backend_options = ProcessGroupNCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupNCCL(
prefix_store, group_rank, group_size, backend_options
)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class)
return pg
@classmethod @classmethod
def device_count(cls) -> int: def device_count(cls) -> int:
return cuda_device_count_stateless() return cuda_device_count_stateless()
......
...@@ -2,10 +2,13 @@ ...@@ -2,10 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os import os
from datetime import timedelta
from functools import cache, lru_cache, wraps from functools import cache, lru_cache, wraps
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import torch import torch
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -656,6 +659,37 @@ class RocmPlatform(Platform): ...@@ -656,6 +659,37 @@ class RocmPlatform(Platform):
def get_static_graph_wrapper_cls(cls) -> str: def get_static_graph_wrapper_cls(cls) -> str:
return "vllm.compilation.cuda_graph.CUDAGraphWrapper" return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
@classmethod
def stateless_init_device_torch_dist_pg(
cls,
backend: str,
prefix_store: PrefixStore,
group_rank: int,
group_size: int,
timeout: timedelta,
) -> ProcessGroup:
assert is_nccl_available()
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
from torch.distributed.distributed_c10d import ProcessGroupNCCL
backend_options = ProcessGroupNCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupNCCL(
prefix_store, group_rank, group_size, backend_options
)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class)
return pg
@classmethod @classmethod
def device_count(cls) -> int: def device_count(cls) -> int:
return cuda_device_count_stateless() return cuda_device_count_stateless()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment