Unverified Commit 53415653 authored by Flora Feng's avatar Flora Feng Committed by GitHub
Browse files

[P/D][Nixl] Make kv cache register compatible with hybrid memory allocator (#23079)


Signed-off-by: default avatarsfeng33 <4florafeng@gmail.com>
parent 17373dcd
...@@ -14,6 +14,7 @@ from unittest.mock import patch ...@@ -14,6 +14,7 @@ from unittest.mock import patch
import pytest import pytest
import ray import ray
import torch
from vllm import LLM from vllm import LLM
from vllm.config import KVTransferConfig from vllm.config import KVTransferConfig
...@@ -22,6 +23,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( ...@@ -22,6 +23,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorWorker) NixlConnectorWorker)
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from .utils import create_request, create_scheduler, create_vllm_config from .utils import create_request, create_scheduler, create_vllm_config
...@@ -98,7 +100,6 @@ class FakeNixlWrapper: ...@@ -98,7 +100,6 @@ class FakeNixlWrapper:
def set_cycles_before_xfer_done(self, cycles: int): def set_cycles_before_xfer_done(self, cycles: int):
"""Set the number of cycles before a transfer is considered done.""" """Set the number of cycles before a transfer is considered done."""
self._cycles_before_xfer_done = cycles
@contextlib.contextmanager @contextlib.contextmanager
...@@ -562,3 +563,86 @@ def _run_abort_timeout_test(llm_kwargs: dict, timeout: int): ...@@ -562,3 +563,86 @@ def _run_abort_timeout_test(llm_kwargs: dict, timeout: int):
sampling_params) sampling_params)
# Request-0 times out and is cleared! # Request-0 times out and is cleared!
assert '0' not in req_to_blocks assert '0' not in req_to_blocks
def test_register_kv_caches(dist_init):
"""
Test that register_kv_caches() properly calls nixl_wrapper methods with
correct data.
This test verifies:
1. nixl_wrapper.get_reg_descs() is called with caches_data containing
tensor metadata
2. nixl_wrapper.get_xfer_descs() is called with blocks_data containing
block layout info
"""
vllm_config = create_vllm_config()
# Create test kv cache tensors using proper backend shape
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(num_blocks=2,
block_size=16,
num_kv_heads=4,
head_size=64)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}
# Store tensor info for validation
expected_tensor_size = shared_tensor[0].element_size(
) * shared_tensor[0].numel()
expected_base_addrs = [
shared_tensor[0].data_ptr(), shared_tensor[1].data_ptr(),
unique_tensor[0].data_ptr(), unique_tensor[1].data_ptr()
]
with patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper") as mock_nixl_wrapper, \
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"), \
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"): # noqa: E501
# Create connector
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0)
# Get the mock instance
mock_wrapper_instance = mock_nixl_wrapper.return_value
connector.connector_worker.nixl_wrapper = mock_wrapper_instance
# Execute register_kv_caches
connector.register_kv_caches(kv_caches)
# Verify get_reg_descs was called with caches_data
assert mock_wrapper_instance.get_reg_descs.called
caches_data, _ = mock_wrapper_instance.get_reg_descs.call_args[0]
assert len(caches_data) == 4
for i, cache_entry in enumerate(caches_data):
base_addr, size, _tp_rank, _ = cache_entry
assert size == expected_tensor_size, \
f"Entry {i}: Expected tensor size {expected_tensor_size}, " \
f"got {size}"
assert base_addr == expected_base_addrs[i], \
f"Entry {i}: Expected base address {expected_base_addrs[i]}, " \
f"got {base_addr}"
# Verify get_xfer_descs was called with blocks_data
assert mock_wrapper_instance.get_xfer_descs.called
blocks_data, _ = mock_wrapper_instance.get_xfer_descs.call_args[0]
# Validate blocks_data structure and size
expected_blocks_count = 8
assert len(blocks_data) == expected_blocks_count, \
f"Expected {expected_blocks_count} blocks, " \
f"got {len(blocks_data)}"
expected_block_len = expected_tensor_size // 2
for i, block_entry in enumerate(blocks_data):
block_start_addr, block_len, tp_rank = block_entry
assert block_len == expected_block_len, \
f"Block entry {i}: Expected block len {expected_block_len}, " \
f"got {block_len}"
...@@ -131,8 +131,8 @@ class KVConnectorBase_V1(ABC): ...@@ -131,8 +131,8 @@ class KVConnectorBase_V1(ABC):
Initialize with the KV caches. Useful for pre-registering the Initialize with the KV caches. Useful for pre-registering the
KV Caches in the KVConnector (e.g. for NIXL). KV Caches in the KVConnector (e.g. for NIXL).
Args: kv_caches: Args:
dictionary of layer names, kv cache kv_caches: dictionary of layer names, kv cache
""" """
return return
......
...@@ -686,9 +686,6 @@ class NixlConnectorWorker: ...@@ -686,9 +686,6 @@ class NixlConnectorWorker:
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in nixl.""" """Register the KV Cache data in nixl."""
_, first_kv_cache = next(iter(kv_caches.items()))
kv_elem_size = first_kv_cache.element_size()
if self.use_host_buffer: if self.use_host_buffer:
self.initialize_host_xfer_buffer(kv_caches=kv_caches) self.initialize_host_xfer_buffer(kv_caches=kv_caches)
assert len(self.host_xfer_buffers) == len(kv_caches), ( assert len(self.host_xfer_buffers) == len(kv_caches), (
...@@ -701,66 +698,16 @@ class NixlConnectorWorker: ...@@ -701,66 +698,16 @@ class NixlConnectorWorker:
"host_xfer_buffer should not be initialized when " "host_xfer_buffer should not be initialized when "
f"kv_buffer_device is {self.kv_buffer_device}") f"kv_buffer_device is {self.kv_buffer_device}")
# TODO(tms): Find a more robust way to detect and handle MLA
# NOTE (NickLucche) To move blocks efficiently with NIXL, the expected
# KV memory layout is HND, as opposed to the default NHD. Note that it
# will only affects the strides. For MLA instead, we make require no
# such thing and resort to the standard layout.
use_mla = len(first_kv_cache.shape) == 3
if self.device_type == "tpu":
assert not use_mla, f"{self.kv_buffer_device} does not support MLA."
assert self._use_pallas_v1, f"attn backend: {self.backend_name}"
# tpu (v1) kv shape per layer:
# (num_blocks, block_size, num_kv_heads * 2, head_size)
self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
block_size, n_kv_heads_x_2, head_dim = block_shape
self.slot_size_bytes = kv_elem_size * n_kv_heads_x_2 * head_dim
elif self.device_type == "cuda":
assert use_mla == self.use_mla
# TODO (NickLucche) not compatible with hybrid allocator.
# Enforce check once it goes live, as a single kv layout
# is expected for xfers.
if use_mla:
# MLA case.
self.num_blocks = first_kv_cache.shape[0]
block_rank = 2 # [block_size, latent_dim]
block_shape = first_kv_cache.shape[-block_rank:]
block_size, kv_latent_dim = block_shape
self.slot_size_bytes = kv_elem_size * kv_latent_dim
else:
# [2 (k and v), num_blocks, ...]
if self._use_flashinfer:
# FlashInfer swaps 2<->num_blocks dimensions.
self.num_blocks = first_kv_cache.shape[0]
block_rank = 4 # [2, block_size, kv_heads, head_dim]
else:
self.num_blocks = first_kv_cache.shape[1]
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
block_size, n_kv_heads, head_dim = block_shape[-3:]
# head size in bytes.
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
assert block_size == self.block_size
else:
raise RuntimeError(
f"{self.device_type} ({self.backend_name}) is not supported.")
# TODO(tms): self.block_len needs to be per-layer for sliding window,
# hybrid attn, etc
# block size in bytes
self.block_len = kv_elem_size * math.prod(block_shape)
logger.info( logger.info(
"Registering KV_Caches. use_mla: %s, kv_buffer_device: %s, " "Registering KV_Caches. use_mla: %s, kv_buffer_device: %s, "
"use_host_buffer: %s, num_blocks: %s, block_shape: %s, " "use_host_buffer: %s", self.use_mla, self.kv_buffer_device,
"per_layer_kv_cache_shape: %s", use_mla, self.kv_buffer_device, self.use_host_buffer)
self.use_host_buffer, self.num_blocks, block_shape,
first_kv_cache.shape)
self.dst_num_blocks[self.engine_id] = self.num_blocks
self.device_kv_caches = kv_caches
kv_caches_base_addr = []
caches_data = [] caches_data = []
# With hybrid allocator, layers can share a kv cache tensor
seen_base_addresses = []
xfer_buffers = (self.host_xfer_buffers
if self.use_host_buffer else kv_caches)
# Note(tms): I modified this from the original region setup code. # Note(tms): I modified this from the original region setup code.
# K and V are now in different regions. Advantage is that we can # K and V are now in different regions. Advantage is that we can
...@@ -770,42 +717,35 @@ class NixlConnectorWorker: ...@@ -770,42 +717,35 @@ class NixlConnectorWorker:
# (roughly 8KB vs 5KB). # (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are transferred in the same tensor # Conversely for FlashInfer, K and V are transferred in the same tensor
# to better exploit the memory layout (ie num_blocks is the first dim). # to better exploit the memory layout (ie num_blocks is the first dim).
for cache_or_caches in xfer_buffers.values(): split_k_and_v = not (self.use_mla or self._use_pallas_v1
# Normalize to always be a list of caches or self._use_flashinfer)
cache_list = [cache_or_caches] if use_mla \ tensor_size_bytes = None
or self._use_pallas_v1 or self._use_flashinfer \ for layer_name, cache_or_caches in xfer_buffers.items():
else cache_or_caches cache_list = cache_or_caches if split_k_and_v else [
cache_or_caches
]
for cache in cache_list: for cache in cache_list:
base_addr = cache.data_ptr() base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len if base_addr in seen_base_addresses:
# NOTE: use tp_rank for device_id since multi-node TP continue
# is rarely used.
caches_data.append((base_addr, region_len, self.tp_rank, "")) seen_base_addresses.append(base_addr)
kv_caches_base_addr.append(base_addr) curr_tensor_size_bytes = cache.numel() * cache.element_size()
self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
if tensor_size_bytes is None:
tensor_size_bytes = curr_tensor_size_bytes
self.num_blocks = cache.shape[0]
assert tensor_size_bytes == curr_tensor_size_bytes, \
"All kv cache tensors must have the same size"
caches_data.append(
(base_addr, tensor_size_bytes, self.tp_rank, ""))
self.kv_caches_base_addr[self.engine_id] = seen_base_addresses
self.num_regions = len(caches_data) self.num_regions = len(caches_data)
self.num_layers = len(xfer_buffers.keys()) self.num_layers = len(xfer_buffers.keys())
# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
if self.vllm_config.model_config.hf_config.model_type == "llama4":
from transformers import Llama4TextConfig
assert isinstance(self.vllm_config.model_config.hf_text_config,
Llama4TextConfig)
llama4_config = self.vllm_config.model_config.hf_text_config
no_rope_layers = llama4_config.no_rope_layers
chunk_size = llama4_config.attention_chunk_size
chunk_block_size = math.ceil(chunk_size / self.block_size)
for layer_idx in range(self.num_layers):
# no_rope_layers[layer_idx] == 0 means NoPE (global)
# Any other value means RoPE (local chunked)
is_local_attention = no_rope_layers[layer_idx] != 0
block_window = chunk_block_size if is_local_attention else None
self.block_window_per_layer.append(block_window)
logger.debug("Llama 4 block window per layer mapping: %s",
self.block_window_per_layer)
assert len(self.block_window_per_layer) == self.num_layers
descs = self.nixl_wrapper.get_reg_descs(caches_data, descs = self.nixl_wrapper.get_reg_descs(caches_data,
self.nixl_memory_type) self.nixl_memory_type)
logger.debug("Registering descs: %s", caches_data) logger.debug("Registering descs: %s", caches_data)
...@@ -813,9 +753,20 @@ class NixlConnectorWorker: ...@@ -813,9 +753,20 @@ class NixlConnectorWorker:
logger.debug("Done registering descs") logger.debug("Done registering descs")
self._registered_descs.append(descs) self._registered_descs.append(descs)
assert tensor_size_bytes is not None
assert self.num_blocks != 0
assert tensor_size_bytes % self.num_blocks == 0
self.block_len = tensor_size_bytes // self.num_blocks
self.slot_size_bytes = self.block_len // self.block_size
if self._use_flashinfer:
assert self.slot_size_bytes % 2 == 0
self.slot_size_bytes /= 2
self.device_kv_caches = kv_caches
self.dst_num_blocks[self.engine_id] = self.num_blocks
# Register local/src descr for NIXL xfer. # Register local/src descr for NIXL xfer.
blocks_data = [] blocks_data = []
for base_addr in self.kv_caches_base_addr[self.engine_id]: for base_addr in seen_base_addresses:
# NOTE With heter-TP, more blocks are prepared than what are # NOTE With heter-TP, more blocks are prepared than what are
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
# could create fewer, but then _get_block_descs_ids needs to # could create fewer, but then _get_block_descs_ids needs to
...@@ -836,6 +787,26 @@ class NixlConnectorWorker: ...@@ -836,6 +787,26 @@ class NixlConnectorWorker:
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
"NIXL_INIT_AGENT", descs) "NIXL_INIT_AGENT", descs)
# TODO(mgoin): Hybrid memory allocator is currently diabled for
# models with local attention (Llama 4). Can remove this once enabled.
if self.vllm_config.model_config.hf_config.model_type == "llama4":
from transformers import Llama4TextConfig
assert isinstance(self.vllm_config.model_config.hf_text_config,
Llama4TextConfig)
llama4_config = self.vllm_config.model_config.hf_text_config
no_rope_layers = llama4_config.no_rope_layers
chunk_size = llama4_config.attention_chunk_size
chunk_block_size = math.ceil(chunk_size / self.block_size)
for layer_idx in range(self.num_layers):
# no_rope_layers[layer_idx] == 0 means NoPE (global)
# Any other value means RoPE (local chunked)
is_local_attention = no_rope_layers[layer_idx] != 0
block_window = chunk_block_size if is_local_attention else None
self.block_window_per_layer.append(block_window)
logger.debug("Llama 4 block window per layer mapping: %s",
self.block_window_per_layer)
assert len(self.block_window_per_layer) == self.num_layers
# After KV Caches registered, listen for new connections. # After KV Caches registered, listen for new connections.
metadata = NixlAgentMetadata( metadata = NixlAgentMetadata(
engine_id=self.engine_id, engine_id=self.engine_id,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment