Unverified Commit 098d8447 authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[NIXL][1/N] Refactor `kernel_block_size` detection (#35752)


Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
parent a40ee486
...@@ -9,7 +9,7 @@ import textwrap ...@@ -9,7 +9,7 @@ import textwrap
import time import time
import uuid import uuid
from collections import defaultdict from collections import defaultdict
from typing import Any from typing import Any, cast
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import msgspec import msgspec
...@@ -332,14 +332,22 @@ def test_kv_transfer_handshake(dist_init): ...@@ -332,14 +332,22 @@ def test_kv_transfer_handshake(dist_init):
# Prefill connector will register KV cache to populate proper handshake # Prefill connector will register KV cache to populate proper handshake
# metadata. # metadata.
# TODO this must match with values used in kv cache config
kv_cache_config = make_kv_cache_config(block_size=16, num_blocks=2)
prefill_connector = NixlConnector( prefill_connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) vllm_config, KVConnectorRole.WORKER, kv_cache_config
)
kv_cache_spec = cast(
AttentionSpec, kv_cache_config.kv_cache_groups[0].kv_cache_spec
) )
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 num_blocks=kv_cache_config.num_blocks,
block_size=kv_cache_spec.block_size,
num_kv_heads=kv_cache_spec.num_kv_heads,
head_size=kv_cache_spec.head_size,
) )
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) shared_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) unique_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype)
kv_caches = { kv_caches = {
"layer0": shared_tensor, "layer0": shared_tensor,
"layer1": unique_tensor, "layer1": unique_tensor,
...@@ -383,7 +391,7 @@ def test_kv_transfer_handshake(dist_init): ...@@ -383,7 +391,7 @@ def test_kv_transfer_handshake(dist_init):
# Decode connector will be able to create handshake with the prefill connector. # Decode connector will be able to create handshake with the prefill connector.
decode_connector = NixlConnector( decode_connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) vllm_config, KVConnectorRole.WORKER, kv_cache_config
) )
decode_connector.register_kv_caches(kv_caches) decode_connector.register_kv_caches(kv_caches)
...@@ -525,11 +533,13 @@ class TestNixlHandshake: ...@@ -525,11 +533,13 @@ class TestNixlHandshake:
request_id = "req_id" request_id = "req_id"
# Test worker role in decode server. # Test worker role in decode server.
connector = NixlConnector( kv_cache_config = make_kv_cache_config(block_size=16, num_blocks=2)
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) connector = NixlConnector(vllm_config, KVConnectorRole.WORKER, kv_cache_config)
)
connector.connector_worker = FakeNixlConnectorWorker( connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0 vllm_config,
connector.engine_id,
hand_shake_latency=0,
kv_cache_config=kv_cache_config,
) )
assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper) assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper)
worker = connector.connector_worker worker = connector.connector_worker
...@@ -1479,18 +1489,22 @@ def test_register_kv_caches( ...@@ -1479,18 +1489,22 @@ def test_register_kv_caches(
patch(f"{nixl_module}.threading.Event"), patch(f"{nixl_module}.threading.Event"),
patch(f"{nixl_module}.threading.Thread") as mock_thread, patch(f"{nixl_module}.threading.Thread") as mock_thread,
patch(f"{nixl_module}.get_current_attn_backend") as mock_get_attn_backend, patch(f"{nixl_module}.get_current_attn_backend") as mock_get_attn_backend,
patch(f"{nixl_module}.get_current_attn_backends") as mock_get_attn_backends,
): ):
# Ensure get_attn_backend returns the correct value due to # Ensure get_attn_backend returns the correct value due to
# _cached_get_attn_backend returning the backend from previous # _cached_get_attn_backend returning the backend from previous
# test run if not mocking. # test run if not mocking.
mock_get_attn_backend.return_value = backend_cls mock_get_attn_backend.return_value = backend_cls
mock_get_attn_backends.return_value = [backend_cls]
# Create connector # Create connector
connector = NixlConnector( kv_cache_config = make_kv_cache_config(block_size=16, num_blocks=2)
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) connector = NixlConnector(vllm_config, KVConnectorRole.WORKER, kv_cache_config)
)
connector.connector_worker = FakeNixlConnectorWorker( connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0 vllm_config,
connector.engine_id,
hand_shake_latency=0,
kv_cache_config=kv_cache_config,
) )
# Get the mock instance # Get the mock instance
...@@ -1515,6 +1529,13 @@ def test_register_kv_caches( ...@@ -1515,6 +1529,13 @@ def test_register_kv_caches(
num_layers = 32 num_layers = 32
block_size = 16 block_size = 16
num_blocks = 8 num_blocks = 8
# Keep the fake worker's expected num_blocks in sync with the
# cross-layer tensor we are about to register.
worker_kv_cache_config = make_kv_cache_config(
block_size=block_size, num_blocks=num_blocks
)
connector.connector_worker.kv_cache_config = worker_kv_cache_config
connector.connector_worker.num_blocks = worker_kv_cache_config.num_blocks
kv_cache_spec = AttentionSpec( kv_cache_spec = AttentionSpec(
block_size=block_size, block_size=block_size,
num_kv_heads=4, num_kv_heads=4,
...@@ -1568,11 +1589,17 @@ def test_register_kv_caches( ...@@ -1568,11 +1589,17 @@ def test_register_kv_caches(
else: else:
# Create test kv cache tensors using proper backend shape # Create test kv cache tensors using proper backend shape
kv_cache_spec = cast(
AttentionSpec, kv_cache_config.kv_cache_groups[0].kv_cache_spec
)
kv_cache_shape = backend_cls.get_kv_cache_shape( kv_cache_shape = backend_cls.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 num_blocks=kv_cache_config.num_blocks,
block_size=kv_cache_spec.block_size,
num_kv_heads=kv_cache_spec.num_kv_heads,
head_size=kv_cache_spec.head_size,
) )
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) shared_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) unique_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype)
kv_caches = { kv_caches = {
"layer0": shared_tensor, "layer0": shared_tensor,
"layer1": unique_tensor, "layer1": unique_tensor,
...@@ -1606,7 +1633,7 @@ def test_register_kv_caches( ...@@ -1606,7 +1633,7 @@ def test_register_kv_caches(
unique_tensor[1].data_ptr(), unique_tensor[1].data_ptr(),
] ]
expected_num_entries = 4 expected_num_entries = 4
expected_blocks_count = 8 expected_blocks_count = kv_cache_config.num_blocks * 4
# Execute register_kv_caches # Execute register_kv_caches
connector.register_kv_caches(kv_caches) connector.register_kv_caches(kv_caches)
...@@ -1639,7 +1666,7 @@ def test_register_kv_caches( ...@@ -1639,7 +1666,7 @@ def test_register_kv_caches(
num_blocks = 8 num_blocks = 8
expected_block_len = expected_tensor_size // num_blocks expected_block_len = expected_tensor_size // num_blocks
else: else:
num_blocks = 2 num_blocks = kv_cache_config.num_blocks
if is_blocks_first: if is_blocks_first:
expected_block_len = expected_tensor_size // num_blocks // 2 expected_block_len = expected_tensor_size // num_blocks // 2
else: else:
...@@ -2226,15 +2253,22 @@ def test_compatibility_hash_validation( ...@@ -2226,15 +2253,22 @@ def test_compatibility_hash_validation(
"enforce_handshake_compat": enforce_handshake_compat "enforce_handshake_compat": enforce_handshake_compat
}, },
) )
kv_cache_config = make_kv_cache_config(block_size=16, num_blocks=2)
decode_connector = NixlConnector( decode_connector = NixlConnector(
local_vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) local_vllm_config, KVConnectorRole.WORKER, kv_cache_config
) )
decode_worker = decode_connector.connector_worker decode_worker = decode_connector.connector_worker
kv_cache_spec = cast(
AttentionSpec, kv_cache_config.kv_cache_groups[0].kv_cache_spec
)
kv_cache_shape = decode_worker.attn_backend.get_kv_cache_shape( kv_cache_shape = decode_worker.attn_backend.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 num_blocks=kv_cache_config.num_blocks,
block_size=kv_cache_spec.block_size,
num_kv_heads=kv_cache_spec.num_kv_heads,
head_size=kv_cache_spec.head_size,
) )
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) shared_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) unique_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype)
kv_caches = { kv_caches = {
"layer0": shared_tensor, "layer0": shared_tensor,
"layer1": unique_tensor, "layer1": unique_tensor,
......
...@@ -38,7 +38,7 @@ from vllm.v1.kv_cache_interface import ( ...@@ -38,7 +38,7 @@ from vllm.v1.kv_cache_interface import (
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.utils import AttentionGroup, select_common_block_size from vllm.v1.worker.utils import select_common_block_size
BLOCK_SIZE = 16 BLOCK_SIZE = 16
NUM_BLOCKS = 10 NUM_BLOCKS = 10
...@@ -203,37 +203,25 @@ def _make_kv_cache_spec() -> FullAttentionSpec: ...@@ -203,37 +203,25 @@ def _make_kv_cache_spec() -> FullAttentionSpec:
def test_select_common_block_size_prefers_manager_block_size(): def test_select_common_block_size_prefers_manager_block_size():
backend_a = _make_mock_backend_for_kernel_block_size([MultipleOf(32)]) backend_a = _make_mock_backend_for_kernel_block_size([MultipleOf(32)])
backend_b = _make_mock_backend_for_kernel_block_size([64, MultipleOf(16)]) backend_b = _make_mock_backend_for_kernel_block_size([64, MultipleOf(16)])
attn_groups = [
AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0),
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
]
selected_size = select_common_block_size(128, attn_groups) selected_size = select_common_block_size(128, [backend_a, backend_b])
assert selected_size == 128 assert selected_size == 128
def test_select_common_block_size_uses_largest_shared_int(): def test_select_common_block_size_uses_largest_shared_int():
backend_a = _make_mock_backend_for_kernel_block_size([128, 64]) backend_a = _make_mock_backend_for_kernel_block_size([128, 64])
backend_b = _make_mock_backend_for_kernel_block_size([64, 32]) backend_b = _make_mock_backend_for_kernel_block_size([64, 32])
attn_groups = [
AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0),
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
]
selected_size = select_common_block_size(256, attn_groups) selected_size = select_common_block_size(256, [backend_a, backend_b])
assert selected_size == 64 assert selected_size == 64
def test_select_common_block_size_no_valid_option(): def test_select_common_block_size_no_valid_option():
backend_a = _make_mock_backend_for_kernel_block_size([64]) backend_a = _make_mock_backend_for_kernel_block_size([64])
backend_b = _make_mock_backend_for_kernel_block_size([MultipleOf(16)]) backend_b = _make_mock_backend_for_kernel_block_size([MultipleOf(16)])
attn_groups = [
AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0),
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
]
with pytest.raises(ValueError): with pytest.raises(ValueError):
select_common_block_size(48, attn_groups) select_common_block_size(48, [backend_a, backend_b])
def test_update_states_new_request(model_runner, dist_init): def test_update_states_new_request(model_runner, dist_init):
......
...@@ -358,15 +358,6 @@ class TpKVTopology: ...@@ -358,15 +358,6 @@ class TpKVTopology:
# stride_order to retrieve physical position of block_size # stride_order to retrieve physical position of block_size
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
# In the default non-cross layers layout the block_size position
# is logical while in the cross layers case it is the physical
# position. This matches the shape of the actual kv cache tensors
# passed at register_kv_caches()/register_cross_layers_kv_cache()
block_size_position = kv_cache_shape.index(_MOCK_BLOCK_SIZE)
assert block_size_position is not None
self._block_size_position = -(len(kv_cache_shape) - block_size_position)
@property @property
def is_kv_layout_blocks_first(self) -> bool: def is_kv_layout_blocks_first(self) -> bool:
return self._is_kv_layout_blocks_first return self._is_kv_layout_blocks_first
...@@ -390,10 +381,6 @@ class TpKVTopology: ...@@ -390,10 +381,6 @@ class TpKVTopology:
def cross_layers_blocks(self) -> bool: def cross_layers_blocks(self) -> bool:
return self._cross_layers_blocks return self._cross_layers_blocks
@property
def block_size_position(self) -> int:
return self._block_size_position
def tp_ratio( def tp_ratio(
self, self,
remote_tp_size: int, remote_tp_size: int,
...@@ -484,23 +471,46 @@ class TpKVTopology: ...@@ -484,23 +471,46 @@ class TpKVTopology:
return self.get_target_remote_ranks(remote_tp_size) return self.get_target_remote_ranks(remote_tp_size)
def get_current_attn_backend(vllm_config: VllmConfig): def get_current_attn_backends(
vllm_config: VllmConfig, layer_names: list[str] | None = None
) -> list[type[AttentionBackend]]:
"""Get all distinct attention backends for the given layers.
Args:
vllm_config: The current vLLM configuration.
layer_names: Optional list of layer names to scope the lookup.
When None, all attention layers are considered.
Returns:
Deduplicated list of attention backend classes.
"""
layer_type = cast(type[Any], AttentionLayerBase) layer_type = cast(type[Any], AttentionLayerBase)
layers = get_layers_from_vllm_config(vllm_config, layer_type, None) layers = get_layers_from_vllm_config(vllm_config, layer_type, layer_names)
if layers: if layers:
backend = next(iter(layers.values())).get_attn_backend() seen: dict[str, type[AttentionBackend]] = {}
else: for layer in layers.values():
backend = layer.get_attn_backend()
seen[backend.full_cls_name()] = backend
return list(seen.values())
# Fallback for tests, when static_forward_context is empty. # Fallback for tests, when static_forward_context is empty.
logger.debug( logger.debug(
"No layers found in the vLLM config. " "No layers found in the vLLM config. Falling back to default attention backend."
"Falling back to default attention backend."
) )
from vllm.v1.attention.selector import get_attn_backend from vllm.v1.attention.selector import get_attn_backend
backend = get_attn_backend( return [
get_attn_backend(
head_size=vllm_config.model_config.get_head_size(), head_size=vllm_config.model_config.get_head_size(),
dtype=vllm_config.model_config.dtype, dtype=vllm_config.model_config.dtype,
kv_cache_dtype=vllm_config.cache_config.cache_dtype, kv_cache_dtype=vllm_config.cache_config.cache_dtype,
use_mla=vllm_config.model_config.use_mla, use_mla=vllm_config.model_config.use_mla,
) )
return backend ]
def get_current_attn_backend(
vllm_config: VllmConfig, layer_names: list[str] | None = None
) -> type[AttentionBackend]:
"""Get the first attention backend for the given layers."""
return get_current_attn_backends(vllm_config, layer_names)[0]
...@@ -13,7 +13,7 @@ from collections import defaultdict ...@@ -13,7 +13,7 @@ from collections import defaultdict
from collections.abc import Iterator from collections.abc import Iterator
from concurrent.futures import Future, ThreadPoolExecutor from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, cast
import msgspec import msgspec
import numpy as np import numpy as np
...@@ -27,6 +27,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import ( ...@@ -27,6 +27,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import (
EngineId, EngineId,
TpKVTopology, TpKVTopology,
get_current_attn_backend, get_current_attn_backend,
get_current_attn_backends,
kv_postprocess_blksize_and_layout_on_receive, kv_postprocess_blksize_and_layout_on_receive,
kv_postprocess_blksize_on_receive, kv_postprocess_blksize_on_receive,
kv_postprocess_layout_on_receive, kv_postprocess_layout_on_receive,
...@@ -61,6 +62,7 @@ from vllm.v1.attention.backends.utils import get_kv_cache_layout ...@@ -61,6 +62,7 @@ from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, SlidingWindowSpec from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, SlidingWindowSpec
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.utils import select_common_block_size
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
...@@ -945,7 +947,8 @@ class NixlConnectorWorker: ...@@ -945,7 +947,8 @@ class NixlConnectorWorker:
# Config. # Config.
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size # mypy will complain on re-assignment otherwise.
self.block_size: int = cast(int, vllm_config.cache_config.block_size)
if vllm_config.kv_transfer_config is None: if vllm_config.kv_transfer_config is None:
raise ValueError("kv_transfer_config must be set for NixlConnector") raise ValueError("kv_transfer_config must be set for NixlConnector")
...@@ -993,7 +996,7 @@ class NixlConnectorWorker: ...@@ -993,7 +996,7 @@ class NixlConnectorWorker:
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
self.world_size = get_tensor_model_parallel_world_size() self.world_size = get_tensor_model_parallel_world_size()
self.tp_group = get_tp_group() self.tp_group = get_tp_group()
self.num_blocks = 0 self.num_blocks = kv_cache_config.num_blocks
self.enable_permute_local_kv = False self.enable_permute_local_kv = False
# KV Caches and nixl tracking data. # KV Caches and nixl tracking data.
...@@ -1131,11 +1134,30 @@ class NixlConnectorWorker: ...@@ -1131,11 +1134,30 @@ class NixlConnectorWorker:
self.xfer_stats = NixlKVConnectorStats() self.xfer_stats = NixlKVConnectorStats()
self._physical_blocks_per_logical_kv_block = 1 self._physical_blocks_per_logical_kv_block = 1
self._sync_block_size_with_kernel()
self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config( self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config(
"enforce_handshake_compat", True "enforce_handshake_compat", True
) )
def _sync_block_size_with_kernel(self) -> None:
backends = get_current_attn_backends(self.vllm_config)
kernel_block_size = select_common_block_size(self.block_size, backends)
if self.block_size != kernel_block_size:
logger.info_once(
"User-specified logical block size (%s) does not match"
" physical kernel block size (%s). Using the latter.",
self.block_size,
kernel_block_size,
)
assert self.block_size > kernel_block_size
self._physical_blocks_per_logical_kv_block = (
self.block_size // kernel_block_size
)
self.block_size = kernel_block_size
self._block_size[self.engine_id] = kernel_block_size
self.num_blocks *= self._physical_blocks_per_logical_kv_block
def _nixl_handshake( def _nixl_handshake(
self, self,
host: str, host: str,
...@@ -1469,7 +1491,6 @@ class NixlConnectorWorker: ...@@ -1469,7 +1491,6 @@ class NixlConnectorWorker:
# Enable different block lengths for different layers when MLA is used. # Enable different block lengths for different layers when MLA is used.
self.block_len_per_layer = list[int]() self.block_len_per_layer = list[int]()
self.slot_size_per_layer = list[int]() # HD bytes in kv terms
for layer_name, cache_or_caches in xfer_buffers.items(): for layer_name, cache_or_caches in xfer_buffers.items():
cache_list = ( cache_list = (
cache_or_caches if self.kv_topo.split_k_and_v else [cache_or_caches] cache_or_caches if self.kv_topo.split_k_and_v else [cache_or_caches]
...@@ -1486,26 +1507,11 @@ class NixlConnectorWorker: ...@@ -1486,26 +1507,11 @@ class NixlConnectorWorker:
logger.debug( logger.debug(
"Registering layer %s with cache shape: %s", layer_name, cache.shape "Registering layer %s with cache shape: %s", layer_name, cache.shape
) )
kernel_block_size = cache.shape[self.kv_topo.block_size_position]
if self.block_size != kernel_block_size:
logger.info_once(
"User-specified logical block size (%s) does not match"
" physical kernel block size (%s). Using the latter. ",
self.block_size,
kernel_block_size,
)
self._physical_blocks_per_logical_kv_block = (
self.block_size // kernel_block_size
)
self.block_size = kernel_block_size
self._block_size[self.engine_id] = kernel_block_size
seen_base_addresses.append(base_addr) seen_base_addresses.append(base_addr)
curr_tensor_size_bytes = cache.numel() * cache.element_size() curr_tensor_size_bytes = cache.numel() * cache.element_size()
if tensor_size_bytes is None: if tensor_size_bytes is None:
tensor_size_bytes = curr_tensor_size_bytes tensor_size_bytes = curr_tensor_size_bytes
self.num_blocks = cache.shape[0]
assert cache.shape[0] == self.num_blocks, ( assert cache.shape[0] == self.num_blocks, (
"All kv cache tensors must have the same number of blocks" "All kv cache tensors must have the same number of blocks"
...@@ -1514,9 +1520,6 @@ class NixlConnectorWorker: ...@@ -1514,9 +1520,6 @@ class NixlConnectorWorker:
self.block_len_per_layer.append( self.block_len_per_layer.append(
curr_tensor_size_bytes // self.num_blocks curr_tensor_size_bytes // self.num_blocks
) )
self.slot_size_per_layer.append(
self.block_len_per_layer[-1] // self.block_size
)
if not self.use_mla: if not self.use_mla:
# Different kv cache shape is not supported by HeteroTP # Different kv cache shape is not supported by HeteroTP
...@@ -1534,7 +1537,6 @@ class NixlConnectorWorker: ...@@ -1534,7 +1537,6 @@ class NixlConnectorWorker:
"Different block lengths collected: %s", set(self.block_len_per_layer) "Different block lengths collected: %s", set(self.block_len_per_layer)
) )
assert len(self.block_len_per_layer) == len(seen_base_addresses) assert len(self.block_len_per_layer) == len(seen_base_addresses)
assert self.num_blocks != 0
self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses
self.num_regions = len(caches_data) self.num_regions = len(caches_data)
...@@ -1550,10 +1552,6 @@ class NixlConnectorWorker: ...@@ -1550,10 +1552,6 @@ class NixlConnectorWorker:
self.dst_num_blocks[self.engine_id] = self.num_blocks self.dst_num_blocks[self.engine_id] = self.num_blocks
if self.kv_topo.is_kv_layout_blocks_first: if self.kv_topo.is_kv_layout_blocks_first:
for i in range(len(self.slot_size_per_layer)):
assert self.slot_size_per_layer[i] % 2 == 0
self.slot_size_per_layer[i] //= 2
# NOTE (NickLucche) When FlashInfer is used, memory is registered # NOTE (NickLucche) When FlashInfer is used, memory is registered
# with joint KV for each block. This minimizes the overhead in # with joint KV for each block. This minimizes the overhead in
# registerMem allowing faster descs queries. In order to be able to # registerMem allowing faster descs queries. In order to be able to
......
...@@ -258,7 +258,8 @@ class AttentionGroup: ...@@ -258,7 +258,8 @@ class AttentionGroup:
def select_common_block_size( def select_common_block_size(
kv_manager_block_size: int, attn_groups: list[AttentionGroup] kv_manager_block_size: int,
backends: list[type[AttentionBackend]],
) -> int: ) -> int:
""" """
Select a block size that is supported by all backends and is a factor of Select a block size that is supported by all backends and is a factor of
...@@ -269,7 +270,7 @@ def select_common_block_size( ...@@ -269,7 +270,7 @@ def select_common_block_size(
Args: Args:
kv_manager_block_size: Block size of KV cache. kv_manager_block_size: Block size of KV cache.
attn_groups: List of attention groups. backends: List of attention backend classes.
Returns: Returns:
The selected block size. The selected block size.
...@@ -297,8 +298,6 @@ def select_common_block_size( ...@@ -297,8 +298,6 @@ def select_common_block_size(
return False return False
return True return True
backends = [group.backend for group in attn_groups]
# Case 1: if the block_size of kv cache manager is supported by all backends, # Case 1: if the block_size of kv cache manager is supported by all backends,
# return it directly. # return it directly.
if block_size_is_supported(backends, kv_manager_block_size): if block_size_is_supported(backends, kv_manager_block_size):
...@@ -356,8 +355,9 @@ def prepare_kernel_block_sizes( ...@@ -356,8 +355,9 @@ def prepare_kernel_block_sizes(
if isinstance(kv_cache_spec, AttentionSpec): if isinstance(kv_cache_spec, AttentionSpec):
# This is an attention backend that supports virtual block splitting. # This is an attention backend that supports virtual block splitting.
kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size
group_backends = [g.backend for g in attn_groups[kv_cache_gid]]
selected_kernel_size = select_common_block_size( selected_kernel_size = select_common_block_size(
kv_manager_block_size, attn_groups[kv_cache_gid] kv_manager_block_size, group_backends
) )
kernel_block_sizes.append(selected_kernel_size) kernel_block_sizes.append(selected_kernel_size)
elif isinstance(kv_cache_spec, MambaSpec): elif isinstance(kv_cache_spec, MambaSpec):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment