Commit a99300bd authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc1' into v0.10.2rc1-dev

parents cc3e01c7 5438967f
......@@ -352,7 +352,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
with num_lookahead_slots.
Args:
sequence_group (SequenceGroup): The sequence group to swap in.
seq_group (SequenceGroup): The sequence group to swap in.
num_lookahead_slots (int): Number of lookahead slots used in
speculative decoding, default to 0.
......@@ -405,8 +405,6 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
Args:
seq_group (SequenceGroup): The sequence group to swap out.
num_lookahead_slots (int): Number of lookahead slots used in
speculative decoding, default to 0.
Returns:
bool: Whether it's possible to swap out current sequence group.
......@@ -420,7 +418,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
swapping out the given sequence_group with num_lookahead_slots.
Args:
sequence_group (SequenceGroup): The sequence group to swap out.
seq_group (SequenceGroup): The sequence group to swap out.
Returns:
List[Tuple[int, int]]: The mapping of swapping block from
......@@ -473,7 +471,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
on to the 'device'.
Args:
sequence_group (SequenceGroup): The sequence group to swap in/out.
seq_group (SequenceGroup): The sequence group to swap in/out.
device (Device): device to swap the 'seq_group' on.
status (SequenceStatus): The status of sequence which is needed
for action. RUNNING for swap out and SWAPPED for swap in
......
......@@ -76,7 +76,7 @@ class LRUEvictor(Evictor):
that's recorded in the Block. If there are multiple blocks with
the same last_accessed time, then the one with the largest num_hashed_tokens
will be evicted. If two blocks each have the lowest last_accessed time and
highest num_hashed_tokens value, then one will be chose arbitrarily
highest num_hashed_tokens value, then one will be chosen arbitrarily
"""
# CLEANUP_THRESHOLD determines the maximum allowable size of the priority
......
......@@ -657,7 +657,7 @@ class Scheduler:
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
partial_prefill_metadata: information about the partial prefills
that are currently running
that are currently running
Returns:
SchedulerRunningOutputs.
......@@ -1591,7 +1591,6 @@ class Scheduler:
encoder_seq_data=encoder_seq_data,
cross_block_table=cross_block_table,
state=seq_group.state,
token_type_ids=seq_group.token_type_ids,
# `multi_modal_data` will only be present for the 1st comm
# between engine and worker.
# the subsequent comms can still use delta, but
......
......@@ -152,8 +152,13 @@ class CuMemAllocator:
self.pointer_to_data: dict[int, AllocationData] = {}
self.current_tag: str = CuMemAllocator.default_tag
self.allocator_and_pools: dict[str, Any] = {}
# Creating strong references to the two callbacks here to prevent
# these ephemeral bound-method objects being garbage collected.
# See discussions in https://github.com/vllm-project/vllm/pull/22724
self.python_malloc_callback = self._python_malloc_callback
self.python_free_callback = self._python_free_callback
def python_malloc_callback(self, allocation_handle: HandleType) -> None:
def _python_malloc_callback(self, allocation_handle: HandleType) -> None:
"""
Internal method to store the allocation data
when memory is allocated in the memory pool."""
......@@ -162,7 +167,7 @@ class CuMemAllocator:
allocation_handle, self.current_tag)
return
def python_free_callback(self, ptr: int) -> HandleType:
def _python_free_callback(self, ptr: int) -> HandleType:
"""
Internal method to look up the allocation data
when memory is freed in the memory pool."""
......@@ -212,9 +217,9 @@ class CuMemAllocator:
def wake_up(self, tags: Optional[list[str]] = None) -> None:
"""
Wake up the allocator from sleep mode.
All data that is previously offloaded will be loaded back to GPU
All data that is previously offloaded will be loaded back to GPU
memory, and the rest of the data will have empty memory.
:param tags: The tags of the memory allocation that will be loaded
back to GPU memory. If None, all memory allocation will be loaded
back to GPU memory.
......
......@@ -23,6 +23,39 @@ from vllm.utils import (cuda_device_count_stateless,
logger = init_logger(__name__)
MiB = 1024 * 1024
# Max size for each world size in case symmetric memory is available
# For different SM architectures
CUSTOM_ALL_REDUCE_MAX_SIZES = {
"9.0": {
2: 64 * MiB, # 64 MB
4: 32 * MiB, # 32 MB
6: MiB // 2, # 512 KB
8: MiB // 4, # 256 KB
},
"10.0": {
2: 2 * MiB, # 2 MB
4: 2 * MiB, # 2 MB
6: 2 * MiB, # 2 MB
8: 2 * MiB, # 2 MB
}
}
SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
"9.0": {
2: 64 * MiB, # 64 MB
4: 32 * MiB, # 32 MB
6: 64 * MiB, # 64 MB
8: 64 * MiB, # 64 MB
},
"10.0": {
2: 8 * MiB, # 8 MB
4: 32 * MiB, # 32 MB
6: 128 * MiB, # 128 MB
8: 128 * MiB, # 128 MB
}
}
def producer(batch_src: Sequence[int],
producer_queue,
......
......@@ -255,7 +255,7 @@ class DeviceCommunicatorBase:
if module.__class__.__name__ == "FusedMoE"
]
for module in moe_modules:
module.quant_method.init_prepare_finalize()
module.quant_method.init_prepare_finalize(module)
def dispatch(
self, hidden_states: torch.Tensor,
......
......@@ -44,6 +44,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
PyNcclCommunicator)
from vllm.distributed.device_communicators.quick_all_reduce import (
QuickAllReduce)
from vllm.distributed.device_communicators.symm_mem import (
SymmMemCommunicator)
self.pynccl_comm: Optional[PyNcclCommunicator] = None
if use_pynccl and self.world_size > 1:
......@@ -54,6 +56,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.ca_comm: Optional[CustomAllreduce] = None
self.qr_comm: Optional[QuickAllReduce] = None
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation.
self.ca_comm = CustomAllreduce(
......@@ -69,6 +72,12 @@ class CudaCommunicator(DeviceCommunicatorBase):
# currently be an MI300 series.
self.qr_comm = QuickAllReduce(group=self.cpu_group,
device=self.device)
if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda():
self.symm_mem_comm = SymmMemCommunicator(
group=self.cpu_group,
device=self.device,
)
if self.use_all2all:
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend == "naive":
......@@ -105,6 +114,12 @@ class CudaCommunicator(DeviceCommunicatorBase):
out = ca_comm.custom_all_reduce(input_)
assert out is not None
return out
symm_mem_comm = self.symm_mem_comm
if symm_mem_comm is not None and \
symm_mem_comm.should_use_symm_mem(input_):
out = symm_mem_comm.all_reduce(input_)
assert out is not None
return out
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
out = pynccl_comm.all_reduce(input_)
......@@ -137,7 +152,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
dtype=input_tensor.dtype,
device=input_tensor.device)
pynccl_comm.reduce_scatter(output, input_)
pynccl_comm.reduce_scatter(output, input_tensor)
# Reshape before returning
return output.movedim(0, dim).contiguous()
......@@ -171,9 +186,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
device=input_tensor.device)
if sizes is not None:
pynccl_comm.reduce_scatterv(output, input_, sizes=sizes)
pynccl_comm.reduce_scatterv(output, input_tensor, sizes=sizes)
else:
pynccl_comm.reduce_scatter(output, input_)
pynccl_comm.reduce_scatter(output, input_tensor)
# Reshape before returning
return output.movedim(0, dim).contiguous()
......
......@@ -11,8 +11,8 @@ from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
gpu_p2p_access_check)
from vllm.distributed.device_communicators.all_reduce_utils import (
CUSTOM_ALL_REDUCE_MAX_SIZES, gpu_p2p_access_check)
from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.logger import init_logger
from vllm.platforms import current_platform
......@@ -114,7 +114,13 @@ class CustomAllreduce:
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
device_capability = current_platform.get_device_capability(
).as_version_str()
if (current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM
and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES):
max_size = min(
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size],
max_size)
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(",")))
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from vllm.distributed.device_communicators.all_reduce_utils import (
SYMM_MEM_ALL_REDUCE_MAX_SIZES)
from vllm.logger import init_logger
from vllm.platforms import current_platform
try:
import torch.distributed._symmetric_memory as torch_symm_mem
symm_mem_available = True
except ImportError:
symm_mem_available = False
logger = init_logger(__name__)
class SymmMemCommunicator:
_WORLD_SIZES_MULTIMEM = {
"9.0": [4, 6, 8],
"10.0": [6, 8],
}
def __init__(self, group: ProcessGroup, device: Union[int, str,
torch.device]):
self.disabled = True
if not symm_mem_available:
return
if not current_platform.is_cuda():
logger.warning("SymmMemCommunicator: symmetric "
"memory is not available.")
return
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
torch.cuda.set_device(device)
self.dtype = torch.bfloat16
self.device = device
self.group = group
self.world_size = dist.get_world_size(self.group)
self.device_capability = current_platform.get_device_capability(
).as_version_str()
if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES:
logger.warning(
"SymmMemCommunicator: Device capability %s not supported, "
"communicator is not available.",
self.device_capability,
)
return
if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[
self.device_capability]:
logger.warning(
"SymmMemCommunicator: World size %d not supported, "
"communicator is not available.",
self.world_size,
)
return
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
self.world_size]
self.buffer = torch_symm_mem.empty(
self.max_size // self.dtype.itemsize,
device=self.device,
dtype=self.dtype,
)
handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
if handle.multicast_ptr == 0:
logger.warning("SymmMemCommunicator: symmetric memory "
"multicast operations are not supported.")
return
self.disabled = False
def should_use_symm_mem(self, inp: torch.Tensor):
if self.disabled:
return False
if inp.dtype != self.dtype:
return False
inp_size = inp.numel() * inp.element_size()
if inp_size % 4 != 0:
return False
return inp_size < self.max_size
def all_reduce(
self,
inp: torch.Tensor,
*,
out: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]:
if not self.should_use_symm_mem(inp):
return None
if out is None:
out = torch.empty_like(inp)
self.buffer[:inp.numel()].copy_(inp.view(-1))
if self.world_size in self._WORLD_SIZES_MULTIMEM[
self.device_capability]:
torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()],
"sum",
self.group.group_name)
else:
torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()],
"sum",
self.group.group_name)
out.copy_(self.buffer[:inp.numel()].view(out.shape))
return out
......@@ -10,6 +10,7 @@ from torch.distributed import ProcessGroup
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.platforms.tpu import USE_TPU_COMMONS
from .base_device_communicator import DeviceCommunicatorBase
......@@ -18,16 +19,17 @@ USE_RAY = parallel_config = get_current_vllm_config(
logger = init_logger(__name__)
if current_platform.is_tpu():
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla._internal import pjrt
from torch_xla.distributed.xla_multiprocessing import (
create_optimized_replica_groups)
if USE_RAY:
from vllm.executor import ray_utils
if not USE_TPU_COMMONS:
logger.info("tpu_commons not found, using vLLM's TpuCommunicator")
if current_platform.is_tpu():
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla._internal import pjrt
from torch_xla.distributed.xla_multiprocessing import (
create_optimized_replica_groups)
if USE_RAY:
from vllm.executor import ray_utils
class TpuCommunicator(DeviceCommunicatorBase):
......@@ -94,10 +96,7 @@ class TpuCommunicator(DeviceCommunicatorBase):
return xm.all_gather(input_, dim=dim)
try:
if USE_TPU_COMMONS:
from tpu_commons.distributed.device_communicators import (
TpuCommunicator as TpuCommonsCommunicator)
TpuCommunicator = TpuCommonsCommunicator # type: ignore
except ImportError:
logger.info("tpu_commons not found, using vLLM's TpuCommunicator")
pass
......@@ -7,8 +7,13 @@ import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm.logger import init_logger
from .base_device_communicator import DeviceCommunicatorBase
logger = init_logger(__name__)
class XpuCommunicator(DeviceCommunicatorBase):
......@@ -18,6 +23,12 @@ class XpuCommunicator(DeviceCommunicatorBase):
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name)
if self.use_all2all:
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend == "naive":
from .all2all import NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
logger.info("Using naive all2all manager.")
def all_reduce(self, input_) -> torch.Tensor:
dist.all_reduce(input_, group=self.device_group)
......
......@@ -244,7 +244,7 @@ class EplbState:
dtype=torch.int32,
device=device,
)
expert_load_window_size = parallel_config.eplb_window_size
expert_load_window_size = parallel_config.eplb_config.window_size
expert_load_window = torch.zeros(
(expert_load_window_size, model.num_moe_layers,
model.num_physical_experts),
......@@ -253,7 +253,7 @@ class EplbState:
)
# Set the initial progress of rearrangement to 3/4
eplb_step_interval = parallel_config.eplb_step_interval
eplb_step_interval = parallel_config.eplb_config.step_interval
expert_rearrangement_step = max(
0, eplb_step_interval - eplb_step_interval // 4)
......@@ -409,12 +409,14 @@ class EplbState:
self.expert_rearrangement_step = 0
self.rearrange(model)
def rearrange(self,
model: MixtureOfExperts,
is_profile: bool = False,
execute_shuffle: bool = True,
global_expert_load: Optional[torch.Tensor] = None,
rank_mapping: Optional[dict[int, int]] = None) -> None:
def rearrange(
self,
model: MixtureOfExperts,
is_profile: bool = False,
execute_shuffle: bool = True,
global_expert_load: Optional[torch.Tensor] = None,
rank_mapping: Optional[dict[int,
int]] = None) -> Optional[torch.Tensor]:
"""
Rearrange the experts according to the current load.
"""
......@@ -548,6 +550,7 @@ class EplbState:
" (profile) " if is_profile else " ",
time_end - time_start,
)
return None
@staticmethod
def recv_state() -> tuple[torch.Tensor, torch.Tensor]:
......@@ -613,4 +616,4 @@ def _node_count_with_rank_mapping(
if is_same_node and node_assignment[other_rank] == 0:
node_assignment[other_rank] = next_node_id
return next_node_id
\ No newline at end of file
return next_node_id
......@@ -40,16 +40,21 @@ class KVCacheEvent(
"""Base class for all KV cache-related events"""
MEDIUM_GPU = "GPU"
class BlockStored(KVCacheEvent):
block_hashes: list[int]
parent_block_hash: Optional[int]
token_ids: list[int]
block_size: int
lora_id: Optional[int]
medium: Optional[str]
class BlockRemoved(KVCacheEvent):
block_hashes: list[int]
medium: Optional[str]
class AllBlocksCleared(KVCacheEvent):
......
......@@ -2,7 +2,7 @@
# Distributed KV cache transfer
This folder implements distributed KV cache transfer across vLLM instances.
Currently the main usecase is for disaggregated prefilling.
Currently the main use case is for disaggregated prefilling.
## Abstractions
......@@ -14,7 +14,7 @@ The KV cache transfer contains three layer of abstractions:
Why we need KV lookup buffer: FIFO pipe itself is not enough as prefill vLLM worker may process requests in a different order compared to decode vLLM worker. Say the QPS is really high, prefill worker may handle requests in order A -> B -> C, but the decode worker may process request C first. This is not the case that can be naturally handled by FIFO pipe, so we provide KV lookup buffer to help translate a FIFO pipe to a lookup buffer.
NOTE: KV pipe layer is bypassible: you can skip this layer if your distributed
NOTE: KV pipe layer is bypassable: you can skip this layer if your distributed
communication service already supports key-value-based lookup (like redis or
RDMA database).
......
......@@ -19,6 +19,8 @@ The class provides the following primitives:
Returns whether KV cache should be freed now or will be
freed asynchronously and optionally returns KV transfer
params.
take_events() - returns new KV events that were collected
by the connector since the last call.
Worker-side: runs in each worker, loads/saves KV cache to/from
the Connector based on the metadata.
......@@ -34,6 +36,7 @@ The class provides the following primitives:
import enum
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional
import torch
......@@ -45,6 +48,7 @@ from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
......@@ -131,8 +135,8 @@ class KVConnectorBase_V1(ABC):
Initialize with the KV caches. Useful for pre-registering the
KV Caches in the KVConnector (e.g. for NIXL).
Args: kv_caches:
dictionary of layer names, kv cache
Args:
kv_caches: dictionary of layer names, kv cache
"""
return
......@@ -313,6 +317,15 @@ class KVConnectorBase_V1(ABC):
"""
return False, None
def take_events(self) -> Iterable["KVCacheEvent"]:
"""
Take the KV cache events from the connector.
Yields:
New KV cache events since the last call.
"""
return ()
@classmethod
def get_required_kvcache_layout(
cls, vllm_config: "VllmConfig") -> Optional[str]:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from collections.abc import Iterable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
import torch
from vllm.config import KVTransferConfig, VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
......@@ -208,6 +210,10 @@ class MultiConnector(KVConnectorBase_V1):
return async_saves > 0, kv_txfer_params
def take_events(self) -> Iterable[KVCacheEvent]:
for c in self._connectors:
yield from c.take_events()
@classmethod
def get_required_kvcache_layout(
cls, vllm_config: "VllmConfig") -> Optional[str]:
......
......@@ -686,9 +686,6 @@ class NixlConnectorWorker:
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in nixl."""
_, first_kv_cache = next(iter(kv_caches.items()))
kv_elem_size = first_kv_cache.element_size()
if self.use_host_buffer:
self.initialize_host_xfer_buffer(kv_caches=kv_caches)
assert len(self.host_xfer_buffers) == len(kv_caches), (
......@@ -701,66 +698,16 @@ class NixlConnectorWorker:
"host_xfer_buffer should not be initialized when "
f"kv_buffer_device is {self.kv_buffer_device}")
# TODO(tms): Find a more robust way to detect and handle MLA
# NOTE (NickLucche) To move blocks efficiently with NIXL, the expected
# KV memory layout is HND, as opposed to the default NHD. Note that it
# will only affects the strides. For MLA instead, we make require no
# such thing and resort to the standard layout.
use_mla = len(first_kv_cache.shape) == 3
if self.device_type == "tpu":
assert not use_mla, f"{self.kv_buffer_device} does not support MLA."
assert self._use_pallas_v1, f"attn backend: {self.backend_name}"
# tpu (v1) kv shape per layer:
# (num_blocks, block_size, num_kv_heads * 2, head_size)
self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
block_size, n_kv_heads_x_2, head_dim = block_shape
self.slot_size_bytes = kv_elem_size * n_kv_heads_x_2 * head_dim
elif self.device_type == "cuda":
assert use_mla == self.use_mla
# TODO (NickLucche) not compatible with hybrid allocator.
# Enforce check once it goes live, as a single kv layout
# is expected for xfers.
if use_mla:
# MLA case.
self.num_blocks = first_kv_cache.shape[0]
block_rank = 2 # [block_size, latent_dim]
block_shape = first_kv_cache.shape[-block_rank:]
block_size, kv_latent_dim = block_shape
self.slot_size_bytes = kv_elem_size * kv_latent_dim
else:
# [2 (k and v), num_blocks, ...]
if self._use_flashinfer:
# FlashInfer swaps 2<->num_blocks dimensions.
self.num_blocks = first_kv_cache.shape[0]
block_rank = 4 # [2, block_size, kv_heads, head_dim]
else:
self.num_blocks = first_kv_cache.shape[1]
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
block_size, n_kv_heads, head_dim = block_shape[-3:]
# head size in bytes.
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
assert block_size == self.block_size
else:
raise RuntimeError(
f"{self.device_type} ({self.backend_name}) is not supported.")
# TODO(tms): self.block_len needs to be per-layer for sliding window,
# hybrid attn, etc
# block size in bytes
self.block_len = kv_elem_size * math.prod(block_shape)
logger.info(
"Registering KV_Caches. use_mla: %s, kv_buffer_device: %s, "
"use_host_buffer: %s, num_blocks: %s, block_shape: %s, "
"per_layer_kv_cache_shape: %s", use_mla, self.kv_buffer_device,
self.use_host_buffer, self.num_blocks, block_shape,
first_kv_cache.shape)
self.dst_num_blocks[self.engine_id] = self.num_blocks
self.device_kv_caches = kv_caches
kv_caches_base_addr = []
"use_host_buffer: %s", self.use_mla, self.kv_buffer_device,
self.use_host_buffer)
caches_data = []
# With hybrid allocator, layers can share a kv cache tensor
seen_base_addresses = []
xfer_buffers = (self.host_xfer_buffers
if self.use_host_buffer else kv_caches)
# Note(tms): I modified this from the original region setup code.
# K and V are now in different regions. Advantage is that we can
......@@ -770,42 +717,35 @@ class NixlConnectorWorker:
# (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are transferred in the same tensor
# to better exploit the memory layout (ie num_blocks is the first dim).
for cache_or_caches in xfer_buffers.values():
# Normalize to always be a list of caches
cache_list = [cache_or_caches] if use_mla \
or self._use_pallas_v1 or self._use_flashinfer \
else cache_or_caches
split_k_and_v = not (self.use_mla or self._use_pallas_v1
or self._use_flashinfer)
tensor_size_bytes = None
for layer_name, cache_or_caches in xfer_buffers.items():
cache_list = cache_or_caches if split_k_and_v else [
cache_or_caches
]
for cache in cache_list:
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len
# NOTE: use tp_rank for device_id since multi-node TP
# is rarely used.
caches_data.append((base_addr, region_len, self.tp_rank, ""))
kv_caches_base_addr.append(base_addr)
self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
if base_addr in seen_base_addresses:
continue
seen_base_addresses.append(base_addr)
curr_tensor_size_bytes = cache.numel() * cache.element_size()
if tensor_size_bytes is None:
tensor_size_bytes = curr_tensor_size_bytes
self.num_blocks = cache.shape[0]
assert tensor_size_bytes == curr_tensor_size_bytes, \
"All kv cache tensors must have the same size"
caches_data.append(
(base_addr, tensor_size_bytes, self.tp_rank, ""))
self.kv_caches_base_addr[self.engine_id] = seen_base_addresses
self.num_regions = len(caches_data)
self.num_layers = len(xfer_buffers.keys())
# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
if self.vllm_config.model_config.hf_config.model_type == "llama4":
from transformers import Llama4TextConfig
assert isinstance(self.vllm_config.model_config.hf_text_config,
Llama4TextConfig)
llama4_config = self.vllm_config.model_config.hf_text_config
no_rope_layers = llama4_config.no_rope_layers
chunk_size = llama4_config.attention_chunk_size
chunk_block_size = math.ceil(chunk_size / self.block_size)
for layer_idx in range(self.num_layers):
# no_rope_layers[layer_idx] == 0 means NoPE (global)
# Any other value means RoPE (local chunked)
is_local_attention = no_rope_layers[layer_idx] != 0
block_window = chunk_block_size if is_local_attention else None
self.block_window_per_layer.append(block_window)
logger.debug("Llama 4 block window per layer mapping: %s",
self.block_window_per_layer)
assert len(self.block_window_per_layer) == self.num_layers
descs = self.nixl_wrapper.get_reg_descs(caches_data,
self.nixl_memory_type)
logger.debug("Registering descs: %s", caches_data)
......@@ -813,9 +753,20 @@ class NixlConnectorWorker:
logger.debug("Done registering descs")
self._registered_descs.append(descs)
assert tensor_size_bytes is not None
assert self.num_blocks != 0
assert tensor_size_bytes % self.num_blocks == 0
self.block_len = tensor_size_bytes // self.num_blocks
self.slot_size_bytes = self.block_len // self.block_size
if self._use_flashinfer:
assert self.slot_size_bytes % 2 == 0
self.slot_size_bytes /= 2
self.device_kv_caches = kv_caches
self.dst_num_blocks[self.engine_id] = self.num_blocks
# Register local/src descr for NIXL xfer.
blocks_data = []
for base_addr in self.kv_caches_base_addr[self.engine_id]:
for base_addr in seen_base_addresses:
# NOTE With heter-TP, more blocks are prepared than what are
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
# could create fewer, but then _get_block_descs_ids needs to
......@@ -836,6 +787,26 @@ class NixlConnectorWorker:
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
"NIXL_INIT_AGENT", descs)
# TODO(mgoin): Hybrid memory allocator is currently diabled for
# models with local attention (Llama 4). Can remove this once enabled.
if self.vllm_config.model_config.hf_config.model_type == "llama4":
from transformers import Llama4TextConfig
assert isinstance(self.vllm_config.model_config.hf_text_config,
Llama4TextConfig)
llama4_config = self.vllm_config.model_config.hf_text_config
no_rope_layers = llama4_config.no_rope_layers
chunk_size = llama4_config.attention_chunk_size
chunk_block_size = math.ceil(chunk_size / self.block_size)
for layer_idx in range(self.num_layers):
# no_rope_layers[layer_idx] == 0 means NoPE (global)
# Any other value means RoPE (local chunked)
is_local_attention = no_rope_layers[layer_idx] != 0
block_window = chunk_block_size if is_local_attention else None
self.block_window_per_layer.append(block_window)
logger.debug("Llama 4 block window per layer mapping: %s",
self.block_window_per_layer)
assert len(self.block_window_per_layer) == self.num_layers
# After KV Caches registered, listen for new connections.
metadata = NixlAgentMetadata(
engine_id=self.engine_id,
......
......@@ -30,27 +30,19 @@ logger = init_logger(__name__)
class ReqMeta:
# Request Id
request_id: str
# Request tokens
token_ids: torch.Tensor
# Slot mappings, should have the same length as token_ids
slot_mapping: torch.Tensor
# Request block ids
block_ids: torch.Tensor
# Request num tokens
num_tokens: int
@staticmethod
def make_meta(request_id: str, token_ids: list[int], block_ids: list[int],
block_size: int) -> "ReqMeta":
valid_num_tokens = len(token_ids)
token_ids_tensor = torch.tensor(token_ids)
block_ids_tensor = torch.tensor(block_ids)
num_blocks = block_ids_tensor.shape[0]
block_offsets = torch.arange(0, block_size)
slot_mapping = block_offsets.reshape((1, block_size)) + \
block_ids_tensor.reshape((num_blocks, 1)) * block_size
slot_mapping = slot_mapping.flatten()[:valid_num_tokens]
return ReqMeta(
request_id=request_id,
token_ids=token_ids_tensor,
slot_mapping=slot_mapping,
block_ids=block_ids_tensor,
num_tokens=len(token_ids),
)
......@@ -123,63 +115,58 @@ class P2pNcclConnector(KVConnectorBase_V1):
return
def inject_kv_into_layer(
dst_kv_cache_layer: torch.Tensor,
src_kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
layer: torch.Tensor,
kv_cache: torch.Tensor,
block_ids: torch.Tensor,
request_id: str,
) -> None:
"""Inject the KV cache into the layer.
"""
Inject KV cache data into a given attention layer tensor.
This function updates `layer` in-place with values from `kv_cache`,
handling different backend layouts:
- MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
indexed along the first dimension.
- FlashAttention: KV tensors are indexed along the second
dimension.
If the number of provided block IDs does not match the number of KV
blocks, only the overlapping portion is updated, and a warning is
logged.
Args:
dst_kv_cache_layer (torch.Tensor): the destination KV cache
layer. In shape [2, num_pages, page_size, xxx] if not
using MLA, [num_pages, page_size, xxx] otherwise.
src_kv_cache (torch.Tensor): the source KV cache. In shape
[2, num_tokens, xxx] if not using MLA, [num_tokens, xxx]
otherwise.
slot_mapping (torch.Tensor): the slot mapping. In shape
[num_tokens].
request_id (str): request id for log
layer (torch.Tensor): The attention layer KV tensor to update.
kv_cache (torch.Tensor): The KV cache tensor to inject.
block_ids (torch.Tensor): Indices of the blocks to update.
request_id (str): Request identifier used for logging.
Returns:
None. The function modifies `layer` in-place.
"""
dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
if isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values()):
num_pages = dst_kv_cache_layer_shape[0]
page_size = dst_kv_cache_layer_shape[1]
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
num_pages * page_size, -1)
self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache,
0)
num_token = src_kv_cache.shape[0]
if len(slot_mapping) == num_token:
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
if (isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values())
or layer.shape[1] == 2): # MLA or FlashInfer
num_block = kv_cache.shape[0]
self.check_tensors_except_dim(layer, kv_cache, 0)
if len(block_ids) == num_block:
layer[block_ids, ...] = kv_cache
else:
dst_kv_cache_layer[slot_mapping[:num_token],
...] = src_kv_cache
layer[block_ids[:num_block], ...] = kv_cache
logger.warning(
"🚧src_kv_cache does not match, num_slot:%d, "
"num_token:%d, request_id:%s", len(slot_mapping),
num_token, request_id)
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
else:
num_pages = dst_kv_cache_layer_shape[1]
page_size = dst_kv_cache_layer_shape[2]
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
2, num_pages * page_size, -1)
self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache,
1)
num_token = src_kv_cache.shape[1]
if len(slot_mapping) == num_token:
dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache
"🚧kv_cache does not match, block_ids:%d, "
"num_block:%d, request_id:%s", len(block_ids),
num_block, request_id)
elif layer.shape[0] == 2: # FlashAttention
num_block = kv_cache.shape[1]
self.check_tensors_except_dim(layer, kv_cache, 1)
if len(block_ids) == num_block:
layer[:, block_ids, ...] = kv_cache
else:
dst_kv_cache_layer[:, slot_mapping[:num_token],
...] = src_kv_cache
layer[:, block_ids[:num_block], ...] = kv_cache
logger.warning(
"🚧src_kv_cache does not match, num_slot:%d, "
"num_token:%d, request_id:%s", len(slot_mapping),
num_token, request_id)
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
"🚧kv_cache does not match, block_ids:%d, "
"num_block:%d, request_id:%s", len(block_ids),
num_block, request_id)
# Get the metadata
metadata: KVConnectorMetadata = \
......@@ -201,19 +188,17 @@ class P2pNcclConnector(KVConnectorBase_V1):
if kv_cache is None:
continue
kv_cache_layer = kv_cache[ \
forward_context.virtual_engine]
layer = kv_cache[forward_context.virtual_engine]
kv_cache = self.p2p_nccl_engine.recv_tensor(
request.request_id + "#" + layer_name)
if kv_cache is None:
logger.warning("🚧src_kv_cache is None, %s",
request.request_id)
logger.warning("🚧kv_cache is None, %s", request.request_id)
continue
inject_kv_into_layer(kv_cache_layer, kv_cache,
request.slot_mapping, request.request_id)
inject_kv_into_layer(layer, kv_cache, request.block_ids,
request.request_id)
tensor = self.p2p_nccl_engine.recv_store.pop(request.request_id + "#" + layer_name, None)
if tensor is not None:
del tensor
......@@ -248,16 +233,46 @@ class P2pNcclConnector(KVConnectorBase_V1):
assert self.p2p_nccl_engine is not None
def extract_kv_from_layer(
layer: torch.Tensor,
block_ids: torch.Tensor,
) -> torch.Tensor:
"""
Extract KV cache slices from a given attention layer tensor.
This function handles multiple backend layouts:
- MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
indexed along the first dimension.
- FlashAttention: KV tensors are indexed along the second
dimension.
Args:
layer (torch.Tensor): The KV cache from the attention layer.
block_ids (torch.Tensor): Indices of blocks to extract.
Returns:
torch.Tensor: A tensor containing the extracted KV slices.
Returns None if the layout is unsupported.
"""
if (isinstance(attn_metadata, MLACommonMetadata)
or layer.shape[1] == 2): # MLA or FlashInfer
return layer[block_ids, ...]
if layer.shape[0] == 2: # FlashAttention
return layer[:, block_ids, ...]
return None
connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, P2pNcclConnectorMetadata)
for request in connector_metadata.requests:
request_id = request.request_id
ip, port = self.parse_request_id(request_id, True)
remote_address = ip + ":" + str(port + self._rank)
self.p2p_nccl_engine.send_tensor(
request_id + "#" + layer_name, kv_layer, remote_address,
request.slot_mapping,
isinstance(attn_metadata, MLACommonMetadata))
kv_cache = extract_kv_from_layer(kv_layer, request.block_ids)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
def wait_for_save(self):
if self.is_producer:
......
......@@ -62,8 +62,6 @@ class SendQueueItem:
tensor_id: str
remote_address: str
tensor: torch.Tensor
slot_mapping: torch.Tensor
is_mla: bool
class P2pNcclEngine:
......@@ -202,8 +200,6 @@ class P2pNcclEngine:
tensor_id: str,
tensor: torch.Tensor,
remote_address: typing.Optional[str] = None,
slot_mapping: torch.Tensor = None,
is_mla: bool = False,
) -> bool:
if remote_address is None:
with self.recv_store_cv:
......@@ -213,9 +209,7 @@ class P2pNcclEngine:
item = SendQueueItem(tensor_id=tensor_id,
remote_address=remote_address,
tensor=tensor,
slot_mapping=slot_mapping,
is_mla=is_mla)
tensor=tensor)
if self.send_type == "PUT":
return self.send_sync(item)
......@@ -433,9 +427,7 @@ class P2pNcclEngine:
if item.remote_address not in self.socks:
self.create_connect(item.remote_address)
with self.send_stream:
tensor = self.extract_kv_from_layer(item.is_mla, item.tensor,
item.slot_mapping)
tensor = item.tensor
sock = self.socks[item.remote_address]
comm, rank = self.comms[item.remote_address]
......@@ -548,21 +540,3 @@ class P2pNcclEngine:
self._send_thread.join()
if self._ping_thread is not None:
self._ping_thread.join()
@staticmethod
def extract_kv_from_layer(
is_mla: bool,
layer: torch.Tensor,
slot_mapping: torch.Tensor,
) -> torch.Tensor:
"""Extract the KV cache from the layer.
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
if is_mla:
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...]
num_pages, page_size = layer.shape[1], layer.shape[2]
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
...]
......@@ -99,8 +99,9 @@ class TensorMemoryPool:
addr=self.base_address)
self.free_lists[self.max_block_size][
initial_block.addr] = initial_block
logger.debug("TensorMemoryPool, base_address:", self.base_address,
self.base_address % self.max_block_size)
logger.debug("TensorMemoryPool, base_address:%d, max_block_size:%d",
self.base_address, self.max_block_size)
def allocate(self, size: int) -> int:
"""Allocates a memory block of at least the requested size.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment