"components/vscode:/vscode.git/clone" did not exist on "c263a99ed268424692bcd168765f08525bd425a7"
Unverified Commit 50dd4cb4 authored by Ilya Markov's avatar Ilya Markov Committed by GitHub
Browse files

[EPLB] Add nixl-based eplb communicator (#36276)


Signed-off-by: default avatarilmarkov <markovilya197@gmail.com>
Signed-off-by: default avatarMarkov Ilya <markovilya19@gmail.com>
parent f774ba02
...@@ -153,6 +153,7 @@ Configure EPLB with the `--eplb-config` argument, which accepts a JSON string. T ...@@ -153,6 +153,7 @@ Configure EPLB with the `--eplb-config` argument, which accepts a JSON string. T
| `num_redundant_experts` | Additional global experts per EP rank beyond equal distribution | `0` | | `num_redundant_experts` | Additional global experts per EP rank beyond equal distribution | `0` |
| `use_async` | Use non-blocking EPLB for reduced latency overhead | `false` | | `use_async` | Use non-blocking EPLB for reduced latency overhead | `false` |
| `policy` | The policy type for expert parallel load balancing | `"default"` | | `policy` | The policy type for expert parallel load balancing | `"default"` |
| `communicator` | Backend for expert weight transfers: `"torch_nccl"`, `"torch_gloo"`, `"pynccl"`, `"nixl"`, or `null` (auto) | `null` |
For example: For example:
......
...@@ -9,7 +9,10 @@ import torch ...@@ -9,7 +9,10 @@ import torch
import torch.distributed import torch.distributed
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.eplb.eplb_communicator import create_eplb_communicator from vllm.distributed.eplb.eplb_communicator import (
create_eplb_communicator,
has_nixl,
)
from vllm.distributed.eplb.rebalance_execute import ( from vllm.distributed.eplb.rebalance_execute import (
move_from_buffer, move_from_buffer,
rearrange_expert_weights_inplace, rearrange_expert_weights_inplace,
...@@ -527,7 +530,9 @@ def _test_rearrange_expert_weights_with_redundancy( ...@@ -527,7 +530,9 @@ def _test_rearrange_expert_weights_with_redundancy(
(4, 8, 8, 16), (4, 8, 8, 16),
], ],
) )
@pytest.mark.parametrize("eplb_communicator", ["torch_nccl", "torch_gloo", "pynccl"]) @pytest.mark.parametrize(
"eplb_communicator", ["torch_nccl", "torch_gloo", "pynccl", "nixl"]
)
def test_rearrange_expert_weights_with_redundancy( def test_rearrange_expert_weights_with_redundancy(
world_size, world_size,
num_layers, num_layers,
...@@ -537,6 +542,8 @@ def test_rearrange_expert_weights_with_redundancy( ...@@ -537,6 +542,8 @@ def test_rearrange_expert_weights_with_redundancy(
): ):
"""Test the functionality of rearranging expert weights with redundancy.""" """Test the functionality of rearranging expert weights with redundancy."""
if eplb_communicator == "nixl" and not has_nixl():
pytest.skip("NIXL is not available")
if torch.accelerator.device_count() < world_size: if torch.accelerator.device_count() < world_size:
pytest.skip(f"Need at least {world_size} GPUs to run the test") pytest.skip(f"Need at least {world_size} GPUs to run the test")
distributed_run( distributed_run(
...@@ -633,7 +640,9 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None: ...@@ -633,7 +640,9 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
(2, 2, 2, 3), (2, 2, 2, 3),
], ],
) )
@pytest.mark.parametrize("eplb_communicator", ["torch_nccl", "torch_gloo", "pynccl"]) @pytest.mark.parametrize(
"eplb_communicator", ["torch_nccl", "torch_gloo", "pynccl", "nixl"]
)
def test_async_transfer_layer_without_mtp( def test_async_transfer_layer_without_mtp(
world_size: int, world_size: int,
num_layers: int, num_layers: int,
...@@ -643,6 +652,8 @@ def test_async_transfer_layer_without_mtp( ...@@ -643,6 +652,8 @@ def test_async_transfer_layer_without_mtp(
): ):
"""Exercise async EPLB transfer path without MTP/spec decode.""" """Exercise async EPLB transfer path without MTP/spec decode."""
if eplb_communicator == "nixl" and not has_nixl():
pytest.skip("NIXL is not available")
if torch.accelerator.device_count() < world_size: if torch.accelerator.device_count() < world_size:
pytest.skip(f"Need at least {world_size} GPUs to run the test") pytest.skip(f"Need at least {world_size} GPUs to run the test")
......
...@@ -36,7 +36,7 @@ DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"] ...@@ -36,7 +36,7 @@ DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
DataParallelBackend = Literal["ray", "mp"] DataParallelBackend = Literal["ray", "mp"]
EPLBPolicyOption = Literal["default"] EPLBPolicyOption = Literal["default"]
DCPCommBackend = Literal["ag_rs", "a2a"] DCPCommBackend = Literal["ag_rs", "a2a"]
EPLBCommunicatorBackend = Literal["torch_nccl", "torch_gloo", "pynccl"] EPLBCommunicatorBackend = Literal["torch_nccl", "torch_gloo", "nixl", "pynccl"]
All2AllBackend = Literal[ All2AllBackend = Literal[
"naive", "naive",
"pplx", "pplx",
...@@ -90,6 +90,7 @@ class EPLBConfig: ...@@ -90,6 +90,7 @@ class EPLBConfig:
Backend for EPLB expert weight communication: Backend for EPLB expert weight communication:
- "torch_nccl": Use torch.distributed on the device process group - "torch_nccl": Use torch.distributed on the device process group
- "torch_gloo": Use torch.distributed gloo with CPU staging - "torch_gloo": Use torch.distributed gloo with CPU staging
- "nixl": Use NIXL/ RIXL with staged send/recv buffers
- "pynccl": Use PyNccl send/recv - "pynccl": Use PyNccl send/recv
- None: Auto-select backend ("torch_gloo" for async, "torch_nccl" for sync) - None: Auto-select backend ("torch_gloo" for async, "torch_nccl" for sync)
""" """
......
...@@ -4,8 +4,12 @@ ...@@ -4,8 +4,12 @@
EPLB communicator implementations and factory. EPLB communicator implementations and factory.
""" """
import contextlib
import time
import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from datetime import timedelta
import torch import torch
from torch.distributed import ( from torch.distributed import (
...@@ -18,13 +22,27 @@ from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator ...@@ -18,13 +22,27 @@ from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.device_communicators.pynccl_wrapper import ( from vllm.distributed.device_communicators.pynccl_wrapper import (
ncclDataTypeEnum, ncclDataTypeEnum,
) )
from vllm.distributed.parallel_state import GroupCoordinator, is_local_first_rank from vllm.distributed.nixl_utils import (
NixlWrapper,
nixl_agent_config,
)
from vllm.distributed.parallel_state import (
GroupCoordinator,
get_pp_group,
is_local_first_rank,
)
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
def has_nixl() -> bool:
"""Whether the optional NIXL / RIXL package is available."""
return NixlWrapper is not None
class EplbCommunicator(ABC): class EplbCommunicator(ABC):
"""Abstract EPLB communicator for expert weight transfers.""" """Abstract EPLB communicator for expert weight transfers."""
...@@ -40,6 +58,12 @@ class EplbCommunicator(ABC): ...@@ -40,6 +58,12 @@ class EplbCommunicator(ABC):
def execute(self) -> None: def execute(self) -> None:
pass pass
@property
def needs_profile_buffer_reservation(self) -> bool:
"""Whether the profile path must run a dummy collective operation to reserve
communication buffers."""
return True
def set_stream(self, cuda_stream: torch.cuda.Stream | None) -> None: def set_stream(self, cuda_stream: torch.cuda.Stream | None) -> None:
self._cuda_stream = cuda_stream self._cuda_stream = cuda_stream
...@@ -167,6 +191,385 @@ class TorchDistGlooStagedEplbCommunicator(EplbCommunicator): ...@@ -167,6 +191,385 @@ class TorchDistGlooStagedEplbCommunicator(EplbCommunicator):
dst_tensor.copy_(cpu_tensor, non_blocking=True) dst_tensor.copy_(cpu_tensor, non_blocking=True)
class NixlEplbCommunicator(EplbCommunicator):
"""EPLB communicator backed by NIXL READ transfers."""
def __init__(
self,
cpu_group: ProcessGroup,
expert_weights: Sequence[torch.Tensor],
cuda_stream: torch.cuda.Stream | None = None,
) -> None:
assert expert_weights, "NixlEplbCommunicator requires non-empty expert_weights."
if NixlWrapper is None:
raise RuntimeError("NIXL/ RIXL is unavailable.")
self._cpu_group = cpu_group
self._cuda_stream = cuda_stream
self._world_size = cpu_group.size()
self._rank = cpu_group.rank()
self._send_tensors: dict[torch.dtype, list[list[torch.Tensor]]] = {}
self._recv_tensors: dict[torch.dtype, list[list[torch.Tensor]]] = {}
self._dtypes: list[torch.dtype] = []
self._device = expert_weights[0].device
for tensor in expert_weights:
assert tensor.device == self._device, (
"All local EPLB tensors are expected to be on the same device: "
f"expected={self._device}, got={tensor.device}"
)
if tensor.dtype not in self._dtypes:
self._dtypes.append(tensor.dtype)
config = (
nixl_agent_config(capture_telemetry=False)
if nixl_agent_config is not None
else None
)
self._nixl_wrapper = NixlWrapper(self._make_agent_name(), config)
self._nixl_memory_type = "VRAM"
self._registered_desc: object | None = None
self._remote_agents: dict[int, str] = {}
self._remote_send_meta: dict[int, tuple[int, int, int]] = {}
self._send_buffer: torch.Tensor = torch.empty(0)
self._recv_buffer: torch.Tensor = torch.empty(0)
self._peer_partition_bytes: int = 0
self._dtype_max_bytes: dict[torch.dtype, int] = {}
self._cuda_device_id = int(self._device.index or 0)
self._xfer_cache: dict[tuple[int, int, int], tuple[int, int, int]] = {}
self._init_step("buffers", self._init_registered_buffers, expert_weights)
self._init_step("agents", self._init_remote_agents)
self._init_step("send meta", self._exchange_remote_send_meta)
self._log_initialized()
@property
def needs_profile_buffer_reservation(self) -> bool:
return False
@staticmethod
def _init_step(name: str, fn: object, *args: object, **kwargs: object) -> None:
try:
fn(*args, **kwargs) # type: ignore[operator]
except Exception as exc:
raise RuntimeError(f"NIXL EPLB init failed: {name}") from exc
def _make_agent_name(self) -> str:
"""Build a deployment-unique nixl agent name."""
pp_size = get_pp_group().world_size
pp_suffix = f"-pp{get_pp_group().rank_in_group}" if pp_size > 1 else ""
uid = uuid.uuid4().hex[:8]
return f"eplb-{self._rank}{pp_suffix}-{uid}"
def _get_peer_buckets(
self,
bucket_map: dict[torch.dtype, list[list[torch.Tensor]]],
dtype: torch.dtype,
) -> list[list[torch.Tensor]]:
peer_buckets = bucket_map.get(dtype)
if peer_buckets is None:
peer_buckets = [[] for _ in range(self._world_size)]
bucket_map[dtype] = peer_buckets
return peer_buckets
def add_send(self, tensor: torch.Tensor, dst_rank: int) -> None:
assert dst_rank != self._rank, (
"EPLB communicator should not enqueue same-rank sends: "
f"rank={self._rank}, dst_rank={dst_rank}"
)
self._get_peer_buckets(self._send_tensors, tensor.dtype)[dst_rank].append(
tensor
)
def add_recv(self, tensor: torch.Tensor, src_rank: int) -> None:
assert src_rank != self._rank, (
"EPLB communicator should not enqueue same-rank recvs: "
f"rank={self._rank}, src_rank={src_rank}"
)
self._get_peer_buckets(self._recv_tensors, tensor.dtype)[src_rank].append(
tensor
)
def _init_remote_agents(self) -> None:
local_metadata = self._nixl_wrapper.get_agent_metadata()
gathered_metadata: list[bytes | None] = [None] * self._world_size
torch.distributed.all_gather_object(
gathered_metadata, local_metadata, group=self._cpu_group
)
for peer in range(self._world_size):
if peer == self._rank:
continue
peer_metadata = gathered_metadata[peer]
assert peer_metadata is not None
self._remote_agents[peer] = self._nixl_wrapper.add_remote_agent(
peer_metadata
)
def _init_registered_buffers(self, expert_weights: Sequence[torch.Tensor]) -> None:
total_max_bytes = 0
for dtype in self._dtypes:
max_numel = max(
sum(t.numel() for t in expert_weights if t.dtype == dtype), 1
)
max_bytes = max_numel * dtype.itemsize
self._dtype_max_bytes[dtype] = max_bytes
total_max_bytes += max_bytes
self._peer_partition_bytes = total_max_bytes
# The send buffer needs world_size partitions because remote peers
# READ from fixed offsets (rank * partition_bytes).
# This allocates world_size * partition_bytes
# which can cause OOM on large models.
# TODO(ilmarkov): shrink to const * partition_bytes and execute
# communication in multiple steps dealing with the worst case.
send_total_bytes = self._peer_partition_bytes * self._world_size
self._send_buffer = torch.empty(
send_total_bytes, device=self._device, dtype=torch.uint8
)
self._recv_buffer = torch.empty(
self._peer_partition_bytes, device=self._device, dtype=torch.uint8
)
descs = self._nixl_wrapper.get_reg_descs([self._send_buffer, self._recv_buffer])
self._nixl_wrapper.register_memory(descs)
self._registered_desc = descs
def _exchange_remote_send_meta(self) -> None:
"""Exchange send-buffer metadata so each rank can build dynamic
descriptors at execute time."""
local_meta: tuple[int, int, int] = (
self._send_buffer.data_ptr(),
self._peer_partition_bytes,
self._cuda_device_id,
)
gathered_meta: list[tuple[int, int, int] | None] = [None] * self._world_size
torch.distributed.all_gather_object(
gathered_meta, local_meta, group=self._cpu_group
)
for peer in self._remote_agents:
peer_meta = gathered_meta[peer]
assert peer_meta is not None
self._remote_send_meta[peer] = peer_meta
@staticmethod
def _pack_send_buffer(
peer_tensors: list[torch.Tensor],
send_buffer: torch.Tensor,
byte_offset: int,
) -> int:
"""
Returns the byte offset after the last written byte.
"""
for tensor in peer_tensors:
raw = tensor.reshape(-1).view(torch.uint8)
if raw.numel() == 0:
continue
send_buffer[byte_offset : byte_offset + raw.numel()].copy_(
raw, non_blocking=True
)
byte_offset += raw.numel()
return byte_offset
@staticmethod
def _unpack_recv_buffer(
recv_buffer: torch.Tensor,
peer_tensors: list[torch.Tensor],
byte_offset: int,
) -> int:
"""
Returns the byte offset after the last read byte.
"""
for tensor in peer_tensors:
num_bytes = tensor.numel() * tensor.element_size()
if num_bytes == 0:
continue
tensor.reshape(-1).view(torch.uint8).copy_(
recv_buffer[byte_offset : byte_offset + num_bytes],
non_blocking=True,
)
byte_offset += num_bytes
return byte_offset
def _release_all_cached_handles(self) -> None:
"""Best-effort release of every cached dlist and xfer handle."""
for local_dlist, remote_dlist, xfer in self._xfer_cache.values():
for release_fn, handle in (
(self._nixl_wrapper.release_xfer_handle, xfer),
(self._nixl_wrapper.release_dlist_handle, local_dlist),
(self._nixl_wrapper.release_dlist_handle, remote_dlist),
):
with contextlib.suppress(Exception):
release_fn(handle)
self._xfer_cache.clear()
def _wait_for_all_transfers(self, handles: list[int]) -> None:
pending = set(handles)
while pending:
completed: list[int] = []
for handle in pending:
state = self._nixl_wrapper.check_xfer_state(handle)
if state == "DONE":
completed.append(handle)
continue
if state != "PROC":
raise RuntimeError(f"NIXL transfer failed with state={state}")
for handle in completed:
pending.remove(handle)
if pending:
time.sleep(0.0005)
def _get_or_create_xfer(self, src: int, total_bytes: int, recv_offset: int) -> int:
"""Return a cached xfer handle or create and cache a new one."""
key = (src, total_bytes, recv_offset)
cached = self._xfer_cache.get(key)
if cached is not None:
return cached[2]
recv_base = self._recv_buffer.data_ptr()
local_desc = self._nixl_wrapper.get_xfer_descs(
[
(
recv_base + recv_offset,
total_bytes,
self._cuda_device_id,
)
],
self._nixl_memory_type,
)
local_handle = self._nixl_wrapper.prep_xfer_dlist(
"NIXL_INIT_AGENT",
local_desc,
)
remote_base, remote_part_bytes, remote_dev = self._remote_send_meta[src]
agent_name = self._remote_agents[src]
remote_desc = self._nixl_wrapper.get_xfer_descs(
[
(
remote_base + self._rank * remote_part_bytes,
total_bytes,
remote_dev,
)
],
self._nixl_memory_type,
)
remote_handle = self._nixl_wrapper.prep_xfer_dlist(
agent_name,
remote_desc,
)
xfer_handle = self._nixl_wrapper.make_prepped_xfer(
"READ",
local_handle,
[0],
remote_handle,
[0],
)
self._xfer_cache[key] = (local_handle, remote_handle, xfer_handle)
return xfer_handle
def execute(self) -> None:
xfer_handles: list[int] = []
try:
# Phase 1: pack send buffers.
with torch.cuda.stream(self._cuda_stream):
for dst in range(self._world_size):
byte_offset = dst * self._peer_partition_bytes
for dtype in self._dtypes:
peer_tensors = self._send_tensors.get(
dtype, [[] for _ in range(self._world_size)]
)[dst]
actual_bytes = sum(
t.numel() * t.element_size() for t in peer_tensors
)
if actual_bytes > self._dtype_max_bytes[dtype]:
raise RuntimeError(
"NIXL EPLB send overflow for dtype "
f"{dtype}: peer={dst}, "
f"required={actual_bytes}, "
f"capacity={self._dtype_max_bytes[dtype]}"
)
byte_offset = self._pack_send_buffer(
peer_tensors,
self._send_buffer,
byte_offset,
)
# Ensure all packed data is visible in device memory before pulls.
if self._cuda_stream is not None:
self._cuda_stream.synchronize()
else:
torch.cuda.current_stream().synchronize()
# READ is receiver-initiated; synchronize all ranks before transfer.
# We use monitored_barrier so a rank that crashes or exits early
# produces a diagnostic timeout instead of a silent hang.
torch.distributed.monitored_barrier(
group=self._cpu_group,
timeout=timedelta(minutes=5),
)
# Phase 2: look up or create descriptors and issue all READs.
# Data from all peers is packed sequentially into the single
# partition-sized recv buffer at running offsets.
recv_offsets: dict[int, int] = {}
recv_offset = 0
for src in range(self._world_size):
if src == self._rank:
continue
actual_total_bytes = 0
for dtype in self._dtypes:
peer_tensors = self._recv_tensors.get(
dtype, [[] for _ in range(self._world_size)]
)[src]
actual_total_bytes += sum(
t.numel() * t.element_size() for t in peer_tensors
)
if actual_total_bytes == 0:
continue
recv_offsets[src] = recv_offset
xfer_handle = self._get_or_create_xfer(
src, actual_total_bytes, recv_offset
)
self._nixl_wrapper.transfer(xfer_handle)
xfer_handles.append(xfer_handle)
recv_offset += actual_total_bytes
# Phase 3: single wait for all in-flight transfers, then unpack.
self._wait_for_all_transfers(xfer_handles)
with torch.cuda.stream(self._cuda_stream):
for src, offset in recv_offsets.items():
byte_offset = offset
for dtype in self._dtypes:
peer_tensors = self._recv_tensors.get(
dtype, [[] for _ in range(self._world_size)]
)[src]
byte_offset = self._unpack_recv_buffer(
self._recv_buffer,
peer_tensors,
byte_offset,
)
except Exception:
self._release_all_cached_handles()
raise
finally:
self._send_tensors.clear()
self._recv_tensors.clear()
def __del__(self) -> None:
try:
self._release_all_cached_handles()
if self._registered_desc is not None:
self._nixl_wrapper.deregister_memory(self._registered_desc)
self._registered_desc = None
for agent_name in self._remote_agents.values():
self._nixl_wrapper.remove_remote_agent(agent_name)
self._remote_agents.clear()
except Exception as e:
logger.warning("Error during NixlEplbCommunicator cleanup: %s", e)
class PyNcclEplbCommunicator(EplbCommunicator): class PyNcclEplbCommunicator(EplbCommunicator):
"""EPLB communicator backed by PyNcclCommunicator using ncclSend/ncclRecv.""" """EPLB communicator backed by PyNcclCommunicator using ncclSend/ncclRecv."""
...@@ -204,6 +607,24 @@ def create_eplb_communicator( ...@@ -204,6 +607,24 @@ def create_eplb_communicator(
backend: str | None, backend: str | None,
expert_weights: Sequence[torch.Tensor], expert_weights: Sequence[torch.Tensor],
) -> EplbCommunicator: ) -> EplbCommunicator:
"""Create an EPLB communicator for the given backend.
Args:
group_coordinator: Process-group coordinator that provides the
device and CPU communication groups.
backend: Communicator backend name (``"torch_nccl"``,
``"torch_gloo"``, ``"pynccl"``, or ``"nixl"``).
Falls back to ``"torch_nccl"`` when *None*.
Stateless (elastic EP) groups only support ``"torch_nccl"``
and ``"pynccl"``; ``"torch_nccl"`` is silently promoted to
``"pynccl"`` in that case. When tensors reside on CPU,
``"torch_gloo"`` or ``"torch_nccl"`` are used via the CPU
process group.
expert_weights: Expert weight tensors from *one* MoE layer.
NixlEplbCommunicator pre-allocates send/recv buffers sized
to this layer, so all other MoE layers must have the same
tensor count, shapes, and dtypes.
"""
# Keep a safe default for callers that have not resolved communicator yet. # Keep a safe default for callers that have not resolved communicator yet.
if backend is None: if backend is None:
backend = "torch_nccl" backend = "torch_nccl"
...@@ -256,7 +677,7 @@ def create_eplb_communicator( ...@@ -256,7 +677,7 @@ def create_eplb_communicator(
if backend not in ("torch_nccl", "pynccl"): if backend not in ("torch_nccl", "pynccl"):
raise ValueError( raise ValueError(
f"Elastic EP requires 'torch_nccl' or 'pynccl' EPLB communicator " f"Elastic EP requires 'torch_nccl' or 'pynccl' EPLB communicator "
f"(got '{backend}'). torch_gloo is not supported with stateless groups." f"(got '{backend}')."
) )
if backend == "torch_nccl": if backend == "torch_nccl":
logger.warning( logger.warning(
...@@ -266,7 +687,26 @@ def create_eplb_communicator( ...@@ -266,7 +687,26 @@ def create_eplb_communicator(
backend = "pynccl" backend = "pynccl"
return _create_pynccl() return _create_pynccl()
if backend == "torch_gloo": if backend == "nixl":
if not has_nixl():
raise RuntimeError(
"EPLB communicator 'nixl' requested but NIXL is unavailable."
)
if not (current_platform.is_cuda_alike() and tensor_device_type != "cpu"):
raise RuntimeError(
"EPLB communicator 'nixl' supports only cuda-like devices "
f"(got {tensor_device_type})."
)
try:
return NixlEplbCommunicator(
cpu_group=group_coordinator.cpu_group,
expert_weights=expert_weights,
)
except Exception as exc:
raise RuntimeError(
f"Failed to initialize NixlEplbCommunicator ({exc})."
) from exc
elif backend == "torch_gloo":
return TorchDistGlooStagedEplbCommunicator( return TorchDistGlooStagedEplbCommunicator(
cpu_group=group_coordinator.cpu_group, cpu_group=group_coordinator.cpu_group,
) )
......
...@@ -541,14 +541,15 @@ def rearrange_expert_weights_inplace( ...@@ -541,14 +541,15 @@ def rearrange_expert_weights_inplace(
assert num_physical_experts == ep_size * num_local_physical_experts assert num_physical_experts == ep_size * num_local_physical_experts
first_layer_weights = list(expert_weights[0]) first_layer_weights = list(expert_weights[0])
# Buffers to hold the expert weights during the exchange.
# NOTE: Currently we assume the same weights across different layers if is_profile:
# have the same shape. if communicator.needs_profile_buffer_reservation:
# Reserve NCCL communication buffers via a dummy all_gather.
# Backends that pre-allocate their own transfer buffers
# skip this to avoid the extra memory spike during profiling.
weights_buffer: list[torch.Tensor] = [ weights_buffer: list[torch.Tensor] = [
torch.empty_like(w) for w in first_layer_weights torch.empty_like(w) for w in first_layer_weights
] ]
if is_profile:
# Reserve communication buffers via a minimal dummy all_gather on first layer
for weight, buffer in zip(expert_weights[0], weights_buffer): for weight, buffer in zip(expert_weights[0], weights_buffer):
dummy_recv_buffer = [buffer for _ in range(ep_size)] dummy_recv_buffer = [buffer for _ in range(ep_size)]
torch.distributed.barrier() torch.distributed.barrier()
...@@ -559,6 +560,11 @@ def rearrange_expert_weights_inplace( ...@@ -559,6 +560,11 @@ def rearrange_expert_weights_inplace(
) )
return return
# Buffers to hold the expert weights during the exchange.
# NOTE: Currently we assume the same weights across different layers
# have the same shape.
weights_buffer = [torch.empty_like(w) for w in first_layer_weights]
# NOTE(bowen): We need this synchronize to run, but I don't know why. # NOTE(bowen): We need this synchronize to run, but I don't know why.
# If you figure out the reason, please let me know -- thank you! # If you figure out the reason, please let me know -- thank you!
torch.accelerator.synchronize() torch.accelerator.synchronize()
......
...@@ -15,9 +15,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( ...@@ -15,9 +15,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
PromMetric, PromMetric,
PromMetricT, PromMetricT,
) )
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.utils import ( from vllm.distributed.nixl_utils import nixlXferTelemetry
nixlXferTelemetry,
)
from vllm.v1.metrics.utils import create_metric_per_engine from vllm.v1.metrics.utils import create_metric_per_engine
......
...@@ -3,60 +3,14 @@ ...@@ -3,60 +3,14 @@
"""Shared constants, lazy imports and helpers for the NIXL connector.""" """Shared constants, lazy imports and helpers for the NIXL connector."""
import contextlib import contextlib
import os
import sys
from collections.abc import Iterator from collections.abc import Iterator
from typing import Any from typing import Any
import zmq import zmq
from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.network_utils import make_zmq_socket from vllm.utils.network_utils import make_zmq_socket
logger = init_logger(__name__)
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
try:
if "UCX_RCACHE_MAX_UNRELEASED" not in os.environ:
# avoid a memory leak in UCX when using NIXL on some models
# see: https://github.com/vllm-project/vllm/issues/24264
if "nixl" in sys.modules or "rixl" in sys.modules:
logger.warning(
"NIXL was already imported, we can't reset UCX_RCACHE_MAX_UNRELEASED. "
"Please set it to '1024' manually."
)
else:
logger.info(
"Setting UCX_RCACHE_MAX_UNRELEASED to '1024' to avoid a rare "
"memory leak in UCX when using NIXL."
)
os.environ["UCX_RCACHE_MAX_UNRELEASED"] = "1024"
if not current_platform.is_rocm():
from nixl._api import nixl_agent as NixlWrapper
from nixl._bindings import nixlXferTelemetry
else:
from rixl._api import nixl_agent as NixlWrapper
from rixl._bindings import nixlXferTelemetry
logger.info("NIXL is available")
except ImportError:
logger.warning("NIXL is not available")
NixlWrapper = None
nixlXferTelemetry = None
try:
if not current_platform.is_rocm():
from nixl._api import nixl_agent_config
else:
from rixl._api import nixl_agent_config
except ImportError:
nixl_agent_config = None
logger.warning("NIXL agent config is not available")
# Supported platforms and types of kv transfer buffer. # Supported platforms and types of kv transfer buffer.
# {device: tuple of supported kv buffer types} # {device: tuple of supported kv buffer types}
_NIXL_SUPPORTED_DEVICE = { _NIXL_SUPPORTED_DEVICE = {
......
...@@ -45,8 +45,6 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl.stats import ( ...@@ -45,8 +45,6 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl.stats import (
) )
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.utils import ( from vllm.distributed.kv_transfer.kv_connector.v1.nixl.utils import (
_NIXL_SUPPORTED_DEVICE, _NIXL_SUPPORTED_DEVICE,
NixlWrapper,
nixl_agent_config,
zmq_ctx, zmq_ctx,
) )
from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import ( from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import (
...@@ -54,6 +52,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import ...@@ -54,6 +52,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import
compute_mamba_phys_ratio, compute_mamba_phys_ratio,
derive_mamba_conv_split, derive_mamba_conv_split,
) )
from vllm.distributed.nixl_utils import NixlWrapper, nixl_agent_config
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import sys
from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
if "UCX_RCACHE_MAX_UNRELEASED" not in os.environ:
if "nixl" in sys.modules or "rixl" in sys.modules:
logger.warning_once(
"NIXL was already imported, we can't reset "
"UCX_RCACHE_MAX_UNRELEASED. "
"Please set it to '1024' manually."
)
else:
logger.info_once(
"Setting UCX_RCACHE_MAX_UNRELEASED to '1024' to avoid a rare "
"memory leak in UCX when using NIXL."
)
os.environ["UCX_RCACHE_MAX_UNRELEASED"] = "1024"
try:
if current_platform.is_cuda():
from nixl._api import nixl_agent as NixlWrapper
else:
from rixl._api import nixl_agent as NixlWrapper
logger.info_once("NIXL is available")
except ImportError:
logger.warning_once("NIXL is not available")
NixlWrapper = None # type: ignore[assignment, misc]
try:
if current_platform.is_cuda():
from nixl._api import nixl_agent_config
else:
from rixl._api import nixl_agent_config
except ImportError:
nixl_agent_config = None # type: ignore[assignment]
logger.warning_once("NIXL agent config is not available")
try:
if current_platform.is_cuda():
from nixl._bindings import nixlXferTelemetry
else:
from rixl._bindings import nixlXferTelemetry
except ImportError:
nixlXferTelemetry = None # type: ignore[assignment, misc]
__all__ = ["NixlWrapper", "nixl_agent_config", "nixlXferTelemetry"]
...@@ -4821,6 +4821,23 @@ class GPUModelRunner( ...@@ -4821,6 +4821,23 @@ class GPUModelRunner(
) )
self.model.set_aux_hidden_state_layers(aux_layers) self.model.set_aux_hidden_state_layers(aux_layers)
if (
is_mixture_of_experts(self.model)
and self.parallel_config.enable_eplb
and not load_dummy_weights
):
logger.info_once(
"EPLB is enabled for model %s.",
self.model_config.model,
)
assert self.eplb_state is not None
self.eplb_state.add_model(
self.model,
self.model_config,
)
eplb_models += 1
time_after_load = time.perf_counter() time_after_load = time.perf_counter()
self.model_memory_usage = m.consumed_memory self.model_memory_usage = m.consumed_memory
except torch.cuda.OutOfMemoryError as e: except torch.cuda.OutOfMemoryError as e:
...@@ -4860,14 +4877,9 @@ class GPUModelRunner( ...@@ -4860,14 +4877,9 @@ class GPUModelRunner(
is_mixture_of_experts(self.model) is_mixture_of_experts(self.model)
and self.parallel_config.enable_eplb and self.parallel_config.enable_eplb
and not load_dummy_weights and not load_dummy_weights
and self.eplb_state is not None
and self.eplb_state.is_async
): ):
logger.info_once("EPLB is enabled for model %s.", self.model_config.model)
assert self.eplb_state is not None
self.eplb_state.add_model(
self.model,
self.model_config,
)
if self.eplb_state.is_async:
self.eplb_state.start_async_loop() self.eplb_state.start_async_loop()
if ( if (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment