Unverified Commit 098d8447 authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[NIXL][1/N] Refactor `kernel_block_size` detection (#35752)


Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
parent a40ee486
......@@ -9,7 +9,7 @@ import textwrap
import time
import uuid
from collections import defaultdict
from typing import Any
from typing import Any, cast
from unittest.mock import MagicMock, patch
import msgspec
......@@ -332,14 +332,22 @@ def test_kv_transfer_handshake(dist_init):
# Prefill connector will register KV cache to populate proper handshake
# metadata.
# TODO this must match with values used in kv cache config
kv_cache_config = make_kv_cache_config(block_size=16, num_blocks=2)
prefill_connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
vllm_config, KVConnectorRole.WORKER, kv_cache_config
)
kv_cache_spec = cast(
AttentionSpec, kv_cache_config.kv_cache_groups[0].kv_cache_spec
)
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
num_blocks=kv_cache_config.num_blocks,
block_size=kv_cache_spec.block_size,
num_kv_heads=kv_cache_spec.num_kv_heads,
head_size=kv_cache_spec.head_size,
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
......@@ -383,7 +391,7 @@ def test_kv_transfer_handshake(dist_init):
# Decode connector will be able to create handshake with the prefill connector.
decode_connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
vllm_config, KVConnectorRole.WORKER, kv_cache_config
)
decode_connector.register_kv_caches(kv_caches)
......@@ -525,11 +533,13 @@ class TestNixlHandshake:
request_id = "req_id"
# Test worker role in decode server.
connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
kv_cache_config = make_kv_cache_config(block_size=16, num_blocks=2)
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER, kv_cache_config)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
vllm_config,
connector.engine_id,
hand_shake_latency=0,
kv_cache_config=kv_cache_config,
)
assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper)
worker = connector.connector_worker
......@@ -1479,18 +1489,22 @@ def test_register_kv_caches(
patch(f"{nixl_module}.threading.Event"),
patch(f"{nixl_module}.threading.Thread") as mock_thread,
patch(f"{nixl_module}.get_current_attn_backend") as mock_get_attn_backend,
patch(f"{nixl_module}.get_current_attn_backends") as mock_get_attn_backends,
):
# Ensure get_attn_backend returns the correct value due to
# _cached_get_attn_backend returning the backend from previous
# test run if not mocking.
mock_get_attn_backend.return_value = backend_cls
mock_get_attn_backends.return_value = [backend_cls]
# Create connector
connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
kv_cache_config = make_kv_cache_config(block_size=16, num_blocks=2)
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER, kv_cache_config)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
vllm_config,
connector.engine_id,
hand_shake_latency=0,
kv_cache_config=kv_cache_config,
)
# Get the mock instance
......@@ -1515,6 +1529,13 @@ def test_register_kv_caches(
num_layers = 32
block_size = 16
num_blocks = 8
# Keep the fake worker's expected num_blocks in sync with the
# cross-layer tensor we are about to register.
worker_kv_cache_config = make_kv_cache_config(
block_size=block_size, num_blocks=num_blocks
)
connector.connector_worker.kv_cache_config = worker_kv_cache_config
connector.connector_worker.num_blocks = worker_kv_cache_config.num_blocks
kv_cache_spec = AttentionSpec(
block_size=block_size,
num_kv_heads=4,
......@@ -1568,11 +1589,17 @@ def test_register_kv_caches(
else:
# Create test kv cache tensors using proper backend shape
kv_cache_spec = cast(
AttentionSpec, kv_cache_config.kv_cache_groups[0].kv_cache_spec
)
kv_cache_shape = backend_cls.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
num_blocks=kv_cache_config.num_blocks,
block_size=kv_cache_spec.block_size,
num_kv_heads=kv_cache_spec.num_kv_heads,
head_size=kv_cache_spec.head_size,
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
......@@ -1606,7 +1633,7 @@ def test_register_kv_caches(
unique_tensor[1].data_ptr(),
]
expected_num_entries = 4
expected_blocks_count = 8
expected_blocks_count = kv_cache_config.num_blocks * 4
# Execute register_kv_caches
connector.register_kv_caches(kv_caches)
......@@ -1639,7 +1666,7 @@ def test_register_kv_caches(
num_blocks = 8
expected_block_len = expected_tensor_size // num_blocks
else:
num_blocks = 2
num_blocks = kv_cache_config.num_blocks
if is_blocks_first:
expected_block_len = expected_tensor_size // num_blocks // 2
else:
......@@ -2226,15 +2253,22 @@ def test_compatibility_hash_validation(
"enforce_handshake_compat": enforce_handshake_compat
},
)
kv_cache_config = make_kv_cache_config(block_size=16, num_blocks=2)
decode_connector = NixlConnector(
local_vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
local_vllm_config, KVConnectorRole.WORKER, kv_cache_config
)
decode_worker = decode_connector.connector_worker
kv_cache_spec = cast(
AttentionSpec, kv_cache_config.kv_cache_groups[0].kv_cache_spec
)
kv_cache_shape = decode_worker.attn_backend.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
num_blocks=kv_cache_config.num_blocks,
block_size=kv_cache_spec.block_size,
num_kv_heads=kv_cache_spec.num_kv_heads,
head_size=kv_cache_spec.head_size,
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
......
......@@ -38,7 +38,7 @@ from vllm.v1.kv_cache_interface import (
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.utils import AttentionGroup, select_common_block_size
from vllm.v1.worker.utils import select_common_block_size
BLOCK_SIZE = 16
NUM_BLOCKS = 10
......@@ -203,37 +203,25 @@ def _make_kv_cache_spec() -> FullAttentionSpec:
def test_select_common_block_size_prefers_manager_block_size():
backend_a = _make_mock_backend_for_kernel_block_size([MultipleOf(32)])
backend_b = _make_mock_backend_for_kernel_block_size([64, MultipleOf(16)])
attn_groups = [
AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0),
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
]
selected_size = select_common_block_size(128, attn_groups)
selected_size = select_common_block_size(128, [backend_a, backend_b])
assert selected_size == 128
def test_select_common_block_size_uses_largest_shared_int():
backend_a = _make_mock_backend_for_kernel_block_size([128, 64])
backend_b = _make_mock_backend_for_kernel_block_size([64, 32])
attn_groups = [
AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0),
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
]
selected_size = select_common_block_size(256, attn_groups)
selected_size = select_common_block_size(256, [backend_a, backend_b])
assert selected_size == 64
def test_select_common_block_size_no_valid_option():
backend_a = _make_mock_backend_for_kernel_block_size([64])
backend_b = _make_mock_backend_for_kernel_block_size([MultipleOf(16)])
attn_groups = [
AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0),
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
]
with pytest.raises(ValueError):
select_common_block_size(48, attn_groups)
select_common_block_size(48, [backend_a, backend_b])
def test_update_states_new_request(model_runner, dist_init):
......
......@@ -358,15 +358,6 @@ class TpKVTopology:
# stride_order to retrieve physical position of block_size
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
# In the default non-cross layers layout the block_size position
# is logical while in the cross layers case it is the physical
# position. This matches the shape of the actual kv cache tensors
# passed at register_kv_caches()/register_cross_layers_kv_cache()
block_size_position = kv_cache_shape.index(_MOCK_BLOCK_SIZE)
assert block_size_position is not None
self._block_size_position = -(len(kv_cache_shape) - block_size_position)
@property
def is_kv_layout_blocks_first(self) -> bool:
return self._is_kv_layout_blocks_first
......@@ -390,10 +381,6 @@ class TpKVTopology:
def cross_layers_blocks(self) -> bool:
return self._cross_layers_blocks
@property
def block_size_position(self) -> int:
return self._block_size_position
def tp_ratio(
self,
remote_tp_size: int,
......@@ -484,23 +471,46 @@ class TpKVTopology:
return self.get_target_remote_ranks(remote_tp_size)
def get_current_attn_backend(vllm_config: VllmConfig):
def get_current_attn_backends(
vllm_config: VllmConfig, layer_names: list[str] | None = None
) -> list[type[AttentionBackend]]:
"""Get all distinct attention backends for the given layers.
Args:
vllm_config: The current vLLM configuration.
layer_names: Optional list of layer names to scope the lookup.
When None, all attention layers are considered.
Returns:
Deduplicated list of attention backend classes.
"""
layer_type = cast(type[Any], AttentionLayerBase)
layers = get_layers_from_vllm_config(vllm_config, layer_type, None)
layers = get_layers_from_vllm_config(vllm_config, layer_type, layer_names)
if layers:
backend = next(iter(layers.values())).get_attn_backend()
else:
# Fallback for tests, when static_forward_context is empty.
logger.debug(
"No layers found in the vLLM config. "
"Falling back to default attention backend."
)
from vllm.v1.attention.selector import get_attn_backend
seen: dict[str, type[AttentionBackend]] = {}
for layer in layers.values():
backend = layer.get_attn_backend()
seen[backend.full_cls_name()] = backend
return list(seen.values())
# Fallback for tests, when static_forward_context is empty.
logger.debug(
"No layers found in the vLLM config. Falling back to default attention backend."
)
from vllm.v1.attention.selector import get_attn_backend
backend = get_attn_backend(
return [
get_attn_backend(
head_size=vllm_config.model_config.get_head_size(),
dtype=vllm_config.model_config.dtype,
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
use_mla=vllm_config.model_config.use_mla,
)
return backend
]
def get_current_attn_backend(
vllm_config: VllmConfig, layer_names: list[str] | None = None
) -> type[AttentionBackend]:
"""Get the first attention backend for the given layers."""
return get_current_attn_backends(vllm_config, layer_names)[0]
......@@ -13,7 +13,7 @@ from collections import defaultdict
from collections.abc import Iterator
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast
import msgspec
import numpy as np
......@@ -27,6 +27,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import (
EngineId,
TpKVTopology,
get_current_attn_backend,
get_current_attn_backends,
kv_postprocess_blksize_and_layout_on_receive,
kv_postprocess_blksize_on_receive,
kv_postprocess_layout_on_receive,
......@@ -61,6 +62,7 @@ from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, SlidingWindowSpec
from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.utils import select_common_block_size
if TYPE_CHECKING:
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
......@@ -945,7 +947,8 @@ class NixlConnectorWorker:
# Config.
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
# mypy will complain on re-assignment otherwise.
self.block_size: int = cast(int, vllm_config.cache_config.block_size)
if vllm_config.kv_transfer_config is None:
raise ValueError("kv_transfer_config must be set for NixlConnector")
......@@ -993,7 +996,7 @@ class NixlConnectorWorker:
self.tp_rank = get_tensor_model_parallel_rank()
self.world_size = get_tensor_model_parallel_world_size()
self.tp_group = get_tp_group()
self.num_blocks = 0
self.num_blocks = kv_cache_config.num_blocks
self.enable_permute_local_kv = False
# KV Caches and nixl tracking data.
......@@ -1131,11 +1134,30 @@ class NixlConnectorWorker:
self.xfer_stats = NixlKVConnectorStats()
self._physical_blocks_per_logical_kv_block = 1
self._sync_block_size_with_kernel()
self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config(
"enforce_handshake_compat", True
)
def _sync_block_size_with_kernel(self) -> None:
backends = get_current_attn_backends(self.vllm_config)
kernel_block_size = select_common_block_size(self.block_size, backends)
if self.block_size != kernel_block_size:
logger.info_once(
"User-specified logical block size (%s) does not match"
" physical kernel block size (%s). Using the latter.",
self.block_size,
kernel_block_size,
)
assert self.block_size > kernel_block_size
self._physical_blocks_per_logical_kv_block = (
self.block_size // kernel_block_size
)
self.block_size = kernel_block_size
self._block_size[self.engine_id] = kernel_block_size
self.num_blocks *= self._physical_blocks_per_logical_kv_block
def _nixl_handshake(
self,
host: str,
......@@ -1469,7 +1491,6 @@ class NixlConnectorWorker:
# Enable different block lengths for different layers when MLA is used.
self.block_len_per_layer = list[int]()
self.slot_size_per_layer = list[int]() # HD bytes in kv terms
for layer_name, cache_or_caches in xfer_buffers.items():
cache_list = (
cache_or_caches if self.kv_topo.split_k_and_v else [cache_or_caches]
......@@ -1486,26 +1507,11 @@ class NixlConnectorWorker:
logger.debug(
"Registering layer %s with cache shape: %s", layer_name, cache.shape
)
kernel_block_size = cache.shape[self.kv_topo.block_size_position]
if self.block_size != kernel_block_size:
logger.info_once(
"User-specified logical block size (%s) does not match"
" physical kernel block size (%s). Using the latter. ",
self.block_size,
kernel_block_size,
)
self._physical_blocks_per_logical_kv_block = (
self.block_size // kernel_block_size
)
self.block_size = kernel_block_size
self._block_size[self.engine_id] = kernel_block_size
seen_base_addresses.append(base_addr)
curr_tensor_size_bytes = cache.numel() * cache.element_size()
if tensor_size_bytes is None:
tensor_size_bytes = curr_tensor_size_bytes
self.num_blocks = cache.shape[0]
assert cache.shape[0] == self.num_blocks, (
"All kv cache tensors must have the same number of blocks"
......@@ -1514,9 +1520,6 @@ class NixlConnectorWorker:
self.block_len_per_layer.append(
curr_tensor_size_bytes // self.num_blocks
)
self.slot_size_per_layer.append(
self.block_len_per_layer[-1] // self.block_size
)
if not self.use_mla:
# Different kv cache shape is not supported by HeteroTP
......@@ -1534,7 +1537,6 @@ class NixlConnectorWorker:
"Different block lengths collected: %s", set(self.block_len_per_layer)
)
assert len(self.block_len_per_layer) == len(seen_base_addresses)
assert self.num_blocks != 0
self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses
self.num_regions = len(caches_data)
......@@ -1550,10 +1552,6 @@ class NixlConnectorWorker:
self.dst_num_blocks[self.engine_id] = self.num_blocks
if self.kv_topo.is_kv_layout_blocks_first:
for i in range(len(self.slot_size_per_layer)):
assert self.slot_size_per_layer[i] % 2 == 0
self.slot_size_per_layer[i] //= 2
# NOTE (NickLucche) When FlashInfer is used, memory is registered
# with joint KV for each block. This minimizes the overhead in
# registerMem allowing faster descs queries. In order to be able to
......
......@@ -258,7 +258,8 @@ class AttentionGroup:
def select_common_block_size(
kv_manager_block_size: int, attn_groups: list[AttentionGroup]
kv_manager_block_size: int,
backends: list[type[AttentionBackend]],
) -> int:
"""
Select a block size that is supported by all backends and is a factor of
......@@ -269,7 +270,7 @@ def select_common_block_size(
Args:
kv_manager_block_size: Block size of KV cache.
attn_groups: List of attention groups.
backends: List of attention backend classes.
Returns:
The selected block size.
......@@ -297,8 +298,6 @@ def select_common_block_size(
return False
return True
backends = [group.backend for group in attn_groups]
# Case 1: if the block_size of kv cache manager is supported by all backends,
# return it directly.
if block_size_is_supported(backends, kv_manager_block_size):
......@@ -356,8 +355,9 @@ def prepare_kernel_block_sizes(
if isinstance(kv_cache_spec, AttentionSpec):
# This is an attention backend that supports virtual block splitting.
kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size
group_backends = [g.backend for g in attn_groups[kv_cache_gid]]
selected_kernel_size = select_common_block_size(
kv_manager_block_size, attn_groups[kv_cache_gid]
kv_manager_block_size, group_backends
)
kernel_block_sizes.append(selected_kernel_size)
elif isinstance(kv_cache_spec, MambaSpec):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment