Commit 38d80967 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc2' into v0.10.2rc2-ori

parents 33650733 880c741b
...@@ -36,8 +36,8 @@ CUSTOM_ALL_REDUCE_MAX_SIZES = { ...@@ -36,8 +36,8 @@ CUSTOM_ALL_REDUCE_MAX_SIZES = {
"10.0": { "10.0": {
2: 2 * MiB, # 2 MB 2: 2 * MiB, # 2 MB
4: 2 * MiB, # 2 MB 4: 2 * MiB, # 2 MB
6: 2 * MiB, # 2 MB 6: 1 * MiB, # 1 MB
8: 2 * MiB, # 2 MB 8: 1 * MiB, # 1 MB
} }
} }
......
...@@ -252,7 +252,10 @@ class DeviceCommunicatorBase: ...@@ -252,7 +252,10 @@ class DeviceCommunicatorBase:
moe_modules = [ moe_modules = [
module for module in model.modules() module for module in model.modules()
if module.__class__.__name__ == "FusedMoE" # TODO(bnell): Should use isinstance but can't. Maybe search for
# presence of quant_method.init_prepare_finalize?
if (module.__class__.__name__ == "FusedMoE"
or module.__class__.__name__ == "SharedFusedMoE")
] ]
for module in moe_modules: for module in moe_modules:
module.quant_method.init_prepare_finalize(module) module.quant_method.init_prepare_finalize(module)
......
...@@ -57,11 +57,19 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -57,11 +57,19 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.ca_comm: Optional[CustomAllreduce] = None self.ca_comm: Optional[CustomAllreduce] = None
self.qr_comm: Optional[QuickAllReduce] = None self.qr_comm: Optional[QuickAllReduce] = None
self.symm_mem_comm: Optional[SymmMemCommunicator] = None self.symm_mem_comm: Optional[SymmMemCommunicator] = None
if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda():
self.symm_mem_comm = SymmMemCommunicator(
group=self.cpu_group,
device=self.device,
)
if use_custom_allreduce and self.world_size > 1: if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation. # Initialize a custom fast all-reduce implementation.
self.ca_comm = CustomAllreduce( self.ca_comm = CustomAllreduce(
group=self.cpu_group, group=self.cpu_group,
device=self.device, device=self.device,
symm_mem_enabled=(self.symm_mem_comm is not None
and not self.symm_mem_comm.disabled),
) )
if current_platform.is_rocm(): if current_platform.is_rocm():
...@@ -72,11 +80,6 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -72,11 +80,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
# currently be an MI300 series. # currently be an MI300 series.
self.qr_comm = QuickAllReduce(group=self.cpu_group, self.qr_comm = QuickAllReduce(group=self.cpu_group,
device=self.device) device=self.device)
if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda():
self.symm_mem_comm = SymmMemCommunicator(
group=self.cpu_group,
device=self.device,
)
if self.use_all2all: if self.use_all2all:
all2all_backend = envs.VLLM_ALL2ALL_BACKEND all2all_backend = envs.VLLM_ALL2ALL_BACKEND
......
...@@ -54,13 +54,14 @@ class CustomAllreduce: ...@@ -54,13 +54,14 @@ class CustomAllreduce:
def __init__(self, def __init__(self,
group: ProcessGroup, group: ProcessGroup,
device: Union[int, str, torch.device], device: Union[int, str, torch.device],
max_size=8192 * 1024) -> None: max_size=8192 * 1024,
symm_mem_enabled=False) -> None:
""" """
Args: Args:
group: the process group to work on. If None, it will use the group: the process group to work on. If None, it will use the
default process group. default process group.
device: the device to bind the CustomAllreduce to. If None, device: the device to bind the CustomAllreduce to. If None,
it will be bind to f"cuda:{local_rank}". it will be bound to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group is bind to a unique device, and all communicators in this group
are in the same node. are in the same node.
...@@ -111,7 +112,7 @@ class CustomAllreduce: ...@@ -111,7 +112,7 @@ class CustomAllreduce:
self.device = device self.device = device
device_capability = current_platform.get_device_capability( device_capability = current_platform.get_device_capability(
).as_version_str() ).as_version_str()
if (current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM if (current_platform.is_cuda() and symm_mem_enabled
and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES): and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES):
max_size = min( max_size = min(
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size], CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size],
...@@ -159,7 +160,7 @@ class CustomAllreduce: ...@@ -159,7 +160,7 @@ class CustomAllreduce:
self.disabled = False self.disabled = False
# Buffers memory are owned by this Python class and passed to C++. # Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a # Metadata composes of two parts: metadata for synchronization and a
# temporary buffer for storing intermediate allreduce results. # temporary buffer for storing intermediate allreduce results.
self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size, self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size,
group=group, group=group,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.distributed.device_communicators.base_device_communicator import (
DeviceCommunicatorBase)
from vllm.platforms import current_platform
if current_platform.is_neuron():
import torch_xla.core.xla_model as xm
class NeuronCommunicator(DeviceCommunicatorBase):
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
return xm.all_reduce(xm.REDUCE_SUM, x)
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
assert dim == -1, "Neuron only supports dim=-1 for all-gather."
return xm.all_gather(x, dim=dim)
...@@ -31,7 +31,7 @@ class PyNcclCommunicator: ...@@ -31,7 +31,7 @@ class PyNcclCommunicator:
group: the process group to work on. If None, it will use the group: the process group to work on. If None, it will use the
default process group. default process group.
device: the device to bind the PyNcclCommunicator to. If None, device: the device to bind the PyNcclCommunicator to. If None,
it will be bind to f"cuda:{local_rank}". it will be bound to f"cuda:{local_rank}".
library_path: the path to the NCCL library. If None, it will library_path: the path to the NCCL library. If None, it will
use the default library path. use the default library path.
It is the caller's responsibility to make sure each communicator It is the caller's responsibility to make sure each communicator
......
...@@ -78,7 +78,7 @@ class QuickAllReduce: ...@@ -78,7 +78,7 @@ class QuickAllReduce:
group: the process group to work on. If None, it will use the group: the process group to work on. If None, it will use the
default process group. default process group.
device: the device to bind the CustomAllreduce to. If None, device: the device to bind the CustomAllreduce to. If None,
it will be bind to f"cuda:{local_rank}". it will be bound to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group is bind to a unique device, and all communicators in this group
are in the same node. are in the same node.
......
...@@ -186,7 +186,7 @@ class RayPPCommunicator(Communicator): ...@@ -186,7 +186,7 @@ class RayPPCommunicator(Communicator):
""" """
Receive a torch.Tensor from a peer and synchronize the current stream. Receive a torch.Tensor from a peer and synchronize the current stream.
After this call returns, the receive buffer is safe to read from from After this call returns, the receive buffer is safe to read from
any stream. An RayChannelError will be raised if an error occurred any stream. An RayChannelError will be raised if an error occurred
(e.g., remote actor died), and the buffer is not safe to read. (e.g., remote actor died), and the buffer is not safe to read.
......
...@@ -27,8 +27,13 @@ class SymmMemCommunicator: ...@@ -27,8 +27,13 @@ class SymmMemCommunicator:
"10.0": [6, 8], "10.0": [6, 8],
} }
def __init__(self, group: ProcessGroup, device: Union[int, str, def __init__(
torch.device]): self,
group: ProcessGroup,
device: Union[int, str, torch.device],
# add options for testing
force_multimem: Optional[bool] = None,
max_size_override: Optional[int] = None):
self.disabled = True self.disabled = True
if not symm_mem_available: if not symm_mem_available:
...@@ -64,8 +69,17 @@ class SymmMemCommunicator: ...@@ -64,8 +69,17 @@ class SymmMemCommunicator:
self.world_size, self.world_size,
) )
return return
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][ # Use override max_size if provided, otherwise use default
self.world_size] if max_size_override is not None:
self.max_size = max_size_override
logger.info(
"SymmMemCommunicator: Using override max_size: %s bytes",
self.max_size,
)
else:
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[
self.device_capability][self.world_size]
self.buffer = torch_symm_mem.empty( self.buffer = torch_symm_mem.empty(
self.max_size // self.dtype.itemsize, self.max_size // self.dtype.itemsize,
device=self.device, device=self.device,
...@@ -76,6 +90,7 @@ class SymmMemCommunicator: ...@@ -76,6 +90,7 @@ class SymmMemCommunicator:
logger.warning("SymmMemCommunicator: symmetric memory " logger.warning("SymmMemCommunicator: symmetric memory "
"multicast operations are not supported.") "multicast operations are not supported.")
return return
self.force_multimem = force_multimem
self.disabled = False self.disabled = False
def should_use_symm_mem(self, inp: torch.Tensor): def should_use_symm_mem(self, inp: torch.Tensor):
...@@ -98,8 +113,18 @@ class SymmMemCommunicator: ...@@ -98,8 +113,18 @@ class SymmMemCommunicator:
if out is None: if out is None:
out = torch.empty_like(inp) out = torch.empty_like(inp)
self.buffer[:inp.numel()].copy_(inp.view(-1)) self.buffer[:inp.numel()].copy_(inp.view(-1))
if self.world_size in self._WORLD_SIZES_MULTIMEM[
self.device_capability]: # Determine which algorithm to use
use_multimem = False
if self.force_multimem is not None:
# Test override: use forced setting
use_multimem = self.force_multimem
else:
# Normal logic: use multimem for supported world sizes
use_multimem = self.world_size in self._WORLD_SIZES_MULTIMEM[
self.device_capability]
if use_multimem:
torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()], torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()],
"sum", "sum",
self.group.group_name) self.group.group_name)
......
...@@ -14,8 +14,9 @@ from typing import Any, Callable, Optional, Union ...@@ -14,8 +14,9 @@ from typing import Any, Callable, Optional, Union
import msgspec import msgspec
import zmq import zmq
from vllm.config import KVEventsConfig from vllm.config.kv_events import KVEventsConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import ExternalBlockHash
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -44,8 +45,8 @@ MEDIUM_GPU = "GPU" ...@@ -44,8 +45,8 @@ MEDIUM_GPU = "GPU"
class BlockStored(KVCacheEvent): class BlockStored(KVCacheEvent):
block_hashes: list[int] block_hashes: list[ExternalBlockHash]
parent_block_hash: Optional[int] parent_block_hash: Optional[ExternalBlockHash]
token_ids: list[int] token_ids: list[int]
block_size: int block_size: int
lora_id: Optional[int] lora_id: Optional[int]
...@@ -53,7 +54,7 @@ class BlockStored(KVCacheEvent): ...@@ -53,7 +54,7 @@ class BlockStored(KVCacheEvent):
class BlockRemoved(KVCacheEvent): class BlockRemoved(KVCacheEvent):
block_hashes: list[int] block_hashes: list[ExternalBlockHash]
medium: Optional[str] medium: Optional[str]
......
...@@ -2,11 +2,12 @@ ...@@ -2,11 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.distributed.kv_transfer.kv_transfer_state import ( from vllm.distributed.kv_transfer.kv_transfer_state import (
KVConnectorBaseType, ensure_kv_transfer_initialized, get_kv_transfer_group, KVConnectorBaseType, ensure_kv_transfer_initialized,
has_kv_transfer_group, is_v1_kv_transfer_group) ensure_kv_transfer_shutdown, get_kv_transfer_group, has_kv_transfer_group,
is_v1_kv_transfer_group)
__all__ = [ __all__ = [
"get_kv_transfer_group", "has_kv_transfer_group", "get_kv_transfer_group", "has_kv_transfer_group",
"is_v1_kv_transfer_group", "ensure_kv_transfer_initialized", "is_v1_kv_transfer_group", "ensure_kv_transfer_initialized",
"KVConnectorBaseType" "ensure_kv_transfer_shutdown", "KVConnectorBaseType"
] ]
...@@ -14,7 +14,8 @@ from vllm.logger import init_logger ...@@ -14,7 +14,8 @@ from vllm.logger import init_logger
# yapf: enable # yapf: enable
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import KVTransferConfig, VllmConfig from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -6,7 +6,7 @@ KV cache helper for store. ...@@ -6,7 +6,7 @@ KV cache helper for store.
from collections import defaultdict from collections import defaultdict
from collections.abc import Sequence from collections.abc import Sequence
from concurrent.futures import CancelledError, Future from concurrent.futures import CancelledError, Future
from typing import Optional, cast from typing import Literal, Optional, Union, cast
import torch import torch
...@@ -196,3 +196,51 @@ class KVOutputAggregator: ...@@ -196,3 +196,51 @@ class KVOutputAggregator:
output_future.add_done_callback(make_callback(i)) output_future.add_done_callback(make_callback(i))
return result_future return result_future
def _make_src_and_dst_indices(
src_block_ids: list[int],
dst_block_ids: list[int],
src_device: Union[torch.device, str],
dst_device: Union[torch.device, str],
) -> tuple[torch.Tensor, torch.Tensor]:
src_indices = torch.tensor(src_block_ids,
device=src_device,
dtype=torch.int64)
dst_indices = torch.tensor(dst_block_ids,
device=dst_device,
dtype=torch.int64)
return src_indices, dst_indices
def copy_kv_blocks(
src_kv_caches: dict[str, torch.Tensor],
dst_kv_caches: dict[str, torch.Tensor],
src_block_ids: list[int],
dst_block_ids: list[int],
direction: Literal["h2d", "d2h"],
) -> None:
"""Copy kv blocks between different buffers."""
if not src_kv_caches or not dst_kv_caches or \
not src_block_ids or not dst_block_ids or \
len(src_block_ids) != len(dst_block_ids):
return
src_device = next(iter(src_kv_caches.values())).device
dst_device = next(iter(dst_kv_caches.values())).device
src_indices, dst_indices = _make_src_and_dst_indices(
src_block_ids=src_block_ids,
dst_block_ids=dst_block_ids,
src_device=src_device,
dst_device=dst_device)
from vllm.platforms import current_platform
if direction == "h2d":
copy_fn = current_platform.insert_blocks_to_device
else:
copy_fn = current_platform.swap_out_blocks_to_host
for layer_name in src_kv_caches:
src_tensor = src_kv_caches[layer_name]
dst_tensor = dst_kv_caches[layer_name]
copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
...@@ -226,6 +226,14 @@ class KVConnectorBase_V1(ABC): ...@@ -226,6 +226,14 @@ class KVConnectorBase_V1(ABC):
""" """
return None, None return None, None
def shutdown(self):
"""
Shutdown the connector. This is called when the worker process
is shutting down to ensure that all the async operations are
completed and the connector is cleaned up properly.
"""
return None
# ============================== # ==============================
# Scheduler-side methods # Scheduler-side methods
# ============================== # ==============================
...@@ -235,7 +243,7 @@ class KVConnectorBase_V1(ABC): ...@@ -235,7 +243,7 @@ class KVConnectorBase_V1(ABC):
self, self,
request: "Request", request: "Request",
num_computed_tokens: int, num_computed_tokens: int,
) -> tuple[int, bool]: ) -> tuple[Optional[int], bool]:
""" """
Get number of new tokens that can be loaded from the Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens. external KV cache beyond the num_computed_tokens.
...@@ -247,8 +255,11 @@ class KVConnectorBase_V1(ABC): ...@@ -247,8 +255,11 @@ class KVConnectorBase_V1(ABC):
Returns: Returns:
A tuple with the following elements: A tuple with the following elements:
- The number of tokens that can be loaded from the - An optional number of tokens that can be loaded from the
external KV cache beyond what is already computed. external KV cache beyond what is already computed.
If None, it means that the connector needs more time to
determine the number of matched tokens, and the scheduler
should query for this request again later.
- `True` if external KV cache tokens will be loaded - `True` if external KV cache tokens will be loaded
asynchronously (between scheduler steps). Must be asynchronously (between scheduler steps). Must be
'False' if the first element is 0. 'False' if the first element is 0.
......
...@@ -110,7 +110,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1): ...@@ -110,7 +110,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
self, self,
request: "Request", request: "Request",
num_computed_tokens: int, num_computed_tokens: int,
) -> tuple[int, bool]: ) -> tuple[Optional[int], bool]:
""" """
Get number of new tokens that can be loaded from the Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens. external KV cache beyond the num_computed_tokens.
......
...@@ -7,7 +7,8 @@ from typing import TYPE_CHECKING, Any, Optional ...@@ -7,7 +7,8 @@ from typing import TYPE_CHECKING, Any, Optional
import torch import torch
from vllm.config import KVTransferConfig, VllmConfig from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_events import KVCacheEvent from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.factory import ( from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory) KVConnectorFactory)
...@@ -87,6 +88,18 @@ class MultiConnector(KVConnectorBase_V1): ...@@ -87,6 +88,18 @@ class MultiConnector(KVConnectorBase_V1):
for c in self._connectors: for c in self._connectors:
c.clear_connector_metadata() c.clear_connector_metadata()
def shutdown(self):
exception: Optional[Exception] = None
for c in self._connectors:
try:
c.shutdown()
except Exception as e:
logger.exception("Exception during connector %s shutdown.",
c.__class__.__name__)
exception = e
if exception:
raise exception
# ============================== # ==============================
# Worker-side methods # Worker-side methods
# ============================== # ==============================
...@@ -142,11 +155,15 @@ class MultiConnector(KVConnectorBase_V1): ...@@ -142,11 +155,15 @@ class MultiConnector(KVConnectorBase_V1):
self, self,
request: "Request", request: "Request",
num_computed_tokens: int, num_computed_tokens: int,
) -> tuple[int, bool]: ) -> tuple[Optional[int], bool]:
to_return = (0, False) to_return = (0, False)
for i, c in enumerate(self._connectors): for i, c in enumerate(self._connectors):
toks, load_async = c.get_num_new_matched_tokens( toks, load_async = c.get_num_new_matched_tokens(
request, num_computed_tokens) request, num_computed_tokens)
# If there is a connector still looking up the matches,
# we return None to indicate that we are not done yet.
if toks is None:
return (None, False)
# The first connector that has new matched tokens will be assigned # The first connector that has new matched tokens will be assigned
# to this request. # to this request.
if to_return[0] == 0 and toks > 0: if to_return[0] == 0 and toks > 0:
......
...@@ -14,6 +14,7 @@ from dataclasses import dataclass ...@@ -14,6 +14,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
import msgspec import msgspec
import numpy as np
import torch import torch
import zmq import zmq
...@@ -60,6 +61,7 @@ except ImportError: ...@@ -60,6 +61,7 @@ except ImportError:
_NIXL_SUPPORTED_XPUS = { _NIXL_SUPPORTED_XPUS = {
"cuda": ("cuda", ), "cuda": ("cuda", ),
"tpu": ("cpu", ), "tpu": ("cpu", ),
"xpu": ("cpu", ),
} }
...@@ -160,7 +162,7 @@ class NixlConnector(KVConnectorBase_V1): ...@@ -160,7 +162,7 @@ class NixlConnector(KVConnectorBase_V1):
def get_num_new_matched_tokens( def get_num_new_matched_tokens(
self, request: "Request", self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]: num_computed_tokens: int) -> tuple[Optional[int], bool]:
assert self.connector_scheduler is not None assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens( return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens) request, num_computed_tokens)
...@@ -715,7 +717,7 @@ class NixlConnectorWorker: ...@@ -715,7 +717,7 @@ class NixlConnectorWorker:
# are non-contiguous (it's not locally guaranteed that they will be) # are non-contiguous (it's not locally guaranteed that they will be)
# Disadvantage is that the encoded NixlAgentMetadata is now larger # Disadvantage is that the encoded NixlAgentMetadata is now larger
# (roughly 8KB vs 5KB). # (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are transferred in the same tensor # Conversely for FlashInfer, K and V are registered in the same region
# to better exploit the memory layout (ie num_blocks is the first dim). # to better exploit the memory layout (ie num_blocks is the first dim).
split_k_and_v = not (self.use_mla or self._use_pallas_v1 split_k_and_v = not (self.use_mla or self._use_pallas_v1
or self._use_flashinfer) or self._use_flashinfer)
...@@ -758,12 +760,21 @@ class NixlConnectorWorker: ...@@ -758,12 +760,21 @@ class NixlConnectorWorker:
assert tensor_size_bytes % self.num_blocks == 0 assert tensor_size_bytes % self.num_blocks == 0
self.block_len = tensor_size_bytes // self.num_blocks self.block_len = tensor_size_bytes // self.num_blocks
self.slot_size_bytes = self.block_len // self.block_size self.slot_size_bytes = self.block_len // self.block_size
self.device_kv_caches = kv_caches
self.dst_num_blocks[self.engine_id] = self.num_blocks
if self._use_flashinfer: if self._use_flashinfer:
assert self.slot_size_bytes % 2 == 0 assert self.slot_size_bytes % 2 == 0
self.slot_size_bytes /= 2 self.slot_size_bytes /= 2
self.device_kv_caches = kv_caches
self.dst_num_blocks[self.engine_id] = self.num_blocks
# NOTE (NickLucche) When FlashInfer is used, memory is registered
# with joint KV for each block. This minimizes the overhead in
# registerMem allowing faster descs queries. In order to be able to
# split on kv_heads dim as required by heterogeneous TP, one must
# be able to index K/V separately. Hence we double the number
# of 'virtual' regions here and halve `block_len` below.
self.num_regions *= 2
kv_block_len = self.get_backend_aware_kv_block_len()
# Register local/src descr for NIXL xfer. # Register local/src descr for NIXL xfer.
blocks_data = [] blocks_data = []
for base_addr in seen_base_addresses: for base_addr in seen_base_addresses:
...@@ -776,8 +787,18 @@ class NixlConnectorWorker: ...@@ -776,8 +787,18 @@ class NixlConnectorWorker:
block_offset = block_id * self.block_len block_offset = block_id * self.block_len
addr = base_addr + block_offset addr = base_addr + block_offset
# (addr, len, device id) # (addr, len, device id)
# TODO: does device_id matter to DRAM? blocks_data.append((addr, kv_block_len, self.tp_rank))
blocks_data.append((addr, self.block_len, self.tp_rank))
if self._use_flashinfer:
# Separate and interleave K/V regions to maintain the same
# descs ordering. This is needed for selecting contiguous heads
# when split across TP ranks.
for block_id in range(self.num_blocks):
block_offset = block_id * self.block_len
addr = base_addr + block_offset
# Register addresses for V cache (K registered first).
v_addr = addr + kv_block_len
blocks_data.append((v_addr, kv_block_len, self.tp_rank))
logger.debug("Created %s blocks for src engine %s and rank %s", logger.debug("Created %s blocks for src engine %s and rank %s",
len(blocks_data), self.engine_id, self.tp_rank) len(blocks_data), self.engine_id, self.tp_rank)
...@@ -787,7 +808,7 @@ class NixlConnectorWorker: ...@@ -787,7 +808,7 @@ class NixlConnectorWorker:
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
"NIXL_INIT_AGENT", descs) "NIXL_INIT_AGENT", descs)
# TODO(mgoin): Hybrid memory allocator is currently diabled for # TODO(mgoin): Hybrid memory allocator is currently disabled for
# models with local attention (Llama 4). Can remove this once enabled. # models with local attention (Llama 4). Can remove this once enabled.
if self.vllm_config.model_config.hf_config.model_type == "llama4": if self.vllm_config.model_config.hf_config.model_type == "llama4":
from transformers import Llama4TextConfig from transformers import Llama4TextConfig
...@@ -903,7 +924,7 @@ class NixlConnectorWorker: ...@@ -903,7 +924,7 @@ class NixlConnectorWorker:
remote_block_size = nixl_agent_meta.block_len // ( remote_block_size = nixl_agent_meta.block_len // (
self.slot_size_bytes * tp_ratio) self.slot_size_bytes * tp_ratio)
if self._use_flashinfer: if self._use_flashinfer:
# Account for joint KV in FlashInfer. # With flashinfer, KV are sent in the same message.
remote_block_size //= 2 remote_block_size //= 2
if tp_ratio > 1: if tp_ratio > 1:
# Heterogeneous TP expects same kv_cache_layout. # Heterogeneous TP expects same kv_cache_layout.
...@@ -929,10 +950,10 @@ class NixlConnectorWorker: ...@@ -929,10 +950,10 @@ class NixlConnectorWorker:
# rank. With heterogeneous TP, prepare the descriptors by splitting the # rank. With heterogeneous TP, prepare the descriptors by splitting the
# P KV cache along kv_head dim, of D worker's kv_head size (D>P). # P KV cache along kv_head dim, of D worker's kv_head size (D>P).
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
# Only register the remote's descriptors if current rank pulls from it.
self.kv_caches_base_addr[ self.kv_caches_base_addr[
engine_id] = nixl_agent_meta.kv_caches_base_addr engine_id] = nixl_agent_meta.kv_caches_base_addr
rank_offset = self.tp_rank % tp_ratio * self.block_len \ kv_block_len = self.get_backend_aware_kv_block_len()
rank_offset = self.tp_rank % tp_ratio * kv_block_len \
if not (self.use_mla or is_kv_replicated) else 0 if not (self.use_mla or is_kv_replicated) else 0
# Register all remote blocks, but only the corresponding kv heads. # Register all remote blocks, but only the corresponding kv heads.
for base_addr in nixl_agent_meta.kv_caches_base_addr: for base_addr in nixl_agent_meta.kv_caches_base_addr:
...@@ -943,7 +964,16 @@ class NixlConnectorWorker: ...@@ -943,7 +964,16 @@ class NixlConnectorWorker:
# self.block_len == remote_block_len//tp_ratio bytes. # self.block_len == remote_block_len//tp_ratio bytes.
addr = base_addr + block_offset + rank_offset addr = base_addr + block_offset + rank_offset
# (addr, len, device id) # (addr, len, device id)
blocks_data.append((addr, self.block_len, remote_tp_rank)) blocks_data.append((addr, kv_block_len, remote_tp_rank))
if self._use_flashinfer:
# With FlashInfer index V separately to allow head splitting.
for block_id in range(nixl_agent_meta.num_blocks):
block_offset = block_id * nixl_agent_meta.block_len
addr = base_addr + block_offset + rank_offset
v_addr = addr + nixl_agent_meta.block_len // 2
blocks_data.append((v_addr, kv_block_len, remote_tp_rank))
logger.debug( logger.debug(
"Created %s blocks for dst engine %s with remote rank %s and " "Created %s blocks for dst engine %s with remote rank %s and "
"local rank %s", len(blocks_data), engine_id, remote_tp_rank, "local rank %s", len(blocks_data), engine_id, remote_tp_rank,
...@@ -1163,8 +1193,8 @@ class NixlConnectorWorker: ...@@ -1163,8 +1193,8 @@ class NixlConnectorWorker:
# workers will issue xfers to parts of the P worker remote kv caches. # workers will issue xfers to parts of the P worker remote kv caches.
# Get descs ids. # Get descs ids.
local_block_descs_ids: list[int] = [] local_block_descs_ids: np.ndarray
remote_block_descs_ids: list[int] = [] remote_block_descs_ids: np.ndarray
if not self.block_window_per_layer: if not self.block_window_per_layer:
# Default case: assume global attention # Default case: assume global attention
remote_block_descs_ids = self._get_block_descs_ids( remote_block_descs_ids = self._get_block_descs_ids(
...@@ -1174,6 +1204,8 @@ class NixlConnectorWorker: ...@@ -1174,6 +1204,8 @@ class NixlConnectorWorker:
else: else:
# TODO(mgoin): remove this once we have hybrid memory allocator # TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4) # Optimization for models with local attention (Llama 4)
local_descs_list = []
remote_descs_list = []
for layer_idx, block_window in enumerate( for layer_idx, block_window in enumerate(
self.block_window_per_layer): self.block_window_per_layer):
# For each layer: # For each layer:
...@@ -1193,8 +1225,11 @@ class NixlConnectorWorker: ...@@ -1193,8 +1225,11 @@ class NixlConnectorWorker:
layer_remote_desc_ids = self._get_block_descs_ids( layer_remote_desc_ids = self._get_block_descs_ids(
dst_engine_id, layer_remote_block_ids, layer_idx) dst_engine_id, layer_remote_block_ids, layer_idx)
local_block_descs_ids.extend(layer_local_desc_ids) local_descs_list.append(layer_local_desc_ids)
remote_block_descs_ids.extend(layer_remote_desc_ids) remote_descs_list.append(layer_remote_desc_ids)
local_block_descs_ids = np.concatenate(local_descs_list)
remote_block_descs_ids = np.concatenate(remote_descs_list)
assert len(local_block_descs_ids) == len(remote_block_descs_ids) assert len(local_block_descs_ids) == len(remote_block_descs_ids)
...@@ -1219,14 +1254,14 @@ class NixlConnectorWorker: ...@@ -1219,14 +1254,14 @@ class NixlConnectorWorker:
def _get_block_descs_ids(self, def _get_block_descs_ids(self,
engine_id: str, engine_id: str,
block_ids: list[int], block_ids: list[int],
layer_idx: Optional[int] = None) -> list[int]: layer_idx: Optional[int] = None) -> np.ndarray:
""" """
Get the descs ids for a set of block ids. Get the descs ids for a set of block ids.
If layer_idx is provided, we use the region_ids for the given layer. If layer_idx is provided, we use the region_ids for the given layer.
Otherwise, we use all regions. Otherwise, we use all regions.
""" """
if layer_idx is None: if layer_idx is None:
region_ids = range(self.num_regions) region_ids = np.arange(self.num_regions)
else: else:
assert layer_idx < self.num_layers assert layer_idx < self.num_layers
if self.num_layers < self.num_regions: if self.num_layers < self.num_regions:
...@@ -1234,20 +1269,35 @@ class NixlConnectorWorker: ...@@ -1234,20 +1269,35 @@ class NixlConnectorWorker:
# the regions are organized as [K0, V0, K1, V1, ...] # the regions are organized as [K0, V0, K1, V1, ...]
# and we select K_i and V_i # and we select K_i and V_i
assert 2 * self.num_layers == self.num_regions assert 2 * self.num_layers == self.num_regions
region_ids = range(2 * layer_idx, 2 * layer_idx + 2) region_ids = np.arange(2 * layer_idx, 2 * layer_idx + 2)
else: else:
# Otherwise, we assume we have MLA and select i-th layer # Otherwise, we assume we have MLA and select i-th layer
assert self.num_layers == self.num_regions assert self.num_layers == self.num_regions
region_ids = range(layer_idx, layer_idx + 1) region_ids = np.arange(layer_idx, layer_idx + 1)
num_blocks = self.dst_num_blocks[engine_id] num_blocks = self.dst_num_blocks[engine_id]
# Compute the desc ids for each block. # Compute the desc ids for each block.
descs_ids: list[int] = [] region_ids = region_ids[:, None]
for reg_id in region_ids: block_ids = np.array(block_ids)[None, :]
for block_id in block_ids: descs_ids = region_ids * num_blocks + block_ids
descs_ids.append(reg_id * num_blocks + block_id) return descs_ids.flatten()
return descs_ids
def get_backend_aware_kv_block_len(self):
"""
Get the block length for one K/V element (K and V have the same size).
For FA and other backends, this is equal to the length of the whole
block, as K and V are in separate regions.
For FlashInfer, this is half the length of the whole block, as K and V
share the same region.
"""
if self._use_flashinfer:
# For indexing only half (either just the K or V part).
block_len = self.block_len // 2
else:
block_len = self.block_len
return block_len
@contextlib.contextmanager @contextlib.contextmanager
......
...@@ -15,7 +15,7 @@ import msgpack ...@@ -15,7 +15,7 @@ import msgpack
import torch import torch
import zmq import zmq
from vllm.config import KVTransferConfig from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.device_communicators.pynccl_wrapper import ( from vllm.distributed.device_communicators.pynccl_wrapper import (
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum) NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum)
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501 from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import hashlib import hashlib
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Optional
import safetensors import safetensors
import torch import torch
...@@ -238,7 +238,7 @@ class SharedStorageConnector(KVConnectorBase_V1): ...@@ -238,7 +238,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
self, self,
request: "Request", request: "Request",
num_computed_tokens: int, num_computed_tokens: int,
) -> tuple[int, bool]: ) -> tuple[Optional[int], bool]:
""" """
Get number of new tokens that can be loaded from the Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens. external KV cache beyond the num_computed_tokens.
......
...@@ -13,7 +13,7 @@ import zmq ...@@ -13,7 +13,7 @@ import zmq
from safetensors.torch import load as safetensors_load from safetensors.torch import load as safetensors_load
from safetensors.torch import save as safetensors_save from safetensors.torch import save as safetensors_save
from vllm.config import KVTransferConfig from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import join_host_port, make_zmq_path, split_host_port from vllm.utils import join_host_port, make_zmq_path, split_host_port
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment