"lib/runtime/vscode:/vscode.git/clone" did not exist on "e232bec07ceea0b18c94dacc7c052a6ee1daaaf3"
Unverified Commit 5b3ba94a authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[Core][KVConnector] Support HMA+NixlConnector (#35758)


Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
parent 90f3c01f
......@@ -12,6 +12,7 @@ tp_configs=(
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA case
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192" # SW model
)
dp_ep_configs=(
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP1, D-DPEP=2 (TP=1)
......@@ -26,6 +27,14 @@ else
configs=("${tp_configs[@]}")
fi
if [[ -n "${ENABLE_HMA_FLAG:-}" ]]; then
# Append ENABLE_HMA_FLAG=1 to each config in the selected array
echo "ENABLE_HMA_FLAG is set, appending ENABLE_HMA_FLAG=1 to each config"
for i in "${!configs[@]}"; do
configs[$i]="ENABLE_HMA_FLAG=1 ${configs[$i]}"
done
fi
run_tests() {
local label=$1
local extra_args=$2
......
......@@ -5,6 +5,12 @@ set -xe
KV_BUFFER_DEVICE="cuda" # Default to cuda
ATTENTION_BACKEND="" # Default to empty (use vllm default)
CROSS_LAYERS_BLOCKS="False"
ENABLE_HMA_VAR="" # Default to empty (HMA disabled by default for kv connector)
# Check for ENABLE_HMA_FLAG environment variable
if [[ -n "${ENABLE_HMA_FLAG:-}" ]]; then
ENABLE_HMA_VAR="--no-disable-hybrid-kv-cache-manager"
fi
while [[ $# -gt 0 ]]; do
case $1 in
--kv_buffer_device)
......@@ -31,6 +37,12 @@ echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE"
if [[ -n "$ATTENTION_BACKEND" ]]; then
echo "Using attention backend: $ATTENTION_BACKEND"
fi
if [[ -n "$ENABLE_HMA_VAR" ]]; then
echo "HMA (Hybrid KV Cache Manager) enabled"
fi
if [[ -n "$VLLM_SERVE_EXTRA_ARGS" ]]; then
echo "vLLM serve extra args: $VLLM_SERVE_EXTRA_ARGS"
fi
DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"HND"} # Default to HND, optional NHD
if [[ "$DECODER_KV_LAYOUT" == "NHD" ]]; then
......@@ -70,6 +82,8 @@ DECODER_TP_SIZE=${DECODER_TP_SIZE:-1}
GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.2}
PREFILL_BLOCK_SIZE=${PREFILL_BLOCK_SIZE:-128}
DECODE_BLOCK_SIZE=${DECODE_BLOCK_SIZE:-128}
# Comma-separated extra args for vllm serve (e.g. --max-model-len,2048)
VLLM_SERVE_EXTRA_ARGS=${VLLM_SERVE_EXTRA_ARGS:-}
# Find the git repository root directory
GIT_ROOT=$(git rev-parse --show-toplevel)
......@@ -151,14 +165,24 @@ run_tests_for_model() {
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--tensor-parallel-size $PREFILLER_TP_SIZE \
--kv-transfer-config '$KV_CONFIG'"
if [[ -n "$VLLM_SERVE_EXTRA_ARGS" ]]; then
IFS=',' read -r -a extra_args <<< "$VLLM_SERVE_EXTRA_ARGS"
for arg in "${extra_args[@]}"; do
BASE_CMD="${BASE_CMD} $arg"
done
fi
# Add attention backend config if specified
if [[ -n "$ATTENTION_BACKEND" ]]; then
BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND"
fi
# Add HMA flag if specified
if [[ -n "$ENABLE_HMA_VAR" ]]; then
BASE_CMD="${BASE_CMD} $ENABLE_HMA_VAR"
fi
FULL_CMD="$BASE_CMD"
eval "$FULL_CMD &"
# Store host and port for proxy configuration
......@@ -193,12 +217,23 @@ run_tests_for_model() {
--block-size ${DECODE_BLOCK_SIZE} \
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--kv-transfer-config '$KV_CONFIG'"
if [[ -n "$VLLM_SERVE_EXTRA_ARGS" ]]; then
IFS=',' read -r -a extra_args <<< "$VLLM_SERVE_EXTRA_ARGS"
for arg in "${extra_args[@]}"; do
BASE_CMD="${BASE_CMD} $arg"
done
fi
# Add attention backend config if specified
if [[ -n "$ATTENTION_BACKEND" ]]; then
BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND"
fi
# Add HMA flag if specified
if [[ -n "$ENABLE_HMA_VAR" ]]; then
BASE_CMD="${BASE_CMD} $ENABLE_HMA_VAR"
fi
# DP-EP attention mode
if [[ -z "$DP_EP" ]]; then
BASE_CMD="${BASE_CMD} --tensor-parallel-size $DECODER_TP_SIZE"
......
......@@ -17,6 +17,7 @@ EXPECTED_VALUES = {
"deepseek-ai/deepseek-vl2-small": 0.59,
"deepseek-ai/deepseek-vl2-tiny": 0.19,
"deepseek-ai/DeepSeek-V2-Lite-Chat": 0.65,
"google/gemma-3-4b-it": 0.74,
}
SIMPLE_PROMPT = (
......
......@@ -59,7 +59,12 @@ from vllm.v1.request import RequestStatus
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
from vllm.v1.worker.utils import AttentionGroup
from .utils import create_request, create_scheduler, create_vllm_config
from .utils import (
create_request,
create_scheduler,
create_vllm_config,
make_kv_cache_config,
)
@pytest.fixture(scope="module", autouse=True)
......@@ -263,7 +268,7 @@ def test_basic_interface():
req_meta = kv_connector_metadata.reqs_to_recv[request_id]
for block_id, block in zip(
req_meta.local_block_ids,
req_meta.local_block_ids[0],
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[
request_id
],
......@@ -327,7 +332,9 @@ def test_kv_transfer_handshake(dist_init):
# Prefill connector will register KV cache to populate proper handshake
# metadata.
prefill_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
prefill_connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
......@@ -367,13 +374,17 @@ def test_kv_transfer_handshake(dist_init):
do_remote_decode=True,
)
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
delay, kv_connector_metadata = scheduler.get_kv_connector().request_finished(
request, [0, 1, 2]
delay, kv_connector_metadata = (
scheduler.get_kv_connector().request_finished_all_groups(
request, ([0, 1, 2],)
)
)
assert delay
# Decode connector will be able to create handshake with the prefill connector.
decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
decode_connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
decode_connector.register_kv_caches(kv_caches)
# Here we are testing the retrieval of NIXLAgentMetadata.
......@@ -404,9 +415,16 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
REMOTE_ENGINE_ID = "remote_engine"
def __init__(
self, *args, hand_shake_latency: float = 1.8, kv_cache_layout="HND", **kwargs
self,
*args,
hand_shake_latency: float = 1.8,
kv_cache_layout="HND",
kv_cache_config=None,
**kwargs,
):
super().__init__(*args, **kwargs)
if kv_cache_config is None:
kv_cache_config = make_kv_cache_config(block_size=16)
super().__init__(*args, kv_cache_config=kv_cache_config, **kwargs)
self._hand_shake_latency = hand_shake_latency
self.kv_cache_layout = kv_cache_layout
# Mock register_kv_caches attribute needed for tests that do not call it.
......@@ -507,7 +525,9 @@ class TestNixlHandshake:
request_id = "req_id"
# Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
......@@ -528,13 +548,15 @@ class TestNixlHandshake:
num_xfers -= 1
metadata.add_new_req_to_recv(
request_id=request_id,
local_block_ids=[num_xfers + 1, num_xfers + 2, num_xfers + 3],
local_block_ids=([num_xfers + 1, num_xfers + 2, num_xfers + 3],),
kv_transfer_params={
"remote_block_ids": [
num_xfers + 4,
num_xfers + 5,
num_xfers + 6,
],
"remote_block_ids": (
[
num_xfers + 4,
num_xfers + 5,
num_xfers + 6,
],
),
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": f"prefill-{request_id}",
"remote_host": "localhost",
......@@ -594,16 +616,18 @@ class TestNixlHandshake:
vllm_config.parallel_config.tensor_parallel_size = decode_tp_size
# Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id
)
metadata = NixlConnectorMetadata()
metadata.add_new_req_to_recv(
request_id="id",
local_block_ids=[1, 2, 3],
local_block_ids=([1, 2, 3],),
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_block_ids": ([4, 5, 6],),
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": "prefill-id",
"remote_host": "localhost",
......@@ -652,7 +676,9 @@ class TestNixlHandshake:
local_tp_size = 1
vllm_config.parallel_config.tensor_parallel_size = local_tp_size
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
......@@ -717,8 +743,12 @@ class TestNixlHandshake:
p_tp_size = 2
# Build two separate connectors/workers to emulate P TP=2 ranks.
conn_p0 = NixlConnector(vllm_config, KVConnectorRole.WORKER)
conn_p1 = NixlConnector(vllm_config, KVConnectorRole.WORKER)
conn_p0 = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
conn_p1 = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
conn_p0.connector_worker = FakeNixlConnectorWorker(
vllm_config, conn_p0.engine_id, hand_shake_latency=0
)
......@@ -815,7 +845,9 @@ class TestNixlHandshake:
vllm_config = create_vllm_config()
# Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id
)
......@@ -827,9 +859,9 @@ class TestNixlHandshake:
for i in range(total_reqs):
metadata.add_new_req_to_recv(
request_id=f"id_{i}",
local_block_ids=[1, 2, 3],
local_block_ids=([1, 2, 3],),
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_block_ids": ([4, 5, 6],),
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": f"prefill-id-{i}",
"remote_host": "localhost",
......@@ -884,7 +916,9 @@ class TestNixlHandshake:
return_value=2,
):
# Initialize connector and worker (with fake NIXL wrapper)
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
......@@ -934,7 +968,9 @@ class TestNixlHandshake:
return_value=2,
):
# Initialize connector and worker (with fake NIXL wrapper)
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config,
connector.engine_id,
......@@ -979,7 +1015,9 @@ def test_kv_connector_stats(default_vllm_config, dist_init):
vllm_config = create_vllm_config()
# Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
......@@ -993,9 +1031,9 @@ def test_kv_connector_stats(default_vllm_config, dist_init):
metadata = NixlConnectorMetadata()
metadata.add_new_req_to_recv(
request_id=request_id,
local_block_ids=[1, 2, 3],
local_block_ids=([1, 2, 3],),
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_block_ids": ([4, 5, 6],),
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": f"prefill-{request_id}",
"remote_host": "localhost",
......@@ -1448,7 +1486,9 @@ def test_register_kv_caches(
mock_get_attn_backend.return_value = backend_cls
# Create connector
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
......@@ -1676,7 +1716,9 @@ def test_kv_buffer_to_nixl_memory_types(
),
): # noqa: E501
# Create connector and replace its worker with a fake one for isolation
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
# Verify get_reg_descs was called with the correct memory_type
assert connector.connector_worker.kv_buffer_device == kv_buffer_device
......@@ -1692,9 +1734,15 @@ def test_shutdown_cleans_up_resources(default_vllm_config, dist_init):
vllm_config = create_vllm_config()
scheduler = NixlConnectorScheduler(
vllm_config, vllm_config.kv_transfer_config.engine_id
vllm_config,
vllm_config.kv_transfer_config.engine_id,
make_kv_cache_config(block_size=16),
)
worker = NixlConnectorWorker(
vllm_config,
vllm_config.kv_transfer_config.engine_id,
make_kv_cache_config(block_size=16),
)
worker = NixlConnectorWorker(vllm_config, vllm_config.kv_transfer_config.engine_id)
nixl_wrapper = worker.nixl_wrapper
with (
......@@ -1756,7 +1804,9 @@ def test_aborted_request_removed_from_worker_in_batch(default_vllm_config, dist_
scheduler = create_scheduler(vllm_config)
# KVConnector Worker in P
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
......@@ -1875,12 +1925,14 @@ class FailingNixlWrapper(FakeNixlWrapper):
("transfer_exception", {"fail_transfer_exception": True}, True),
],
)
@pytest.mark.parametrize("enable_hma", [False, True])
def test_transfer_failure_logging(
default_vllm_config,
dist_init,
failure_type,
wrapper_config,
needs_get_finished,
enable_hma,
):
"""Test that transfer failures are logged with structured context.
......@@ -1897,9 +1949,16 @@ def test_transfer_failure_logging(
vllm_config = create_vllm_config()
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector = NixlConnector(
vllm_config,
KVConnectorRole.WORKER,
make_kv_cache_config(block_size=16, hma_enabled=enable_hma),
)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0.0
vllm_config,
connector.engine_id,
hand_shake_latency=0.0,
kv_cache_config=connector._kv_cache_config,
)
# Configure FailingNixlWrapper to fail in the specified way
......@@ -1910,8 +1969,17 @@ def test_transfer_failure_logging(
# For notification_failed, we need empty local blocks
# (full cache hit path to trigger send_notif)
local_blocks = [] if failure_type == "notification_failed" else [10, 11, 12]
remote_blocks = [20, 21, 22]
local_blocks: tuple[()] | tuple[list[int], ...]
if enable_hma:
# HMA enabled: multiple groups (FA + SW)
local_blocks = (
() if failure_type == "notification_failed" else ([10, 11, 12], [13, 14])
)
remote_blocks = [[20, 21, 22], [23, 24]]
else:
# HMA disabled: single group
local_blocks = () if failure_type == "notification_failed" else ([10, 11, 12],)
remote_blocks = [[20, 21, 22]]
metadata = NixlConnectorMetadata()
metadata.add_new_req_to_recv(
......@@ -2007,7 +2075,9 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
"""Test that handshake failures mark blocks invalid and return via get_finished."""
vllm_config = create_vllm_config()
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0.1
)
......@@ -2017,9 +2087,9 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
metadata = NixlConnectorMetadata()
metadata.add_new_req_to_recv(
request_id=request_id,
local_block_ids=[1, 2, 3],
local_block_ids=([1, 2, 3],),
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_block_ids": ([4, 5, 6],),
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": f"prefill-{request_id}",
"remote_host": "localhost",
......@@ -2058,7 +2128,9 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init)
and return via get_finished."""
vllm_config = create_vllm_config()
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
......@@ -2068,9 +2140,9 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init)
metadata = NixlConnectorMetadata()
metadata.add_new_req_to_recv(
request_id=request_id,
local_block_ids=[7, 8, 9],
local_block_ids=([7, 8, 9],),
kv_transfer_params={
"remote_block_ids": [10, 11, 12],
"remote_block_ids": ([10, 11, 12],),
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": f"prefill-{request_id}",
"remote_host": "localhost",
......@@ -2154,7 +2226,9 @@ def test_compatibility_hash_validation(
"enforce_handshake_compat": enforce_handshake_compat
},
)
decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER)
decode_connector = NixlConnector(
local_vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
decode_worker = decode_connector.connector_worker
kv_cache_shape = decode_worker.attn_backend.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
......@@ -2267,7 +2341,9 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
model="facebook/opt-125m",
block_size=16,
)
decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER)
decode_connector = NixlConnector(
local_vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
decode_worker = decode_connector.connector_worker
backend = get_current_attn_backend(local_vllm_config)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for NixlConnectorScheduler sw_sizes calculation with HMA."""
from unittest.mock import patch
import pytest
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
from vllm.v1.core.single_type_kv_cache_manager import (
FullAttentionManager,
SlidingWindowManager,
)
from .utils import (
create_vllm_config,
make_kv_cache_config,
)
@pytest.mark.cpu_test
@pytest.mark.parametrize(
"hma_enabled,expected_sw_sizes",
[
# HMA enabled: FullAttentionSpec (0) + SlidingWindowSpec (2048/16=128)
(True, [0, 128 + 1]),
# HMA disabled: only FullAttentionSpec (0)
(False, [0]),
],
)
@patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform")
def test_sw_sizes(mock_platform, hma_enabled, expected_sw_sizes):
"""Test sw_sizes is correctly computed based on HMA enabled/disabled."""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorScheduler,
)
mock_platform.device_type = "cpu"
block_size = 16
vllm_config = create_vllm_config(block_size=block_size)
# SW 2048 tokens=>128 blocks
kv_cache_config = make_kv_cache_config(
block_size=block_size, hma_enabled=hma_enabled, sw_size=2048
)
scheduler = NixlConnectorScheduler(
vllm_config=vllm_config,
engine_id="test-engine",
kv_cache_config=kv_cache_config,
)
# in number of blocks
assert scheduler.blocks_per_sw == expected_sw_sizes, (
f"Expected sw_sizes={expected_sw_sizes}, got {scheduler.blocks_per_sw}"
)
@pytest.mark.cpu_test
def test_logical_to_kernel_block_ids_with_hma():
"""Test _logical_to_kernel_block_ids expands blocks when HMA is enabled.
When HMA is enabled, the logical block size may differ from the kernel
block size. Each logical block maps to multiple kernel blocks.
"""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorWorker,
)
# Create a mock worker with just the required attributes
# (use __new__ to skip __init__)
worker = object.__new__(NixlConnectorWorker)
# Simulate HMA scenario: logical block size = 32, kernel block size = 16
# So each logical block maps to 2 kernel blocks eg [0]->[0,1]
worker._physical_blocks_per_logical_kv_block = 2
# Test conversion: FA + SW group
logical_block_ids = [[0, 1, 2], [3, 4]]
kernel_block_ids = worker._logical_to_kernel_block_ids(logical_block_ids)
expected_kernel_block_ids = [[0, 1, 2, 3, 4, 5], [6, 7, 8, 9]]
assert kernel_block_ids == expected_kernel_block_ids, (
f"Expected {expected_kernel_block_ids}, got {kernel_block_ids}"
)
@pytest.mark.parametrize("model_name, sw_size", [("google/gemma-3-1b-it", 512)])
def test_fewer_blocks_with_hma(monkeypatch, model_name, sw_size):
"""Test that a prefill instance returns fewer "remote blocks" for the SWA groups
when sequence exceeds the sliding window.
"""
kv_transfer_config = KVTransferConfig(
kv_connector="NixlConnector",
kv_role="kv_both",
)
block_size = 16
llm_kwargs = {
"model": model_name,
"enforce_eager": True,
"gpu_memory_utilization": 0.5,
"kv_transfer_config": kv_transfer_config,
"max_model_len": 2048,
# NOTE: Make sure HMA is enabled
"disable_hybrid_kv_cache_manager": False,
"max_num_batched_tokens": 1024,
"enable_prefix_caching": False,
"block_size": block_size,
}
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
def run_hma_test(llm: LLM):
remote_prefill_opts = {
"do_remote_decode": True,
"do_remote_prefill": False,
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": None,
"remote_port": None,
}
# Simulate sidecar request
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=1,
extra_args={"kv_transfer_params": remote_prefill_opts},
)
scheduler = llm.llm_engine.engine_core.engine_core.scheduler
kv_managers = scheduler.kv_cache_manager.coordinator.single_type_managers
# HMA enabled with FA + SWA groups
assert len(kv_managers) > 2
for kv_manager in kv_managers:
assert isinstance(kv_manager, (SlidingWindowManager, FullAttentionManager))
req_to_blocks = kv_managers[0].req_to_blocks
assert len(req_to_blocks) == 0
# Process some request with length exceeding the sliding window
outputs = llm.generate(["hi" * 1401], sampling_params)
kv_params = outputs[0].kv_transfer_params
# +1 to account for overlapping window across blocks.
expected_num_remote_blocks = sw_size // block_size + 1
remote_block_ids = kv_params["remote_block_ids"]
assert (
len(remote_block_ids[0])
== expected_num_remote_blocks
< len(remote_block_ids[-1])
)
for group_block_ids in remote_block_ids[:-1]:
assert len(group_block_ids) == expected_num_remote_blocks
def run_test_and_cleanup():
llm = LLM(**llm_kwargs)
try:
run_hma_test(llm)
finally:
llm.llm_engine.engine_core.shutdown()
run_test_and_cleanup()
@pytest.mark.cpu_test
def test_nixl_metadata_hma_block_ids_structure():
"""
Test that NixlConnectorMetadata correctly stores block IDs for multiple
KV cache groups when HMA is enabled.
"""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorMetadata,
)
metadata = NixlConnectorMetadata()
# Add request with block IDs for 2 groups (FA + SW)
fa_blocks = [0, 1, 2, 3, 4, 5, 6, 7] # 8 blocks for FA
sw_blocks = [8, 9, 10, 11] # 4 blocks for SW (clipped)
metadata.add_new_req_to_recv(
request_id="test-req-hma",
local_block_ids=(fa_blocks, sw_blocks),
kv_transfer_params={
"remote_block_ids": ([10, 11, 12, 13, 14, 15, 16, 17], [18, 19, 20, 21]),
"remote_engine_id": "remote-engine",
"remote_request_id": "prefill-test-req-hma",
"remote_host": "localhost",
"remote_port": 1234,
"tp_size": 1,
},
)
assert "test-req-hma" in metadata.reqs_to_recv
req_meta = metadata.reqs_to_recv["test-req-hma"]
# Verify local block IDs structure
assert len(req_meta.local_block_ids) == 2
assert list(req_meta.local_block_ids[0]) == fa_blocks
assert list(req_meta.local_block_ids[1]) == sw_blocks
# Verify remote block IDs structure
assert req_meta.remote is not None
assert len(req_meta.remote.block_ids) == 2
assert list(req_meta.remote.block_ids[0]) == [10, 11, 12, 13, 14, 15, 16, 17]
assert list(req_meta.remote.block_ids[1]) == [18, 19, 20, 21]
......@@ -208,7 +208,9 @@ def test_prefix_cache_lifecycle():
# Ensure we send all block ids, including the partial blocks,
# even if there is a cache hit.
assert len(kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS + 1)
# remote_block_ids is BlockIds (tuple of lists); sum block counts across groups.
num_remote_blocks = sum(len(g) for g in kv_transfer_params["remote_block_ids"])
assert num_remote_blocks == (NUM_EXTERNAL_FULL_BLOCKS + 1)
# STEP (2): Ensure it is freed.
scheduler_output = scheduler.schedule()
......
......@@ -36,6 +36,7 @@ from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
SlidingWindowSpec,
)
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import Request
......@@ -142,24 +143,26 @@ def create_vllm_config(
def create_scheduler(
vllm_config: VllmConfig,
num_blocks: int = 10000,
kv_cache_config: KVCacheConfig | None = None,
) -> Scheduler:
"""Initialize Scheduler For Testing."""
block_size = vllm_config.cache_config.block_size
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer"],
FullAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
),
)
],
)
if kv_cache_config is None:
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer"],
FullAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
),
)
],
)
vllm_config.cache_config.num_gpu_blocks = num_blocks
return Scheduler(
vllm_config=vllm_config,
......@@ -412,3 +415,38 @@ KVConnectorFactory.register_connector(
KVConnectorFactory.register_connector(
"MockKVConnector", __name__, MockKVConnector.__name__
)
def make_kv_cache_config(
block_size: int,
hma_enabled: bool = False,
sw_size: int = 128,
num_blocks: int = 100,
) -> KVCacheConfig:
kv_cache_groups = [
KVCacheGroupSpec(
["layer0", "layer2"],
FullAttentionSpec(
block_size=block_size,
num_kv_heads=4,
head_size=16,
dtype=torch.float16,
),
)
]
if hma_enabled:
kv_cache_groups.append(
KVCacheGroupSpec(
["layer1", "layer3"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=4,
head_size=16,
dtype=torch.float16,
sliding_window=sw_size,
),
)
)
return KVCacheConfig(
num_blocks=num_blocks, kv_cache_tensors=[], kv_cache_groups=kv_cache_groups
)
......@@ -24,6 +24,9 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
EngineId = str
# block ids as returned by the hybrid KV cache manager. list[list[int]] are allow
# mutability and are for connector internal use only.
BlockIds = tuple[list[int], ...] | list[list[int]]
def get_kv_connector_cache_layout():
......
......@@ -3,7 +3,6 @@
import contextlib
import copy
import logging
import math
import os
import queue
import sys
......@@ -24,6 +23,7 @@ import zmq
from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import (
BlockIds,
EngineId,
TpKVTopology,
get_current_attn_backend,
......@@ -38,6 +38,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorHandshakeMetadata,
KVConnectorMetadata,
KVConnectorRole,
SupportsHMA,
)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics,
......@@ -53,10 +54,12 @@ from vllm.distributed.parallel_state import (
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.utils.network_utils import make_zmq_path, make_zmq_socket
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, SlidingWindowSpec
from vllm.v1.worker.block_table import BlockTable
if TYPE_CHECKING:
......@@ -205,6 +208,7 @@ def compute_nixl_compatibility_hash(
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
is_hma_enabled = not vllm_config.scheduler_config.disable_hybrid_kv_cache_manager
factors = {
# Version compatibility
......@@ -220,6 +224,7 @@ def compute_nixl_compatibility_hash(
"attn_backend_name": attn_backend_name,
"cache_dtype": str(cache_config.cache_dtype),
"cross_layers_blocks": cross_layers_blocks,
"is_hma_enabled": is_hma_enabled,
}
compat_hash = hash_factors(factors)
......@@ -238,7 +243,7 @@ def compute_nixl_compatibility_hash(
@dataclass
class RemoteMeta:
block_ids: list[int]
block_ids: BlockIds
host: str
port: int
engine_id: str
......@@ -247,9 +252,9 @@ class RemoteMeta:
@dataclass
class ReqMeta:
local_block_ids: list[int]
local_block_ids: BlockIds
# To be used when logical block size does not match the kernel block size
local_physical_block_ids: list[int]
local_physical_block_ids: BlockIds
tp_size: int
remote: RemoteMeta | None = None
......@@ -264,7 +269,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
def _add_new_req(
self,
local_block_ids: list[int],
local_block_ids: BlockIds,
kv_transfer_params: dict[str, Any],
) -> ReqMeta:
return ReqMeta(
......@@ -277,7 +282,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
def add_new_req_to_save(
self,
request_id: ReqId,
local_block_ids: list[int],
local_block_ids: BlockIds,
kv_transfer_params: dict[str, Any],
):
self.reqs_to_save[request_id] = self._add_new_req(
......@@ -287,7 +292,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
def add_new_req_to_recv(
self,
request_id: ReqId,
local_block_ids: list[int],
local_block_ids: BlockIds,
kv_transfer_params: dict[str, Any],
):
req = self._add_new_req(local_block_ids, kv_transfer_params)
......@@ -301,7 +306,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
self.reqs_to_recv[request_id] = req
class NixlConnector(KVConnectorBase_V1):
class NixlConnector(KVConnectorBase_V1, SupportsHMA):
@property
def prefer_cross_layer_blocks(self) -> bool:
backend = get_current_attn_backend(self._vllm_config)
......@@ -326,22 +331,27 @@ class NixlConnector(KVConnectorBase_V1):
self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: "KVCacheConfig | None" = None,
kv_cache_config: "KVCacheConfig",
):
super().__init__(vllm_config, role, kv_cache_config)
assert vllm_config.kv_transfer_config is not None
assert vllm_config.kv_transfer_config.engine_id is not None
for group in kv_cache_config.kv_cache_groups:
if isinstance(group.kv_cache_spec, MambaSpec):
raise ValueError("NixlConnector does not support Mamba models.")
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
self.kv_transfer_config = vllm_config.kv_transfer_config
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler: NixlConnectorScheduler | None = (
NixlConnectorScheduler(vllm_config, self.engine_id)
NixlConnectorScheduler(vllm_config, self.engine_id, kv_cache_config)
)
self.connector_worker: NixlConnectorWorker | None = None
elif role == KVConnectorRole.WORKER:
self.connector_scheduler = None
self.connector_worker = NixlConnectorWorker(vllm_config, self.engine_id)
self.connector_worker = NixlConnectorWorker(
vllm_config, self.engine_id, kv_cache_config
)
############################################################
# Class Methods
......@@ -392,10 +402,10 @@ class NixlConnector(KVConnectorBase_V1):
assert self.connector_scheduler is not None
return self.connector_scheduler.build_connector_meta(scheduler_output)
def request_finished(
def request_finished_all_groups(
self,
request: "Request",
block_ids: list[int],
block_ids: tuple[list[int], ...],
) -> tuple[bool, dict[str, Any] | None]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
......@@ -518,10 +528,13 @@ class NixlConnector(KVConnectorBase_V1):
class NixlConnectorScheduler:
"""Implementation of Scheduler side methods"""
def __init__(self, vllm_config: VllmConfig, engine_id: str):
def __init__(
self, vllm_config: VllmConfig, engine_id: str, kv_cache_config: "KVCacheConfig"
):
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
self.engine_id: EngineId = engine_id
self.kv_cache_config = kv_cache_config
self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
self.side_channel_port = (
envs.VLLM_NIXL_SIDE_CHANNEL_PORT
......@@ -534,8 +547,18 @@ class NixlConnectorScheduler:
self.use_host_buffer = (
vllm_config.kv_transfer_config.kv_buffer_device == "cpu"
)
self._is_hma_required = (
not vllm_config.scheduler_config.disable_hybrid_kv_cache_manager
# Also handle unlikely SW-only model case instead of checking num_groups>1.
and any(
not isinstance(g.kv_cache_spec, FullAttentionSpec)
for g in kv_cache_config.kv_cache_groups
)
)
logger.info("Initializing NIXL Scheduler %s", engine_id)
if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
logger.info("Hybrid Memory Allocator is enabled with NIXL")
# Background thread for handling new handshake requests.
self._nixl_handshake_listener_t: threading.Thread | None = None
......@@ -545,7 +568,7 @@ class NixlConnectorScheduler:
# Requests that need to start recv/send.
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
self._reqs_need_recv: dict[ReqId, tuple[Request, BlockIds]] = {}
self._reqs_need_save: dict[ReqId, Request] = {}
# Reqs to send and their expiration time
self._reqs_need_send: dict[ReqId, float] = {}
......@@ -554,12 +577,54 @@ class NixlConnectorScheduler:
# remote prefill or aborted.
self._reqs_not_processed: set[ReqId] = set()
# Gather Sliding Window sizes for each kv cache group (if any) in number of
# blocks per KV cache group. This is used to clip the local attention window.
sw_sizes_tokens: list[tuple[int, int]] = [
(g.kv_cache_spec.sliding_window, g.kv_cache_spec.block_size)
if isinstance(g.kv_cache_spec, SlidingWindowSpec)
else (0, self.block_size)
for g in kv_cache_config.kv_cache_groups
]
# cdiv(n_tokens, block_size) gives blocks/window; add 1 to conservatively
# account for boundary overlap eg window isn't fully aligned with blocks.
self.blocks_per_sw = [
cdiv(n_tokens, block_size) + 1 if n_tokens else 0
for n_tokens, block_size in sw_sizes_tokens
]
def shutdown(self):
self._stop_event.set()
if self._nixl_handshake_listener_t is not None:
self._nixl_handshake_listener_t.join()
self._nixl_handshake_listener_t = None
def get_sw_clipped_blocks(self, block_ids: BlockIds) -> BlockIds:
"""
Clip the number of blocks to the sliding window size for each kv cache group
that employs SWA.
This is necessary because the KV Cache manager initially allocates blocks for
the entire sequence length, and successively cleans up blocks that are outside
the window prior to the `request_finished_all_groups` hook.
"""
if len(block_ids) == 0 or not self._is_hma_required:
# No blocks to clip eg Full prefix cache hit or not a hybrid model.
return block_ids
# NOTE (NickLucche) This logic is currently handled at the connector level
# because offloading connectors might want to receive the whole sequence even
# for SWA groups. We will abstract this logic once the interface is more stable
assert len(block_ids) == len(self.blocks_per_sw), (
"Number of KV cache groups must match"
)
# For non-SWA groups, blocks_per_sw is 0 so we return all block_ids unchanged
return tuple(
[
blocks[-self.blocks_per_sw[i] :]
if self.blocks_per_sw[i] > 0
else blocks
for i, blocks in enumerate(block_ids)
]
)
def set_xfer_handshake_metadata(
self, metadata: dict[int, KVConnectorHandshakeMetadata]
) -> None:
......@@ -707,12 +772,18 @@ class NixlConnectorScheduler:
# If remote_blocks and num_external_tokens = 0, we have
# a full prefix cache hit on the D worker. We need to call
# send_notif in _read_blocks to free the memory on the P.
local_block_ids = (
blocks.get_unhashed_block_ids()
unhashed_local_block_ids: BlockIds = (
blocks.get_unhashed_block_ids_all_groups()
if num_external_tokens > 0
else []
else ()
)
# Get unhashed blocks to pull from remote.
local_block_ids = self.get_sw_clipped_blocks(
unhashed_local_block_ids
)
# Get unhashed blocks to pull from remote. Mind that a full prefix
# cache hit is indicated with an empty list.
self._reqs_need_recv[request.request_id] = (
request,
local_block_ids,
......@@ -753,9 +824,10 @@ class NixlConnectorScheduler:
req = req_to_save
assert req.kv_transfer_params is not None
clipped_block_id_groups = self.get_sw_clipped_blocks(new_block_id_groups)
meta.add_new_req_to_save(
request_id=req_id,
local_block_ids=new_block_id_groups[0],
local_block_ids=clipped_block_id_groups,
kv_transfer_params=req.kv_transfer_params,
)
assert scheduler_output.num_scheduled_tokens is not None
......@@ -786,7 +858,7 @@ class NixlConnectorScheduler:
def request_finished(
self,
request: "Request",
block_ids: list[int],
block_ids: BlockIds,
) -> tuple[bool, dict[str, Any] | None]:
"""
Once a request is finished, determine whether request blocks
......@@ -828,7 +900,7 @@ class NixlConnectorScheduler:
# TODO: check whether block_ids actually ever be 0. If not we could
# remove the conditional below
delay_free_blocks = len(block_ids) > 0
delay_free_blocks = any(len(group) > 0 for group in block_ids)
if delay_free_blocks:
# Prefill request on remote. It will be read from D upon completion
......@@ -841,6 +913,11 @@ class NixlConnectorScheduler:
self._reqs_need_send[request.request_id] = (
time.perf_counter() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT
)
# NOTE HMA will "mark" empty/null blocks in groups with 0s (eg SWA ones),
# trimming down after allocating for the whole sequence length. Empty
# blocks are always at the start of the list.
# Here we "unpad" blocks to send the actual remote blocks to be read.
block_ids = self.get_sw_clipped_blocks(block_ids)
return delay_free_blocks, dict(
do_remote_prefill=True,
......@@ -857,7 +934,9 @@ class NixlConnectorScheduler:
class NixlConnectorWorker:
"""Implementation of Worker side methods"""
def __init__(self, vllm_config: VllmConfig, engine_id: str):
def __init__(
self, vllm_config: VllmConfig, engine_id: str, kv_cache_config: "KVCacheConfig"
):
if NixlWrapper is None:
logger.error("NIXL is not available")
raise RuntimeError("NIXL is not available")
......@@ -875,6 +954,14 @@ class NixlConnectorWorker:
self.nixl_backends = vllm_config.kv_transfer_config.get_from_extra_config(
"backends", ["UCX"]
)
self._is_hma_required = (
not vllm_config.scheduler_config.disable_hybrid_kv_cache_manager
and any(
not isinstance(g.kv_cache_spec, FullAttentionSpec)
for g in kv_cache_config.kv_cache_groups
)
)
self.kv_cache_config = kv_cache_config
# Agent.
non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"]
......@@ -1017,10 +1104,6 @@ class NixlConnectorWorker:
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
# List of block window sizes for each layer for local attention
self.block_window_per_layer: list[int | None] = []
self.use_mla = self.model_config.use_mla
# Get the attention backend from the first layer
......@@ -1030,8 +1113,8 @@ class NixlConnectorWorker:
self.backend_name = self.attn_backend.get_name()
self.kv_cache_layout = get_kv_cache_layout()
self.host_buffer_kv_cache_layout = self.kv_cache_layout
logger.debug("Detected attention backend %s", self.backend_name)
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
logger.info("Detected attention backend %s", self.backend_name)
logger.info("Detected kv cache layout %s", self.kv_cache_layout)
# lazy initialized in register_kv_caches
self.compat_hash: str | None = None
......@@ -1238,9 +1321,15 @@ class NixlConnectorWorker:
"remote_request_id": meta.remote.request_id,
"remote_host": meta.remote.host,
"remote_port": meta.remote.port,
"num_local_blocks": len(meta.local_block_ids),
"num_remote_blocks": len(meta.remote.block_ids),
"local_block_ids_sample": meta.local_block_ids[:10],
"num_local_blocks": sum(
len(group) for group in meta.local_block_ids
),
"num_remote_blocks": sum(
len(group) for group in meta.remote.block_ids
),
"local_block_ids_sample": meta.local_block_ids[0][:10]
if meta.local_block_ids
else [],
}
)
......@@ -1301,8 +1390,10 @@ class NixlConnectorWorker:
error=e,
meta=meta,
)
if req_meta := self._recving_metadata.get(req_id):
self._invalid_block_ids.update(req_meta.local_block_ids)
if (
req_meta := self._recving_metadata.get(req_id)
) and not self._is_hma_required:
self._invalid_block_ids.update(req_meta.local_block_ids[0])
self._failed_recv_reqs.add(req_id)
fut.add_done_callback(request_ready)
......@@ -1370,6 +1461,10 @@ class NixlConnectorWorker:
for cache in cache_list:
base_addr = cache.data_ptr()
if base_addr in seen_base_addresses:
# NOTE (NickLucche) HMA employs memory pooling to share tensors
# across groups. This results in skipping all tensors but the ones
# pointed to by group0. Also, generally we will have more blocks
# per tensor but fewer regions.
continue
logger.debug(
......@@ -1457,28 +1552,6 @@ class NixlConnectorWorker:
self.register_local_xfer_handler(self.block_size)
)
# TODO(mgoin): Hybrid memory allocator is currently disabled for
# models with local attention (Llama 4). Can remove this once enabled.
if self.model_config.hf_config.model_type == "llama4":
from transformers import Llama4TextConfig
assert isinstance(self.model_config.hf_text_config, Llama4TextConfig)
llama4_config = self.model_config.hf_text_config
no_rope_layers = llama4_config.no_rope_layers
chunk_size = llama4_config.attention_chunk_size
chunk_block_size = math.ceil(chunk_size / self.block_size)
for layer_idx in range(self.num_layers):
# no_rope_layers[layer_idx] == 0 means NoPE (global)
# Any other value means RoPE (local chunked)
is_local_attention = no_rope_layers[layer_idx] != 0
block_window = chunk_block_size if is_local_attention else None
self.block_window_per_layer.append(block_window)
logger.debug(
"Llama 4 block window per layer mapping: %s",
self.block_window_per_layer,
)
assert len(self.block_window_per_layer) == self.num_layers
# After KV Caches registered, listen for new connections.
agent_metadata = NixlAgentMetadata(
engine_id=self.engine_id,
......@@ -1767,6 +1840,11 @@ class NixlConnectorWorker:
# Num kv_heads > tp_size and P TP > D TP case, not supported
assert not (tp_ratio < 0 and self.kv_topo.is_kv_replicated(remote_engine_id))
if self._is_hma_required:
assert block_size_ratio == 1, (
"HMA does not support different remote block size yet"
)
kv_cache_layout = (
self.kv_cache_layout
if not self.use_host_buffer
......@@ -1781,6 +1859,9 @@ class NixlConnectorWorker:
"Remote is HND and local is NHD, enabled additional permute "
"on local device KV."
)
assert not self._is_hma_required, (
"HMA does not support block size post processing"
)
self.enable_permute_local_kv = True
else:
raise RuntimeError(
......@@ -1836,13 +1917,15 @@ class NixlConnectorWorker:
assert self.copy_blocks is not None
local_block_ids = meta.local_physical_block_ids
self.copy_blocks(
self.host_xfer_buffers,
self.device_kv_caches,
local_block_ids,
local_block_ids,
"h2d",
)
# TODO (NickLucche) D2H<>H2D ops could benefit from coalescing io across groups
for group_block_ids in local_block_ids:
self.copy_blocks(
self.host_xfer_buffers,
self.device_kv_caches,
group_block_ids,
group_block_ids,
"h2d",
)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"synced recved kv of request[%s] to device kv buffer,"
......@@ -1868,13 +1951,14 @@ class NixlConnectorWorker:
",".join(map(str, meta.local_physical_block_ids)),
)
# blocking
self.copy_blocks(
self.device_kv_caches,
self.host_xfer_buffers,
meta.local_physical_block_ids,
meta.local_physical_block_ids,
"d2h",
)
for group_block_ids in meta.local_physical_block_ids:
self.copy_blocks(
self.device_kv_caches,
self.host_xfer_buffers,
group_block_ids,
group_block_ids,
"d2h",
)
def post_process_device_kv_on_receive(
self,
......@@ -1973,8 +2057,9 @@ class NixlConnectorWorker:
if not self.use_mla and (
block_size_ratio > 1 or self.enable_permute_local_kv
):
assert not self._is_hma_required
block_ids_for_blocksize_post_process[block_size_ratio].append(
meta.local_physical_block_ids
meta.local_physical_block_ids[0]
)
for (
block_size_ratio,
......@@ -2106,8 +2191,9 @@ class NixlConnectorWorker:
handle: The transfer handle.
"""
# Use .get() here as the metadata cleanup is handled by get_finished()
if meta := self._recving_metadata.get(req_id):
self._invalid_block_ids.update(meta.local_block_ids)
# TODO (NickLucche) handle failed transfer for HMA.
if (meta := self._recving_metadata.get(req_id)) and not self._is_hma_required:
self._invalid_block_ids.update(meta.local_block_ids[0])
self.nixl_wrapper.release_xfer_handle(handle)
self.xfer_stats.record_failed_transfer()
......@@ -2230,8 +2316,8 @@ class NixlConnectorWorker:
def _read_blocks(
self,
local_block_ids: list[int],
remote_block_ids: list[int],
local_block_ids: BlockIds,
remote_block_ids: BlockIds,
dst_engine_id: str,
request_id: str,
remote_request_id: str,
......@@ -2246,22 +2332,30 @@ class NixlConnectorWorker:
assert self.kv_topo is not None
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id)
if block_size_ratio > 1:
local_block_ids = self.get_mapped_blocks(
np.asarray(local_block_ids), block_size_ratio
)
if len(local_block_ids) > len(remote_block_ids):
# TODO (NickLucche) assume HMA is off. Change to handle multiple KV groups.
assert not self._is_hma_required
local_block_ids0 = local_block_ids[0] if local_block_ids else []
remote_block_ids0 = remote_block_ids[0]
local_block_ids_mapped = self.get_mapped_blocks(
np.asarray(local_block_ids0), block_size_ratio
).tolist()
if len(local_block_ids_mapped) > len(remote_block_ids0):
# NOTE:
# get_mapped_blocks will always expand block_ids for n times.
# ex:
# prefill block_ids with block_size as 4:
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# Local decode block_ids with block_size as 16: [1, 2, 3]
# expland ecode block_ids with get_mapped_blocks from [1, 2, 3] to
# expanded decode block_ids with get_mapped_blocks from [1, 2, 3] to
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
# Then we clip local to align with prefill
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] to
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
local_block_ids = local_block_ids[: len(remote_block_ids)]
local_block_ids_mapped = local_block_ids_mapped[
: len(remote_block_ids0)
]
local_block_ids = [local_block_ids_mapped] if local_block_ids_mapped else []
remote_block_ids = [remote_block_ids0]
# NOTE(rob): having the staging blocks be on the READER side is
# not going to work well (since we will have to call rearrange tensors).
# after we detect the txn is complete (which means we cannot make the
......@@ -2269,8 +2363,7 @@ class NixlConnectorWorker:
# then we will need to have the staging blocks on the remote side.
# NOTE(rob): according to nvidia the staging blocks are used to
# saturate IB with heterogeneous TP sizes. We should remove the staging
# blocks until we are ready.
# saturate IB with heterogeneous TP sizes.
# Number of D TP workers that will read from dst P. Propagate info
# on notification so that dst worker can wait before freeing blocks.
......@@ -2278,8 +2371,8 @@ class NixlConnectorWorker:
# Full prefix cache hit: do not need to read remote blocks,
# just notify P worker that we have the blocks we need.
num_local_blocks = len(local_block_ids)
if num_local_blocks == 0:
if len(local_block_ids) == 0:
# A full prefix cache hit is indicated with an empty list.
agent_name = self._remote_agents[dst_engine_id][remote_rank]
try:
self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)
......@@ -2297,66 +2390,34 @@ class NixlConnectorWorker:
self.xfer_stats.record_failed_notification()
return
# Partial prefix cache hit: just read uncomputed blocks.
num_remote_blocks = len(remote_block_ids)
assert num_local_blocks <= num_remote_blocks
if num_local_blocks < num_remote_blocks:
remote_block_ids = remote_block_ids[-num_local_blocks:]
assert (
len(remote_block_ids)
== len(local_block_ids)
== len(self.kv_cache_config.kv_cache_groups)
)
remote_block_ids = list(remote_block_ids)
for i, remote_group in enumerate(remote_block_ids):
num_remote_blocks = len(remote_group)
num_local_blocks = len(local_block_ids[i])
assert num_local_blocks <= num_remote_blocks
# Partial prefix cache hit: just read uncomputed blocks.
if num_local_blocks < num_remote_blocks:
remote_block_ids[i] = remote_group[-num_local_blocks:]
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp
# workers will issue xfers to parts of the P worker remote kv caches.
# Get descs ids.
local_block_descs_ids: np.ndarray
remote_block_descs_ids: np.ndarray
if not self.block_window_per_layer:
# Default case: assume global attention
remote_block_descs_ids = self._get_block_descs_ids(
dst_engine_id,
remote_block_ids,
)
local_block_descs_ids = self._get_block_descs_ids(
self.engine_id,
local_block_ids,
block_size_ratio=block_size_ratio,
)
else:
# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
local_descs_list = []
remote_descs_list = []
for layer_idx, block_window in enumerate(self.block_window_per_layer):
# For each layer:
if block_window is None:
# If not chunked, we just use the
# full block lists (global attention)
layer_local_block_ids = local_block_ids
layer_remote_block_ids = remote_block_ids
else:
# If chunked, get the last block_window blocks
layer_local_block_ids = local_block_ids[-block_window:]
layer_remote_block_ids = remote_block_ids[-block_window:]
# Get descs ids for the layer.
layer_local_desc_ids = self._get_block_descs_ids(
self.engine_id,
layer_local_block_ids,
layer_idx,
block_size_ratio=block_size_ratio,
)
layer_remote_desc_ids = self._get_block_descs_ids(
dst_engine_id,
layer_remote_block_ids,
layer_idx,
)
local_descs_list.append(layer_local_desc_ids)
remote_descs_list.append(layer_remote_desc_ids)
local_block_descs_ids = np.concatenate(local_descs_list)
remote_block_descs_ids = np.concatenate(remote_descs_list)
remote_block_descs_ids = self._get_block_descs_ids(
dst_engine_id,
remote_block_ids,
)
local_block_descs_ids = self._get_block_descs_ids(
self.engine_id,
local_block_ids,
block_size_ratio=block_size_ratio,
)
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
......@@ -2387,14 +2448,18 @@ class NixlConnectorWorker:
dst_engine_id=dst_engine_id,
remote_rank=remote_rank,
)
if meta := self._recving_metadata.get(request_id):
self._invalid_block_ids.update(meta.local_block_ids)
if (
meta := self._recving_metadata.get(request_id)
) and not self._is_hma_required:
self._invalid_block_ids.update(meta.local_block_ids[0])
self.xfer_stats.record_failed_transfer()
if handle is not None:
self.nixl_wrapper.release_xfer_handle(handle)
self._failed_recv_reqs.add(request_id)
def get_mapped_blocks(self, block_ids, block_size_ratio):
def get_mapped_blocks(
self, block_ids: np.ndarray, block_size_ratio: int
) -> np.ndarray:
"""
Calculates the new set of block IDs by mapping every element
in the (potentially sparse) input array.
......@@ -2416,41 +2481,32 @@ class NixlConnectorWorker:
def _get_block_descs_ids(
self,
engine_id: str,
block_ids: list[int],
layer_idx: int | None = None,
block_ids: BlockIds,
block_size_ratio: float | None = None,
) -> np.ndarray:
"""
Get the descs ids for a set of block ids.
If layer_idx is provided, we use the region_ids for the given layer.
Otherwise, we use all regions.
When HMA is enabled number of descriptors across kv cache groups might differ.
A single flattened array is returned for all groups anyway.
"""
if layer_idx is None:
region_ids = np.arange(self.num_regions)
else:
assert layer_idx < self.num_layers
if self.num_layers < self.num_regions:
# If we have more regions than layers, we assume that
# the regions are organized as [K0, V0, K1, V1, ...]
# and we select K_i and V_i
assert 2 * self.num_layers == self.num_regions
region_ids = np.arange(2 * layer_idx, 2 * layer_idx + 2)
else:
# Otherwise, we assume we have MLA and select i-th layer
assert self.num_layers == self.num_regions
region_ids = np.arange(layer_idx, layer_idx + 1)
region_ids = np.arange(self.num_regions)
# NOTE (NickLucche) With HMA, every kv group has the same number of layers and
# layers from different groups share the same kv tensor.
# eg block_ids=[[1, 2], [3]]->blocks [1, 2] need to be read across all regions,
# same for [3], but group0-group1 blocks will always differ (different areas).
# Therefore we can just flatten the block_ids and compute the descs ids for all
# groups at once.
num_blocks = self.dst_num_blocks[engine_id]
if block_size_ratio is not None:
num_blocks = int(num_blocks * block_size_ratio)
# Compute the desc ids for each block.
region_ids = region_ids[:, None]
block_ids = np.array(block_ids)[None, :]
block_ids = np.concatenate(block_ids)[None, :]
descs_ids = region_ids * num_blocks + block_ids
return descs_ids.flatten()
def _logical_to_kernel_block_ids(self, block_ids: list[int]) -> list[int]:
def _logical_to_kernel_block_ids(self, block_ids: BlockIds) -> BlockIds:
"""
Convert logical block ids to kernel physical block ids.
This is required when the logical block size (the one set by the user)
......@@ -2459,13 +2515,17 @@ class NixlConnectorWorker:
if self._physical_blocks_per_logical_kv_block == 1:
# Noop when physical and logical block sizes are the same
return block_ids
block_ids_np = np.array(block_ids)
block_arange = np.arange(0, self._physical_blocks_per_logical_kv_block).reshape(
1, -1
)
return BlockTable.map_to_kernel_blocks(
block_ids_np, self._physical_blocks_per_logical_kv_block, block_arange
).tolist()
return [
BlockTable.map_to_kernel_blocks(
np.array(group),
self._physical_blocks_per_logical_kv_block,
block_arange,
).tolist()
for group in block_ids
]
def get_backend_aware_kv_block_len(self, layer_idx: int) -> int:
"""
......
......@@ -84,6 +84,18 @@ class KVCacheBlocks:
assert len(self.blocks) == 1, "Only one group is supported"
return [block.block_id for block in self.blocks[0] if block.block_hash is None]
def get_unhashed_block_ids_all_groups(self) -> list[list[int]]:
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
# Skip padding blocks.
return [
[
block.block_id
for block in group
if block.block_hash is None and not block.is_null
]
for group in self.blocks
]
def new_empty(self) -> "KVCacheBlocks":
"""
Creates a new KVCacheBlocks instance with no blocks.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment