Unverified Commit dea26833 authored by Itay Alroy's avatar Itay Alroy Committed by GitHub
Browse files

[1/N] Elastic EP Milestone 2 (#34861)


Signed-off-by: default avatarYongji Wu <wuyongji317@gmail.com>
Signed-off-by: default avatarItay Alroy <ialroy@nvidia.com>
Signed-off-by: default avatarTyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: default avatarRon Tourgeman <rtourgeman@nvidia.com>
Co-authored-by: default avatarYongji Wu <wuyongji317@gmail.com>
Co-authored-by: default avatarTyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: default avatarRon Tourgeman <rtourgeman@nvidia.com>
parent 90805ff4
......@@ -31,8 +31,8 @@ class NaiveAll2AllManager(All2AllManagerBase):
debugging.
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)
def naive_multicast(
self,
......@@ -138,8 +138,8 @@ class AgRsAll2AllManager(All2AllManagerBase):
all-gather (dispatch) and reduce-scatter (combine).
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)
def dispatch_router_logits(
self,
......@@ -239,12 +239,12 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
All2All communication based on DeepEP High-Throughput kernels.
"""
def __init__(self, cpu_group):
def __init__(self, cpu_group, tcp_store_group=None):
assert has_deep_ep(), (
"DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
" to install DeepEP kernels."
) # noqa
super().__init__(cpu_group)
super().__init__(cpu_group, tcp_store_group)
self.handle_cache = Cache()
# This is the DeepEP default. Stick to it till we can establish
......@@ -282,7 +282,10 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
raise NotImplementedError
def destroy(self):
pass
with self.handle_cache._lock:
for _, handle in self.handle_cache._cache.items():
handle.destroy()
self.handle_cache._cache.clear()
class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
......@@ -290,8 +293,8 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
All2All communication based on DeepEP High-Throughput kernels.
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)
def _make_all2all_kwargs(self) -> dict[Any, Any]:
# Defaults for internode and intranode are taken from DeepEP tests.
......@@ -314,6 +317,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=False,
num_qps_per_rank=num_qps_per_rank,
explicitly_destroy=True,
)
def get_handle(self, kwargs):
......@@ -347,8 +351,8 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
All2All communication based on DeepEP Low-Latency kernels.
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)
def _make_all2all_kwargs(
self,
......@@ -387,6 +391,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
num_qps_per_rank=num_qps_per_rank,
allow_nvlink_for_low_latency_mode=True,
allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
explicitly_destroy=True,
)
def get_handle(self, kwargs):
......@@ -418,11 +423,11 @@ class FlashInferAllToAllManager(All2AllManagerBase):
rank: int
world_size: int
def __init__(self, cpu_group):
def __init__(self, cpu_group, tcp_store_group=None):
assert has_flashinfer_all2all(), (
"flashinfer all2all module not found. Please install/check flashinfer"
) # noqa
super().__init__(cpu_group)
super().__init__(cpu_group, tcp_store_group)
logger.debug(
"Initialize for flashinfer All2All rank=%d, world size=%d",
self.rank,
......
......@@ -29,8 +29,9 @@ class All2AllManagerBase:
rank: int
world_size: int
def __init__(self, cpu_group):
def __init__(self, cpu_group, tcp_store_group=None):
self.cpu_group = cpu_group
self.tcp_store_group = tcp_store_group
# compute some common properties
from vllm.distributed.parallel_state import (
......@@ -47,12 +48,17 @@ class All2AllManagerBase:
# when we create this object
self.dp_rank = self.dp_group.rank_in_group
self.dp_world_size = self.dp_group.world_size
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
self.rank = cpu_group.rank()
self.world_size = cpu_group.size()
# all2all communication often has separate implementations for
# intra-node and inter-node communication
if tcp_store_group is None:
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
else:
self.internode = not all(
in_the_same_node_as(tcp_store_group, source_rank=0)
)
def get_handle(self, kwargs):
# get a handle for the all2all communication,
......@@ -121,11 +127,30 @@ class DeviceCommunicatorBase:
device: torch.device | None = None,
device_group: ProcessGroup | None = None,
unique_name: str = "",
global_ranks: list[int] | None = None,
global_world_size: int | None = None,
):
self.device = device or torch.device("cpu")
self.cpu_group = cpu_group
self.device_group = device_group
self.unique_name = unique_name
# Check if this is a stateless process group
from torch.distributed.distributed_c10d import _world
is_stateless = _world.pg_map.get(cpu_group, None) is None
if is_stateless:
# For stateless groups, we can't use torch.distributed methods
self.rank = cpu_group.rank()
self.world_size = cpu_group.size()
assert global_ranks is not None
assert global_world_size is not None
self.ranks = global_ranks
self.global_rank = self.ranks[self.rank]
self.global_world_size = global_world_size
self.rank_in_group = self.rank
else:
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
self.ranks = dist.get_process_group_ranks(cpu_group)
......@@ -145,7 +170,7 @@ class DeviceCommunicatorBase:
use_ep = config.parallel_config.data_parallel_size > 1
all2all_backend = config.parallel_config.all2all_backend
self.is_ep_communicator = "ep" in unique_name
self.is_ep_communicator = unique_name.split(":")[0] == "ep"
self.use_all2all = self.is_ep_communicator and use_ep
self.all2all_backend = all2all_backend
self.all2all_manager: All2AllManagerBase | None = None
......@@ -275,6 +300,13 @@ class DeviceCommunicatorBase:
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor
def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
"""Broadcast a tensor from source rank to all ranks."""
if self.world_size == 1:
return tensor
torch.distributed.broadcast(tensor, self.ranks[src], self.device_group)
return tensor
def destroy(self):
pass
......@@ -343,3 +375,6 @@ class DeviceCommunicatorBase:
This is a no-op in the base class.
"""
return hidden_states
def batch_isend_irecv(self, p2p_ops: list):
raise NotImplementedError
......@@ -16,6 +16,7 @@ from vllm.distributed.device_communicators.pynccl_allocator import (
from vllm.logger import init_logger
from vllm.platforms import current_platform
from ..utils import StatelessProcessGroup
from .base_device_communicator import DeviceCommunicatorBase
logger = init_logger(__name__)
......@@ -28,8 +29,18 @@ class CudaCommunicator(DeviceCommunicatorBase):
device: torch.device | None = None,
device_group: ProcessGroup | None = None,
unique_name: str = "",
global_ranks: list[int] | None = None,
global_world_size: int | None = None,
tcp_store_group: StatelessProcessGroup | None = None,
):
super().__init__(cpu_group, device, device_group, unique_name)
super().__init__(
cpu_group,
device,
device_group,
unique_name,
global_ranks,
global_world_size,
)
if "tp" not in unique_name:
# custom allreduce or torch symm mem can be used only by tp
use_custom_allreduce = False
......@@ -62,7 +73,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.pynccl_comm: PyNcclCommunicator | None = None
if self.world_size > 1:
self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group,
group=self.cpu_group if tcp_store_group is None else tcp_store_group,
device=self.device,
)
if is_symmetric_memory_enabled():
......@@ -107,19 +118,27 @@ class CudaCommunicator(DeviceCommunicatorBase):
if self.all2all_backend == "naive":
from .all2all import NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
self.all2all_manager = NaiveAll2AllManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "allgather_reducescatter":
from .all2all import AgRsAll2AllManager
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
self.all2all_manager = AgRsAll2AllManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "deepep_high_throughput":
from .all2all import DeepEPHTAll2AllManager
self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group)
self.all2all_manager = DeepEPHTAll2AllManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "deepep_low_latency":
from .all2all import DeepEPLLAll2AllManager
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
self.all2all_manager = DeepEPLLAll2AllManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "mori":
from .all2all import MoriAll2AllManager
......@@ -127,7 +146,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
elif self.all2all_backend == "flashinfer_all2allv":
from .all2all import FlashInferAllToAllManager
self.all2all_manager = FlashInferAllToAllManager(self.cpu_group)
self.all2all_manager = FlashInferAllToAllManager(
self.cpu_group, tcp_store_group
)
else:
raise ValueError(f"Unknown all2all backend: {self.all2all_backend}")
......@@ -284,6 +305,18 @@ class CudaCommunicator(DeviceCommunicatorBase):
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor
def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
"""Broadcast a tensor from source rank to all ranks."""
if self.world_size == 1:
return tensor
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.broadcast(tensor, src)
return tensor
else:
raise ValueError("No PyNCCL communicator found")
def destroy(self):
if self.pynccl_comm is not None:
self.pynccl_comm = None
......@@ -403,3 +436,10 @@ class CudaCommunicator(DeviceCommunicatorBase):
hidden_states,
is_sequence_parallel,
)
def batch_isend_irecv(self, p2p_ops: list):
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.batch_isend_irecv(p2p_ops)
else:
raise ValueError("No PyNCCL communicator found")
......@@ -312,10 +312,19 @@ class PyNcclCommunicator:
)
if stream is None:
stream = current_stream()
if tensor.dtype in [
torch.float8_e5m2,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
]:
nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8)
else:
nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype)
self.nccl.ncclSend(
buffer_type(tensor.data_ptr()),
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
nccl_dtype,
dst,
self.comm,
cudaStream_t(stream.cuda_stream),
......@@ -330,10 +339,19 @@ class PyNcclCommunicator:
)
if stream is None:
stream = current_stream()
if tensor.dtype in [
torch.float8_e5m2,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
]:
nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8)
else:
nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype)
self.nccl.ncclRecv(
buffer_type(tensor.data_ptr()),
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
nccl_dtype,
src,
self.comm,
cudaStream_t(stream.cuda_stream),
......@@ -384,3 +402,17 @@ class PyNcclCommunicator:
def deregister_comm_window(self, window):
return self.nccl.ncclCommWindowDeregister(self.comm, window)
def batch_isend_irecv(self, p2p_ops: list, stream=None):
if self.disabled:
return
if stream is None:
stream = current_stream()
self.group_start()
for op in p2p_ops:
if op.op is torch.distributed.isend:
self.send(op.tensor, op.group_peer, stream)
elif op.op is torch.distributed.irecv:
self.recv(op.tensor, op.group_peer, stream)
self.group_end()
This diff is collapsed.
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.distributed.parallel_state import (
_init_stateless_group,
_node_count,
get_pp_group,
get_tp_group,
get_world_group,
)
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
_STANDBY_WORLD: StatelessGroupCoordinator | None = None
_STANDBY_WORLD_NODE_COUNT: int | None = None
_STANDBY_DP: StatelessGroupCoordinator | None = None
_STANDBY_EP: StatelessGroupCoordinator | None = None
_STANDBY_EPLB: StatelessGroupCoordinator | None = None
def get_standby_dp_group() -> StatelessGroupCoordinator | None:
return _STANDBY_DP
def get_standby_ep_group() -> StatelessGroupCoordinator | None:
return _STANDBY_EP
def get_standby_eplb_group() -> StatelessGroupCoordinator | None:
return _STANDBY_EPLB
def get_standby_world_group() -> StatelessGroupCoordinator | None:
return _STANDBY_WORLD
def create_standby_groups(
new_dp_size: int,
new_world_size_across_dp: int,
master_ip: str,
world_group_ports: list[list[int]],
dp_group_ports: list[list[int]],
ep_group_ports: list[list[int]],
eplb_group_ports: list[list[int]] | None = None,
backend: str | None = None,
) -> None:
global \
_STANDBY_WORLD, \
_STANDBY_WORLD_NODE_COUNT, \
_STANDBY_DP, \
_STANDBY_EP, \
_STANDBY_EPLB
assert new_world_size_across_dp == torch.distributed.get_world_size() * new_dp_size
world_group = get_world_group()
assert isinstance(world_group, StatelessGroupCoordinator)
backend = backend or world_group.backend
standby_world_ranks = [list(range(new_world_size_across_dp))]
_STANDBY_WORLD = _init_stateless_group(
standby_world_ranks,
"world",
world_group_ports,
master_ip,
backend,
use_device_communicator=False,
)
_STANDBY_WORLD_NODE_COUNT = _node_count(_STANDBY_WORLD.tcp_store_group)
tp_size = get_tp_group().world_size
pp_size = get_pp_group().world_size
all_ranks = torch.arange(new_world_size_across_dp).reshape(
-1, new_dp_size, pp_size, tp_size
)
standby_dp_ranks = all_ranks.transpose(1, 3).reshape(-1, new_dp_size).unbind(0)
standby_dp_ranks = [x.tolist() for x in standby_dp_ranks]
_STANDBY_DP = _init_stateless_group(
standby_dp_ranks, "dp", dp_group_ports, master_ip, backend
)
standby_ep_ranks = (
all_ranks.transpose(1, 2).reshape(-1, new_dp_size * tp_size).unbind(0)
)
standby_ep_ranks = [x.tolist() for x in standby_ep_ranks]
_STANDBY_EP = _init_stateless_group(
standby_ep_ranks, "ep", ep_group_ports, master_ip, backend
)
if eplb_group_ports is not None:
_STANDBY_EPLB = _init_stateless_group(
standby_ep_ranks, "eplb", eplb_group_ports, master_ip, backend
)
def pop_standby_groups() -> dict:
"""Return all standby groups and clear the standby state."""
global \
_STANDBY_WORLD, \
_STANDBY_WORLD_NODE_COUNT, \
_STANDBY_DP, \
_STANDBY_EP, \
_STANDBY_EPLB
result = dict(
world=_STANDBY_WORLD,
dp=_STANDBY_DP,
ep=_STANDBY_EP,
eplb=_STANDBY_EPLB,
node_count=_STANDBY_WORLD_NODE_COUNT,
)
_STANDBY_WORLD = None
_STANDBY_WORLD_NODE_COUNT = None
_STANDBY_DP = None
_STANDBY_EP = None
_STANDBY_EPLB = None
return result
......@@ -24,7 +24,6 @@ logger = init_logger(__name__)
def start_async_worker(
state: "EplbState",
rank_mapping: dict[int, int] | None = None,
is_profile: bool = False,
) -> threading.Thread:
eplb_group = get_eplb_group().device_group
......@@ -45,7 +44,6 @@ def start_async_worker(
eplb_group=eplb_group,
cuda_stream=cuda_stream,
is_profile=is_profile,
rank_mapping=rank_mapping,
)
)
except Exception as exc: # pragma: no cover - diagnostic path
......@@ -107,7 +105,6 @@ async def transfer_run_periodically(
eplb_group: ProcessGroup,
cuda_stream: torch.cuda.Stream,
is_profile: bool = False,
rank_mapping: dict[int, int] | None = None,
) -> None:
while True:
await asyncio.to_thread(state.rearrange_event.wait)
......@@ -176,7 +173,6 @@ async def transfer_run_periodically(
ep_group=eplb_group,
is_profile=is_profile,
cuda_stream=cuda_stream,
rank_mapping=rank_mapping,
)
event = torch.cuda.Event(blocking=False)
cuda_stream.record_event(event)
......
......@@ -40,6 +40,7 @@ from vllm.distributed.parallel_state import (
get_node_count,
in_the_same_node_as,
)
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MixtureOfExperts
......@@ -302,6 +303,14 @@ class EplbState:
"""
CUDA device index for the async EPLB worker thread.
"""
self.num_valid_physical_experts: int = 0
"""
Number of valid physical experts.
This is the number of physical experts that are
actually mapped to logical experts. In elastic EP,
newly started EP ranks may not have physical experts
mapped yet.
"""
if self.device.type == "cuda":
self.cuda_device_index = self.device.index
if self.cuda_device_index is None and torch.cuda.is_available():
......@@ -367,9 +376,6 @@ class EplbState:
self,
model: MixtureOfExperts,
model_config: ModelConfig,
global_expert_load: torch.Tensor | None = None,
old_global_expert_indices: torch.Tensor | None = None,
rank_mapping: dict[int, int] | None = None,
):
"""
Build the initial EPLB state.
......@@ -462,75 +468,15 @@ class EplbState:
)
self.expert_rearrangement_step_interval = eplb_step_interval
# Set the policy based on the selected eplb algorithm type.
policy_type = self.parallel_config.eplb_config.policy
self.policy = EPLB_POLICIES[policy_type]
logger.debug("Selected EPLB policy: %s", policy_type)
if global_expert_load is not None:
ep_group = get_ep_group().device_group
assert global_expert_load.shape == (
model.num_moe_layers,
model.num_logical_experts,
)
assert global_expert_load.dtype == torch.int64
num_replicas = model.num_physical_experts
num_groups = model.num_expert_groups
num_nodes = get_node_count()
num_gpus = ep_group.size()
if num_gpus % num_nodes != 0:
num_nodes = 1
logger.warning_once(
f"num_gpus % num_nodes != 0, "
"not using hierarchical rearrangement algorithm.\n"
f"{num_gpus=}, {num_nodes=}"
)
# Get new expert mappings
(
new_physical_to_logical_map,
new_logical_to_physical_map,
new_logical_replica_count,
) = self.policy.rebalance_experts(
global_expert_load,
num_replicas,
num_groups,
num_nodes,
num_gpus,
)
max_physical_slots = new_logical_to_physical_map.shape[-1]
assert max_physical_slots <= logical_to_physical_map.shape[-1]
new_logical_to_physical_map = torch.nn.functional.pad(
new_logical_to_physical_map,
(0, logical_to_physical_map.shape[-1] - max_physical_slots),
value=-1,
)
physical_to_logical_map = new_physical_to_logical_map.to(self.device)
logical_to_physical_map.copy_(new_logical_to_physical_map)
logical_replica_count.copy_(new_logical_replica_count)
else:
new_physical_to_logical_map = None
new_logical_to_physical_map = None
new_logical_replica_count = None
model.set_eplb_state(
expert_load_pass,
logical_to_physical_map,
logical_replica_count,
)
if global_expert_load is not None:
rearrange_expert_weights_inplace(
old_global_expert_indices,
new_physical_to_logical_map,
model.expert_weights,
ep_group,
False,
rank_mapping,
)
self.expert_rearrangement_step = 0
expert_buffer = [torch.empty_like(w) for w in model.expert_weights[0]]
......@@ -561,11 +507,12 @@ class EplbState:
recv_dst_rows=np.array([]),
),
cuda_device_index=self.cuda_device_index,
new_physical_to_logical_map=new_physical_to_logical_map,
new_logical_to_physical_map=new_logical_to_physical_map,
new_logical_replica_count=new_logical_replica_count,
new_physical_to_logical_map=None,
new_logical_to_physical_map=None,
new_logical_replica_count=None,
)
self.model_states[model_config.compute_hash()] = model_state
self.num_valid_physical_experts = model.num_physical_experts
def step(
self,
......@@ -696,8 +643,6 @@ class EplbState:
def rearrange(
self,
is_profile: bool = False,
execute_shuffle: bool = True,
global_expert_loads: list[torch.Tensor] | None = None,
rank_mapping: dict[int, int] | None = None,
) -> torch.Tensor | None:
"""
......@@ -707,12 +652,6 @@ class EplbState:
is_profile (bool): If `True`, perform a dummy rearrangement.
This is used in `profile_run` to reserve enough memory,
no memory movement will be performed. Default is False.
execute_shuffle (bool): If `True`, execute the shuffle
in elastic expert parallel (EEP). Default is True.
global_expert_loads (list[torch.Tensor] | None): The global expert
loads when scaling is done in EEP.
List of expert loads for the main and drafter
(when spec decode is used) models.
rank_mapping (dict[int, int] | None): The rank mapping
when scaling is done in EEP.
"""
......@@ -734,18 +673,12 @@ class EplbState:
"(profile)" if is_profile else "",
)
if global_expert_loads is None:
# Map the physical expert load to global logical experts
global_expert_load_windows = []
if not execute_shuffle:
num_models = torch.tensor(
[len(self.model_states)], dtype=torch.int32, device="cpu"
)
torch.distributed.broadcast(
num_models, group=get_ep_group().cpu_group, group_src=0
)
for eplb_model_state in self.model_states.values():
expert_load_window = eplb_model_state.expert_load_window[
:, :, : self.num_valid_physical_experts
]
logical_expert_load_window = torch.zeros(
self.expert_load_window_size,
eplb_model_state.model.num_moe_layers,
......@@ -755,46 +688,19 @@ class EplbState:
)
logical_expert_load_window.scatter_add_(
dim=-1,
index=eplb_model_state.physical_to_logical_map.unsqueeze(0)
.expand_as(eplb_model_state.expert_load_window)
index=eplb_model_state.physical_to_logical_map[
:, : self.num_valid_physical_experts
]
.unsqueeze(0)
.expand_as(expert_load_window)
.long(),
src=eplb_model_state.expert_load_window,
)
if not execute_shuffle:
metadata = torch.tensor(
[
eplb_model_state.model.num_moe_layers,
eplb_model_state.model.num_logical_experts,
eplb_model_state.physical_to_logical_map.shape[1],
],
dtype=torch.int32,
device="cpu",
)
torch.distributed.broadcast(
metadata, group=get_ep_group().cpu_group, group_src=0
src=expert_load_window,
)
global_expert_load_window = logical_expert_load_window.sum(dim=0)
global_expert_load_windows.append(global_expert_load_window)
# Perform all-reduce to get the expert load across all ranks for each model
global_expert_load_windows = self._allreduce_list(
global_expert_load_windows
)
if not execute_shuffle:
for eplb_model_state, global_expert_load_window in zip(
self.model_states.values(), global_expert_load_windows
):
# (num_moe_layers, old_num_physical_experts)
old_global_expert_indices = eplb_model_state.physical_to_logical_map
torch.distributed.broadcast(
old_global_expert_indices, group=ep_group, group_src=0
)
if not execute_shuffle:
return global_expert_load_windows
else:
assert execute_shuffle
global_expert_load_windows = global_expert_loads
global_expert_load_windows = self._allreduce_list(global_expert_load_windows)
# TODO(bowen): Treat differently for prefill and decode nodes
eplb_model_state = next(iter(self.model_states.values()))
......@@ -806,8 +712,10 @@ class EplbState:
# NOTE(yongji): scale down, we need to rebalance the experts on
# remaining GPUs, transfer the experts while we haven't shutdown
# the GPUs to be released.
cpu_group = get_ep_group().cpu_group
num_nodes = _node_count_with_rank_mapping(cpu_group, rank_mapping)
coordinator = get_ep_group()
assert isinstance(coordinator, StatelessGroupCoordinator)
tcp_store_group = coordinator.tcp_store_group
num_nodes = _node_count_with_rank_mapping(tcp_store_group, rank_mapping)
num_gpus = sum(new_rank != -1 for new_rank in rank_mapping.values())
num_replicas = (
num_replicas // ep_group.size() * num_gpus
......@@ -933,7 +841,6 @@ class EplbState:
if self.async_worker is None:
self.async_worker = start_async_worker(
self,
rank_mapping=rank_mapping,
is_profile=is_profile,
)
......@@ -1089,83 +996,6 @@ class EplbState:
model_state.new_logical_to_physical_map = None
model_state.new_logical_replica_count = None
@staticmethod
def recv_state() -> tuple[list[torch.Tensor], list[torch.Tensor]]:
"""
Receive the expert load and old placement from the master rank.
"""
ep_group = get_ep_group()
num_models = torch.empty(1, dtype=torch.int32, device="cpu")
torch.distributed.broadcast(num_models, group=ep_group.cpu_group, group_src=0)
num_models = num_models.item()
global_expert_loads = []
old_global_expert_indices_per_model = []
for _ in range(num_models):
metadata = torch.empty(3, dtype=torch.int32, device="cpu")
torch.distributed.broadcast(metadata, group=ep_group.cpu_group, group_src=0)
num_moe_layers, num_logical_experts, num_old_physical_experts = (
metadata.tolist()
)
global_expert_load = torch.zeros(
(num_moe_layers, num_logical_experts),
dtype=torch.int64,
device=ep_group.device,
)
all_reduce(global_expert_load, group=ep_group.device_group)
old_global_expert_indices = torch.empty(
(num_moe_layers, num_old_physical_experts),
dtype=torch.int64,
device=ep_group.device,
)
torch.distributed.broadcast(
old_global_expert_indices,
group=ep_group.device_group,
group_src=0,
)
global_expert_loads.append(global_expert_load)
old_global_expert_indices_per_model.append(old_global_expert_indices)
return global_expert_loads, old_global_expert_indices_per_model
@classmethod
def get_eep_state(
cls, parallel_config: ParallelConfig
) -> tuple[
list[torch.Tensor] | None,
list[torch.Tensor] | None,
dict[int, int] | None,
]:
num_local_physical_experts = torch.empty(1, dtype=torch.int32, device="cpu")
torch.distributed.broadcast(
num_local_physical_experts,
group=get_ep_group().cpu_group,
group_src=0,
)
num_local_physical_experts = int(num_local_physical_experts.item())
new_ep_size = get_ep_group().world_size
global_expert_loads, old_global_expert_indices_per_model = (
EplbState.recv_state()
)
# EP configuration for all models has to be the same so as eplb config
num_logical_experts = global_expert_loads[0].shape[1]
parallel_config.eplb_config.num_redundant_experts = (
num_local_physical_experts * new_ep_size - num_logical_experts
)
assert (
old_global_expert_indices_per_model[0].shape[1] % num_local_physical_experts
== 0
)
old_ep_size = (
old_global_expert_indices_per_model[0].shape[1]
// num_local_physical_experts
)
rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
return (
global_expert_loads,
old_global_expert_indices_per_model,
rank_mapping,
)
def _allreduce_list(self, tensor_list: list[torch.Tensor]) -> list[torch.Tensor]:
"""
All-reduce a list of tensors.
......@@ -1203,6 +1033,60 @@ class EplbState:
load_pass_list.append(eplb_model_state.expert_load_pass.clone())
return self._allreduce_list(load_pass_list)
@classmethod
def from_mapping(
cls,
model: MixtureOfExperts,
model_config: ModelConfig,
device: torch.device,
parallel_config: ParallelConfig,
expanded_physical_to_logical: torch.Tensor,
num_valid_physical_experts: int,
) -> "EplbState":
eplb_state = cls(
parallel_config=parallel_config,
device=device,
)
eplb_state.add_model(
model=model,
model_config=model_config,
)
eplb_state.num_valid_physical_experts = num_valid_physical_experts
num_moe_layers = expanded_physical_to_logical.shape[0]
num_physical_experts = expanded_physical_to_logical.shape[1]
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
eplb_model_state.physical_to_logical_map.copy_(expanded_physical_to_logical)
logical_to_physical_map = torch.full(
(
num_moe_layers,
model.num_logical_experts,
eplb_model_state.logical_to_physical_map.shape[2],
),
-1,
dtype=torch.int64,
)
logical_replica_count = torch.zeros(
(num_moe_layers, model.num_logical_experts),
dtype=torch.int64,
)
expanded_physical_to_logical_numpy = expanded_physical_to_logical.cpu().numpy()
for layer_idx in range(num_moe_layers):
for phys_idx in range(num_physical_experts):
logical_idx = expanded_physical_to_logical_numpy[layer_idx, phys_idx]
if logical_idx >= 0:
replica_idx = logical_replica_count[layer_idx, logical_idx]
logical_to_physical_map[layer_idx, logical_idx, replica_idx] = (
phys_idx
)
logical_replica_count[layer_idx, logical_idx] += 1
logical_to_physical_map = logical_to_physical_map.to(device)
logical_replica_count = logical_replica_count.to(device)
eplb_model_state.logical_to_physical_map.copy_(logical_to_physical_map)
eplb_model_state.logical_replica_count.copy_(logical_replica_count)
return eplb_state
@dataclass
class EplbLayerState:
......
......@@ -19,6 +19,8 @@ from torch.distributed import (
get_global_rank,
)
from vllm.distributed.parallel_state import get_ep_group
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
from vllm.logger import init_logger
logger = init_logger(__name__)
......@@ -249,10 +251,18 @@ def move_to_buffer(
b[dst].copy_(w[src_local], non_blocking=True)
p2p_ops: list[P2POp] = []
if isinstance(get_ep_group(), StatelessGroupCoordinator):
ep_group = get_ep_group()
is_stateless = True
else:
is_stateless = False
# Pre-compute global ranks mapping
# Pre-compute global ranks mapping (only needed for non-stateless groups)
ep_size = ep_group.size()
rank_to_global = {rank: get_global_rank(ep_group, rank) for rank in range(ep_size)}
if not is_stateless:
rank_to_global = {
rank: get_global_rank(ep_group, rank) for rank in range(ep_size)
}
# 2. Post sends
if send_count > 0:
......@@ -284,6 +294,14 @@ def move_to_buffer(
if recver_pos < len(ranks_to_recv):
recv_ranks.append(ranks_to_recv[recver_pos])
for dst in recv_ranks:
if is_stateless:
for w in expert_weights:
op = object.__new__(P2POp)
op.op = torch.distributed.isend
op.tensor = w[src]
op.group_peer = dst
p2p_ops.append(op)
else:
dst_global = rank_to_global[dst]
p2p_ops += [
P2POp(
......@@ -321,6 +339,14 @@ def move_to_buffer(
src = ranks_to_send[recver_pos // num_dst_per_sender]
else:
src = ranks_to_send[recver_pos - remainder_start]
if is_stateless:
for b in expert_weights_buffers:
op = object.__new__(P2POp)
op.op = torch.distributed.irecv
op.tensor = b[dst]
op.group_peer = src
p2p_ops.append(op)
else:
src_global = rank_to_global[src]
p2p_ops += [
P2POp(
......@@ -334,10 +360,16 @@ def move_to_buffer(
# 4. Execute the P2P operations. The real communication happens here.
if p2p_ops and cuda_stream is not None:
with torch.cuda.stream(cuda_stream):
if is_stateless:
ep_group.device_communicator.batch_isend_irecv(p2p_ops)
else:
reqs = batch_isend_irecv(p2p_ops)
for req in reqs:
req.wait()
elif p2p_ops:
if is_stateless:
ep_group.device_communicator.batch_isend_irecv(p2p_ops)
else:
reqs = batch_isend_irecv(p2p_ops)
for req in reqs:
req.wait()
......
......@@ -33,7 +33,7 @@ from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from datetime import timedelta
from multiprocessing import shared_memory
from typing import Any, Protocol
from typing import TYPE_CHECKING, Any, Protocol
from unittest.mock import patch
import torch
......@@ -55,6 +55,9 @@ from vllm.utils.torch_utils import (
direct_register_custom_op,
)
if TYPE_CHECKING:
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
@dataclass
class GraphCaptureContext:
......@@ -1157,6 +1160,55 @@ def init_model_parallel_group(
)
def _init_stateless_group(
group_ranks: list[list[int]],
group_name: str,
group_ports: list[list[int]],
host: str,
backend: str,
use_device_communicator: bool = True,
) -> "StatelessGroupCoordinator":
"""Create a StatelessGroupCoordinator with the given parameters."""
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
world = get_world_group()
return StatelessGroupCoordinator(
group_ranks=group_ranks,
local_rank=world.local_rank,
torch_distributed_backend=backend,
use_device_communicator=use_device_communicator,
group_name=group_name,
host=host,
group_ports=group_ports,
global_rank=world.rank,
global_world_size=world.world_size,
)
def _replace_active_groups(
*,
world: GroupCoordinator | None,
dp: GroupCoordinator | None,
ep: GroupCoordinator | None,
eplb: GroupCoordinator | None,
node_count: int | None,
) -> None:
"""Destroy the current DP/EP/WORLD/EPLB groups and replace them.
Destruction is collective — all ranks in the old groups must call this
function together. Pass all-``None`` to tear down without replacement.
"""
global _WORLD, _DP, _EP, _EPLB, _NODE_COUNT
for group in (_DP, _EP, _WORLD, _EPLB):
if group is not None:
group.destroy()
_WORLD = world
_DP = dp
_EP = ep
_EPLB = eplb
_NODE_COUNT = node_count
_TP: GroupCoordinator | None = None
......@@ -1254,6 +1306,39 @@ def set_custom_all_reduce(enable: bool):
_ENABLE_CUSTOM_ALL_REDUCE = enable
def _init_elastic_ep_world(
config, local_rank: int, backend: str, rank: int, world_size: int
) -> None:
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
global _WORLD, _NODE_COUNT
assert _WORLD is None, "world group already initialized"
parallel_config = config.parallel_config
global_rank = parallel_config.data_parallel_rank * world_size + rank
global_world_size = parallel_config.world_size_across_dp
all_ranks = list(range(global_world_size))
group_ranks = [all_ranks[i : i + 1] for i in range(global_world_size)]
if global_rank in all_ranks:
group_ranks = [all_ranks]
group_ports = [parallel_config.get_next_stateless_world_group_port()]
world = StatelessGroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_device_communicator=False,
group_name="world",
host=parallel_config.data_parallel_master_ip,
group_ports=group_ports,
global_rank=global_rank,
global_world_size=global_world_size,
)
assert parallel_config.nnodes_within_dp == 1, (
"Elastic EP is not supported with multi-node TP/PP"
)
_NODE_COUNT = _node_count(world.tcp_store_group)
_WORLD = world
def init_distributed_environment(
world_size: int = -1,
rank: int = -1,
......@@ -1273,6 +1358,7 @@ def init_distributed_environment(
from vllm.config import get_current_vllm_config_or_none
config = get_current_vllm_config_or_none()
enable_elastic_ep = config is not None and config.parallel_config.enable_elastic_ep
if (
config is not None
and config.parallel_config.distributed_executor_backend != "external_launcher"
......@@ -1280,6 +1366,7 @@ def init_distributed_environment(
config.parallel_config.nnodes > 1
or config.parallel_config.data_parallel_size > 1
)
and not enable_elastic_ep
):
parallel_config = config.parallel_config
# adjust to take into account data parallelism
......@@ -1333,6 +1420,18 @@ def init_distributed_environment(
rank=rank,
timeout=timeout,
)
if enable_elastic_ep:
tp_pp_cpu_group = torch.distributed.new_group(
backend="gloo", timeout=timeout
)
if _node_count(tp_pp_cpu_group) > 1:
# NOTE(yongji): StatelessGroupCoordinator uses data_parallel_master_ip
# to initialize all DP/EP groups, hence all ranks within TP/PP group
# must reside on the same node
raise RuntimeError(
"Elastic EP is not yet supported with multi-node TP/PP"
)
# set the local rank
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
......@@ -1341,6 +1440,9 @@ def init_distributed_environment(
# setting, where we can use rank as local rank
local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank
global _WORLD, _NODE_COUNT, _INNER_DP_WORLD
if enable_elastic_ep:
_init_elastic_ep_world(config, local_rank, backend, rank, world_size)
return
if _WORLD is None:
ranks = list(range(torch.distributed.get_world_size()))
_WORLD = init_world_group(ranks, local_rank, backend)
......@@ -1404,16 +1506,33 @@ def initialize_model_parallel(
"""
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
data_parallel_size = 1
from vllm.config import get_current_vllm_config_or_none
from vllm.config import get_current_vllm_config
config = get_current_vllm_config_or_none()
if config is not None:
config = get_current_vllm_config()
data_parallel_size = config.parallel_config.data_parallel_size
enable_elastic_ep = config.parallel_config.enable_elastic_ep
if enable_elastic_ep:
# Use stateless world group for global information
world_size = get_world_group().world_size
rank = get_world_group().rank
backend = backend or "nccl"
tp_pp_pcp_size = (
tensor_model_parallel_size
* pipeline_model_parallel_size
* prefill_context_model_parallel_size
)
local_all_ranks = torch.arange(tp_pp_pcp_size).reshape(
pipeline_model_parallel_size,
prefill_context_model_parallel_size,
tensor_model_parallel_size,
)
else:
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
backend = backend or torch.distributed.get_backend(
get_world_group().device_group
)
# the layout order is: ExternalDP x DP x PP x TP
# ExternalDP is the data parallel group that is not part of the model,
......@@ -1437,7 +1556,9 @@ def initialize_model_parallel(
assert _TP is None, "tensor model parallel group is already initialized"
group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
if enable_elastic_ep:
group_ranks = local_all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
# message queue broadcaster is only used in tensor model parallel group
_TP = init_model_parallel_group(
group_ranks,
......@@ -1456,6 +1577,11 @@ def initialize_model_parallel(
# TP group into tp_size//dcp_size DCP groups.
group_ranks = all_ranks.reshape(-1, decode_context_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
if enable_elastic_ep:
group_ranks = local_all_ranks.reshape(
-1, decode_context_model_parallel_size
).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
_DCP = init_model_parallel_group(
group_ranks,
get_world_group().local_rank,
......@@ -1472,6 +1598,13 @@ def initialize_model_parallel(
.unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
if enable_elastic_ep:
group_ranks = (
local_all_ranks.transpose(1, 2)
.reshape(-1, prefill_context_model_parallel_size)
.unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
_PCP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="pcp"
)
......@@ -1483,6 +1616,13 @@ def initialize_model_parallel(
all_ranks.transpose(2, 4).reshape(-1, pipeline_model_parallel_size).unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
if enable_elastic_ep:
group_ranks = (
local_all_ranks.transpose(0, 2)
.reshape(-1, pipeline_model_parallel_size)
.unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
_PP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="pp"
)
......@@ -1491,6 +1631,19 @@ def initialize_model_parallel(
assert _DP is None, "data parallel group is already initialized"
group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
if enable_elastic_ep:
parallel_config = config.parallel_config
dp_ports = [
parallel_config.get_next_stateless_dp_group_port() for _ in group_ranks
]
_DP = _init_stateless_group(
group_ranks,
"dp",
dp_ports,
parallel_config.data_parallel_master_ip,
backend,
)
else:
_DP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="dp"
)
......@@ -1498,7 +1651,7 @@ def initialize_model_parallel(
global _EP
assert _EP is None, "expert parallel group is already initialized"
# Don't create EP group for dense models.
if config is None or config.model_config is None or config.model_config.is_moe:
if config.model_config is None or config.model_config.is_moe:
group_ranks = (
all_ranks.transpose(1, 2)
.reshape(
......@@ -1510,6 +1663,19 @@ def initialize_model_parallel(
.unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
if enable_elastic_ep:
parallel_config = config.parallel_config
ep_ports = [
parallel_config.get_next_stateless_ep_group_port() for _ in group_ranks
]
_EP = _init_stateless_group(
group_ranks,
"ep",
ep_ports,
parallel_config.data_parallel_master_ip,
backend,
)
else:
_EP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="ep"
)
......@@ -1525,9 +1691,24 @@ def initialize_model_parallel(
and config.parallel_config is not None
and config.parallel_config.enable_eplb
):
# Reuse the same group_ranks from EP
if enable_elastic_ep:
eplb_ports = [
parallel_config.get_next_stateless_eplb_group_port()
for _ in group_ranks
]
_EPLB = _init_stateless_group(
group_ranks,
"eplb",
eplb_ports,
parallel_config.data_parallel_master_ip,
backend,
)
else:
_EPLB = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="eplb"
group_ranks,
get_world_group().local_rank,
backend,
group_name="eplb",
)
# If no EP group needed, _EP remains None
# If no EPLB group needed, _EPLB remains None
......@@ -1558,7 +1739,11 @@ def ensure_model_parallel_initialized(
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
values if the model parallel groups are initialized.
"""
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
world_group = get_world_group()
if hasattr(world_group, "backend"):
backend = backend or world_group.backend
else:
backend = backend or torch.distributed.get_backend(world_group.device_group)
if not model_parallel_is_initialized():
initialize_model_parallel(
tensor_model_parallel_size,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
import torch
from torch.distributed import Backend, ProcessGroup
from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator
from vllm.distributed.parallel_state import (
GroupCoordinator,
TensorMetadata,
_get_unique_name,
_register_group,
_split_tensor_dict,
)
from vllm.distributed.utils import (
StatelessProcessGroup,
stateless_destroy_torch_distributed_process_group,
stateless_init_torch_distributed_process_group,
)
from vllm.logger import init_logger
from vllm.utils.import_utils import resolve_obj_by_qualname
logger = init_logger(__name__)
class StatelessGroupCoordinator(GroupCoordinator):
"""
A stateless version of the GroupCoordinator class in parallel_state,
It will create CPU, device and TCPStore based communication groups
that are independent of PyTorch's WORLD group. Hence,
communication groups with a different set of participants GPUs
can be created without destroying the existing ones.
"""
def __init__(
self,
group_ranks: list[list[int]],
local_rank: int,
torch_distributed_backend: str | Backend,
use_device_communicator: bool,
use_message_queue_broadcaster: bool = False,
group_name: str | None = None,
host: str = "127.0.0.1",
group_ports: list[list[int]] | None = None,
global_rank: int = 0,
global_world_size: int = 1,
):
group_name = group_name or "anonymous"
self.unique_name = _get_unique_name(group_name)
_register_group(self)
self.rank = global_rank
self.local_rank = local_rank
self_device_group = None
self_cpu_group = None
self_tcp_store_group = None
from vllm.platforms import current_platform
backend = str(torch_distributed_backend)
self.backend = backend
assert group_ports is not None, "group_ports is not provided"
for idx, ranks in enumerate(group_ranks):
if self.rank in ranks:
self.ranks = ranks
self.world_size = len(ranks)
self.rank_in_group = ranks.index(self.rank)
ports = group_ports[idx]
device_port = ports[0]
cpu_port = ports[1]
tcp_store_port = ports[2]
device_group = stateless_init_torch_distributed_process_group(
host=host,
port=device_port,
rank=self.rank_in_group,
world_size=self.world_size,
backend=backend,
group_name=f"{self.unique_name}_device",
)
cpu_group = stateless_init_torch_distributed_process_group(
host=host,
port=cpu_port,
rank=self.rank_in_group,
world_size=self.world_size,
backend="gloo",
group_name=f"{self.unique_name}_cpu",
)
tcp_store_group = StatelessProcessGroup.create(
host=host,
port=tcp_store_port,
rank=self.rank_in_group,
world_size=self.world_size,
)
self_device_group = device_group
self_cpu_group = cpu_group
self_tcp_store_group = tcp_store_group
assert self_cpu_group is not None
assert self_device_group is not None
assert self_tcp_store_group is not None
self.cpu_group = self_cpu_group
self.device_group = self_device_group
self.tcp_store_group = self_tcp_store_group
if current_platform.is_cuda_alike():
self.device = torch.device(f"cuda:{local_rank}")
elif current_platform.is_xpu():
self.device = torch.device(f"xpu:{local_rank}")
elif current_platform.is_out_of_tree():
self.device = torch.device(f"{current_platform.device_name}:{local_rank}")
else:
self.device = torch.device("cpu")
self.use_device_communicator = use_device_communicator
self.device_communicator = None
if use_device_communicator and self.world_size > 1:
device_comm_cls = resolve_obj_by_qualname(
current_platform.get_device_communicator_cls()
)
assert device_comm_cls == CudaCommunicator
self.device_communicator = CudaCommunicator(
cpu_group=self.cpu_group,
device=self.device,
device_group=self.device_group,
unique_name=self.unique_name,
global_ranks=self.ranks,
global_world_size=global_world_size,
tcp_store_group=self.tcp_store_group,
)
self.mq_broadcaster = None
self.use_custom_op_call = (
current_platform.is_cuda_alike() or current_platform.is_tpu()
)
self.use_cpu_custom_send_recv = False
def destroy(self):
if self.device_communicator:
self.device_communicator.destroy()
if self.device_group:
stateless_destroy_torch_distributed_process_group(self.device_group)
if self.cpu_group:
stateless_destroy_torch_distributed_process_group(self.cpu_group)
def size(self) -> int:
"""Return the world size of this group."""
return self.world_size
def broadcast(self, input_: torch.Tensor, src: int = 0):
if self.world_size == 1:
return input_
if self.device_communicator and input_.is_cuda:
return self.device_communicator.broadcast(input_, src)
else:
return self.tcp_store_group.broadcast(input_, src)
def broadcast_object(self, obj=None, src: int = 0):
if self.world_size == 1:
return obj
return self.tcp_store_group.broadcast_obj(obj, src)
def broadcast_object_list(
self, obj_list: list[Any], src: int = 0, group: ProcessGroup | None = None
):
assert src < self.world_size
if self.world_size == 1:
return obj_list
if self.rank_in_group == src:
for obj in obj_list:
self.tcp_store_group.broadcast_obj(obj, src)
else:
for i in range(len(obj_list)):
obj_list[i] = self.tcp_store_group.broadcast_obj(None, src)
return obj_list
def broadcast_tensor_dict(
self,
tensor_dict: dict[str, torch.Tensor | Any] | None = None,
src: int = 0,
group: ProcessGroup | None = None,
metadata_group: ProcessGroup | None = None,
) -> dict[str, torch.Tensor | Any] | None:
if self.world_size == 1:
return tensor_dict
if self.rank_in_group == src:
assert isinstance(tensor_dict, dict), (
f"Expecting a dictionary, got {type(tensor_dict)}"
)
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
else:
metadata_list = None
tensor_list = []
recv_metadata_list: list[tuple[str, Any]] = self.tcp_store_group.broadcast_obj(
metadata_list, src
)
if self.rank_in_group != src:
tensor_dict = {}
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(
value.size, dtype=value.dtype, device=value.device
)
tensor_list.append(tensor)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value
for tensor in tensor_list:
if tensor.numel() == 0:
continue
if self.device_communicator and tensor.is_cuda:
tensor.copy_(self.device_communicator.broadcast(tensor, src))
else:
tensor.copy_(self.tcp_store_group.broadcast(tensor, src))
return tensor_dict
def send_object(self, obj, dst: int) -> None:
assert dst < self.world_size
assert dst != self.rank_in_group
self.tcp_store_group.send_obj(obj, dst)
def recv_object(self, src: int):
assert src < self.world_size
assert src != self.rank_in_group
return self.tcp_store_group.recv_obj(src)
def send_tensor_dict(
self,
tensor_dict: dict[str, torch.Tensor | Any],
dst: int | None = None,
all_gather_group: Optional["GroupCoordinator"] = None,
all_gather_tensors: dict[str, bool] | None = None,
) -> dict[str, torch.Tensor | Any] | None:
if self.world_size == 1:
return tensor_dict
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
assert dst < self.world_size
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
self.tcp_store_group.send_obj(metadata_list, dst)
for tensor in tensor_list:
if tensor.numel() == 0:
continue
if self.device_communicator and tensor.is_cuda:
self.device_communicator.send(tensor, dst)
else:
self.tcp_store_group.send(tensor, dst)
return None
def recv_tensor_dict(
self,
src: int | None = None,
all_gather_group: Optional["GroupCoordinator"] = None,
all_gather_tensors: dict[str, bool] | None = None,
) -> dict[str, torch.Tensor | Any] | None:
if self.world_size == 1:
return None
if src is None:
src = (self.rank_in_group - 1) % self.world_size
assert src < self.world_size
recv_metadata_list = self.tcp_store_group.recv_obj(src)
tensor_dict = {}
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size, dtype=value.dtype, device=value.device)
if tensor.numel() > 0:
if self.device_communicator and tensor.is_cuda:
tensor = self.device_communicator.recv(
tensor.size(), tensor.dtype, src
)
else:
tensor = self.tcp_store_group.recv(tensor, src)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value
return tensor_dict
def barrier(self):
self.tcp_store_group.barrier()
def gather(
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
) -> torch.Tensor | None:
if self.world_size == 1:
return input_
if self.device_communicator is None:
raise ValueError("No device communicator found")
if self.rank_in_group == dst:
gathered_list = [torch.empty_like(input_) for _ in range(self.world_size)]
gathered_list[self.rank_in_group] = input_
for src_rank in range(self.world_size):
if src_rank != self.rank_in_group:
gathered_list[src_rank] = self.device_communicator.recv(
input_.size(), input_.dtype, src_rank
)
return torch.cat(gathered_list, dim=dim)
else:
self.device_communicator.send(input_, dst)
return None
......@@ -18,7 +18,7 @@ from datetime import timedelta
from typing import Any
import torch
from torch.distributed import ProcessGroup, TCPStore
from torch.distributed import ProcessGroup, Store, TCPStore
from torch.distributed.distributed_c10d import (
Backend,
PrefixStore,
......@@ -228,6 +228,55 @@ class StatelessProcessGroup:
gathered_objs.append(recv_obj)
return gathered_objs
def broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
"""Broadcast a tensor from source rank to all other ranks."""
if self.rank == src:
tensor_bytes = pickle.dumps(tensor)
self.expire_data()
key = f"broadcast_tensor/{src}/{self.broadcast_send_counter}"
self.store.set(key, tensor_bytes)
self.broadcast_send_counter += 1
self.entries.append((key, time.time()))
return tensor
else:
key = f"broadcast_tensor/{src}/{self.broadcast_recv_src_counter[src]}"
tensor = pickle.loads(self.store.get(key))
self.broadcast_recv_src_counter[src] += 1
return tensor
def send(self, tensor: torch.Tensor, dst: int):
"""Send a tensor to a destination rank."""
self.expire_data()
key = f"send_tensor/{dst}/{self.send_dst_counter[dst]}"
self.store.set(key, pickle.dumps(tensor))
self.send_dst_counter[dst] += 1
self.entries.append((key, time.time()))
def recv(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
"""Receive a tensor from a source rank."""
key = f"send_tensor/{self.rank}/{self.recv_src_counter[src]}"
received = pickle.loads(self.store.get(key))
self.recv_src_counter[src] += 1
tensor.copy_(received)
return tensor
def all_reduce(
self, tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM
) -> torch.Tensor:
"""All-reduce a tensor across all ranks."""
tensors = self.all_gather_obj(tensor)
result = tensors[0].clone()
for t in tensors[1:]:
if op == torch.distributed.ReduceOp.SUM:
result.add_(t)
elif op == torch.distributed.ReduceOp.PRODUCT:
result.mul_(t)
elif op == torch.distributed.ReduceOp.MAX:
result = torch.maximum(result, t)
elif op == torch.distributed.ReduceOp.MIN:
result = torch.minimum(result, t)
return result
def barrier(self, timeout: float = 30.0):
"""A robust barrier to synchronize all ranks.
......@@ -448,8 +497,14 @@ def init_gloo_process_group(
def stateless_init_torch_distributed_process_group(
host: str, port: int, rank: int, world_size: int, backend: str
) -> ProcessGroup:
host: str,
port: int,
rank: int,
world_size: int,
backend: str,
group_name: str | None = None,
return_store: bool = False,
) -> ProcessGroup | tuple[ProcessGroup, Store]:
"""
A replacement for `torch.distributed.init_process_group` that does not
pollute the global state. The created ProcessGroup object can be used for
......@@ -496,26 +551,36 @@ def stateless_init_torch_distributed_process_group(
# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.
prefix_store = PrefixStore(init_method, store)
try:
from vllm.platforms import current_platform
return current_platform.stateless_init_device_torch_dist_pg(
backend=backend,
if backend == "gloo":
pg = init_gloo_process_group(
prefix_store=prefix_store,
group_rank=group_rank,
group_size=group_size,
timeout=timeout,
)
except NotImplementedError:
# If platform doesn't implement stateless_init_device_torch_dist_pg, it
# will raise a NotImplementedError. In this case, we fall back to gloo.
return init_gloo_process_group(
else:
from vllm.platforms import current_platform
pg = current_platform.stateless_init_device_torch_dist_pg(
backend=backend,
prefix_store=prefix_store,
group_rank=group_rank,
group_size=group_size,
timeout=timeout,
)
if group_name is not None:
from torch._C._distributed_c10d import _register_process_group
pg._set_group_name(group_name)
_register_process_group(group_name, pg)
if return_store:
return pg, store
else:
return pg
def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None:
"""
......
......@@ -419,6 +419,7 @@ class EngineArgs:
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
moe_backend: MoEBackend = KernelConfig.moe_backend
all2all_backend: All2AllBackend = ParallelConfig.all2all_backend
enable_elastic_ep: bool = ParallelConfig.enable_elastic_ep
enable_dbo: bool = ParallelConfig.enable_dbo
ubatch_size: int = ParallelConfig.ubatch_size
dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
......@@ -896,6 +897,9 @@ class EngineArgs:
"--ubatch-size",
**parallel_kwargs["ubatch_size"],
)
parallel_group.add_argument(
"--enable-elastic-ep", **parallel_kwargs["enable_elastic_ep"]
)
parallel_group.add_argument(
"--dbo-decode-token-threshold",
**parallel_kwargs["dbo_decode_token_threshold"],
......@@ -1698,6 +1702,7 @@ class EngineArgs:
is_moe_model=model_config.is_moe,
enable_expert_parallel=self.enable_expert_parallel,
all2all_backend=self.all2all_backend,
enable_elastic_ep=self.enable_elastic_ep,
enable_dbo=self.enable_dbo,
ubatch_size=self.ubatch_size,
dbo_decode_token_threshold=self.dbo_decode_token_threshold,
......
......@@ -246,8 +246,12 @@ def run_multi_api_server(args: argparse.Namespace):
api_server_manager: APIServerProcessManager | None = None
from vllm.v1.engine.utils import get_engine_zmq_addresses
addresses = get_engine_zmq_addresses(vllm_config, num_api_servers)
with launch_core_engines(
vllm_config, executor_class, log_stats, num_api_servers
vllm_config, executor_class, log_stats, addresses, num_api_servers
) as (local_engine_manager, coordinator, addresses):
# Construct common args for the APIServerProcessManager up-front.
api_server_manager_kwargs = dict(
......
......@@ -243,6 +243,8 @@ if TYPE_CHECKING:
VLLM_LORA_DISABLE_PDL: bool = False
VLLM_ENABLE_CUDA_COMPATIBILITY: bool = False
VLLM_CUDA_COMPATIBILITY_PATH: str | None = None
VLLM_ELASTIC_EP_SCALE_UP_LAUNCH: bool = False
VLLM_ELASTIC_EP_DRAIN_REQUESTS: bool = False
def get_default_cache_root():
......@@ -1617,6 +1619,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_CUDA_COMPATIBILITY_PATH": lambda: os.environ.get(
"VLLM_CUDA_COMPATIBILITY_PATH", None
),
# Whether it is a scale up launch engine for elastic EP,
# Should only be set by EngineCoreClient.
"VLLM_ELASTIC_EP_SCALE_UP_LAUNCH": lambda: bool(
int(os.getenv("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH", "0"))
),
# Whether to wait for all requests to drain before sending the
# scaling command in elastic EP.
"VLLM_ELASTIC_EP_DRAIN_REQUESTS": lambda: bool(
int(os.getenv("VLLM_ELASTIC_EP_DRAIN_REQUESTS", "0"))
),
}
......
......@@ -627,6 +627,7 @@ class FusedMoE(CustomOp):
moe_quant_params["intermediate_size_full"] = intermediate_size
self.quant_method.create_weights(layer=self, **moe_quant_params)
self.base_quant_method = self.quant_method
# Disable shared expert overlap if:
# - we are using eplb with non-default backend, because of correctness issues
......@@ -683,7 +684,7 @@ class FusedMoE(CustomOp):
# routing_tables only needed for round-robin expert placement with
# DeepEP all2all backend.
routing_tables = self._maybe_init_expert_routing_tables()
prepare_finalize = self.quant_method.maybe_make_prepare_finalize(
prepare_finalize = self.base_quant_method.maybe_make_prepare_finalize(
routing_tables=routing_tables
)
if prepare_finalize is not None:
......@@ -693,7 +694,7 @@ class FusedMoE(CustomOp):
self._replace_quant_method(
FusedMoEModularMethod.make(
self,
self.quant_method,
self.base_quant_method,
prepare_finalize,
self.shared_experts,
inplace=not self.moe_config.disable_inplace,
......
......@@ -6,10 +6,13 @@ pynvml. However, it should not initialize cuda context.
import os
from collections.abc import Callable
from datetime import timedelta
from functools import cache, wraps
from typing import TYPE_CHECKING, TypeVar
import torch
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
from typing_extensions import ParamSpec
# import custom ops, trigger op registration
......@@ -482,6 +485,37 @@ class CudaPlatformBase(Platform):
def get_static_graph_wrapper_cls(cls) -> str:
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
@classmethod
def stateless_init_device_torch_dist_pg(
cls,
backend: str,
prefix_store: PrefixStore,
group_rank: int,
group_size: int,
timeout: timedelta,
) -> ProcessGroup:
assert is_nccl_available()
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
from torch.distributed.distributed_c10d import ProcessGroupNCCL
backend_options = ProcessGroupNCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupNCCL(
prefix_store, group_rank, group_size, backend_options
)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class)
return pg
@classmethod
def device_count(cls) -> int:
return cuda_device_count_stateless()
......
......@@ -2,10 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from datetime import timedelta
from functools import cache, lru_cache, wraps
from typing import TYPE_CHECKING
import torch
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
import vllm.envs as envs
from vllm.logger import init_logger
......@@ -656,6 +659,37 @@ class RocmPlatform(Platform):
def get_static_graph_wrapper_cls(cls) -> str:
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
@classmethod
def stateless_init_device_torch_dist_pg(
cls,
backend: str,
prefix_store: PrefixStore,
group_rank: int,
group_size: int,
timeout: timedelta,
) -> ProcessGroup:
assert is_nccl_available()
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
from torch.distributed.distributed_c10d import ProcessGroupNCCL
backend_options = ProcessGroupNCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupNCCL(
prefix_store, group_rank, group_size, backend_options
)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class)
return pg
@classmethod
def device_count(cls) -> int:
return cuda_device_count_stateless()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment