"vscode:/vscode.git/clone" did not exist on "fc3eeea966d0b1fc0915709ef0de7a5d378619fa"
Unverified Commit 5b3ba94a authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[Core][KVConnector] Support HMA+NixlConnector (#35758)


Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
parent 90f3c01f
......@@ -12,6 +12,7 @@ tp_configs=(
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA case
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192" # SW model
)
dp_ep_configs=(
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP1, D-DPEP=2 (TP=1)
......@@ -26,6 +27,14 @@ else
configs=("${tp_configs[@]}")
fi
if [[ -n "${ENABLE_HMA_FLAG:-}" ]]; then
# Append ENABLE_HMA_FLAG=1 to each config in the selected array
echo "ENABLE_HMA_FLAG is set, appending ENABLE_HMA_FLAG=1 to each config"
for i in "${!configs[@]}"; do
configs[$i]="ENABLE_HMA_FLAG=1 ${configs[$i]}"
done
fi
run_tests() {
local label=$1
local extra_args=$2
......
......@@ -5,6 +5,12 @@ set -xe
KV_BUFFER_DEVICE="cuda" # Default to cuda
ATTENTION_BACKEND="" # Default to empty (use vllm default)
CROSS_LAYERS_BLOCKS="False"
ENABLE_HMA_VAR="" # Default to empty (HMA disabled by default for kv connector)
# Check for ENABLE_HMA_FLAG environment variable
if [[ -n "${ENABLE_HMA_FLAG:-}" ]]; then
ENABLE_HMA_VAR="--no-disable-hybrid-kv-cache-manager"
fi
while [[ $# -gt 0 ]]; do
case $1 in
--kv_buffer_device)
......@@ -31,6 +37,12 @@ echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE"
if [[ -n "$ATTENTION_BACKEND" ]]; then
echo "Using attention backend: $ATTENTION_BACKEND"
fi
if [[ -n "$ENABLE_HMA_VAR" ]]; then
echo "HMA (Hybrid KV Cache Manager) enabled"
fi
if [[ -n "$VLLM_SERVE_EXTRA_ARGS" ]]; then
echo "vLLM serve extra args: $VLLM_SERVE_EXTRA_ARGS"
fi
DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"HND"} # Default to HND, optional NHD
if [[ "$DECODER_KV_LAYOUT" == "NHD" ]]; then
......@@ -70,6 +82,8 @@ DECODER_TP_SIZE=${DECODER_TP_SIZE:-1}
GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.2}
PREFILL_BLOCK_SIZE=${PREFILL_BLOCK_SIZE:-128}
DECODE_BLOCK_SIZE=${DECODE_BLOCK_SIZE:-128}
# Comma-separated extra args for vllm serve (e.g. --max-model-len,2048)
VLLM_SERVE_EXTRA_ARGS=${VLLM_SERVE_EXTRA_ARGS:-}
# Find the git repository root directory
GIT_ROOT=$(git rev-parse --show-toplevel)
......@@ -151,14 +165,24 @@ run_tests_for_model() {
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--tensor-parallel-size $PREFILLER_TP_SIZE \
--kv-transfer-config '$KV_CONFIG'"
if [[ -n "$VLLM_SERVE_EXTRA_ARGS" ]]; then
IFS=',' read -r -a extra_args <<< "$VLLM_SERVE_EXTRA_ARGS"
for arg in "${extra_args[@]}"; do
BASE_CMD="${BASE_CMD} $arg"
done
fi
# Add attention backend config if specified
if [[ -n "$ATTENTION_BACKEND" ]]; then
BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND"
fi
# Add HMA flag if specified
if [[ -n "$ENABLE_HMA_VAR" ]]; then
BASE_CMD="${BASE_CMD} $ENABLE_HMA_VAR"
fi
FULL_CMD="$BASE_CMD"
eval "$FULL_CMD &"
# Store host and port for proxy configuration
......@@ -193,12 +217,23 @@ run_tests_for_model() {
--block-size ${DECODE_BLOCK_SIZE} \
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--kv-transfer-config '$KV_CONFIG'"
if [[ -n "$VLLM_SERVE_EXTRA_ARGS" ]]; then
IFS=',' read -r -a extra_args <<< "$VLLM_SERVE_EXTRA_ARGS"
for arg in "${extra_args[@]}"; do
BASE_CMD="${BASE_CMD} $arg"
done
fi
# Add attention backend config if specified
if [[ -n "$ATTENTION_BACKEND" ]]; then
BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND"
fi
# Add HMA flag if specified
if [[ -n "$ENABLE_HMA_VAR" ]]; then
BASE_CMD="${BASE_CMD} $ENABLE_HMA_VAR"
fi
# DP-EP attention mode
if [[ -z "$DP_EP" ]]; then
BASE_CMD="${BASE_CMD} --tensor-parallel-size $DECODER_TP_SIZE"
......
......@@ -17,6 +17,7 @@ EXPECTED_VALUES = {
"deepseek-ai/deepseek-vl2-small": 0.59,
"deepseek-ai/deepseek-vl2-tiny": 0.19,
"deepseek-ai/DeepSeek-V2-Lite-Chat": 0.65,
"google/gemma-3-4b-it": 0.74,
}
SIMPLE_PROMPT = (
......
......@@ -59,7 +59,12 @@ from vllm.v1.request import RequestStatus
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
from vllm.v1.worker.utils import AttentionGroup
from .utils import create_request, create_scheduler, create_vllm_config
from .utils import (
create_request,
create_scheduler,
create_vllm_config,
make_kv_cache_config,
)
@pytest.fixture(scope="module", autouse=True)
......@@ -263,7 +268,7 @@ def test_basic_interface():
req_meta = kv_connector_metadata.reqs_to_recv[request_id]
for block_id, block in zip(
req_meta.local_block_ids,
req_meta.local_block_ids[0],
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[
request_id
],
......@@ -327,7 +332,9 @@ def test_kv_transfer_handshake(dist_init):
# Prefill connector will register KV cache to populate proper handshake
# metadata.
prefill_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
prefill_connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
......@@ -367,13 +374,17 @@ def test_kv_transfer_handshake(dist_init):
do_remote_decode=True,
)
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
delay, kv_connector_metadata = scheduler.get_kv_connector().request_finished(
request, [0, 1, 2]
delay, kv_connector_metadata = (
scheduler.get_kv_connector().request_finished_all_groups(
request, ([0, 1, 2],)
)
)
assert delay
# Decode connector will be able to create handshake with the prefill connector.
decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
decode_connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
decode_connector.register_kv_caches(kv_caches)
# Here we are testing the retrieval of NIXLAgentMetadata.
......@@ -404,9 +415,16 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
REMOTE_ENGINE_ID = "remote_engine"
def __init__(
self, *args, hand_shake_latency: float = 1.8, kv_cache_layout="HND", **kwargs
self,
*args,
hand_shake_latency: float = 1.8,
kv_cache_layout="HND",
kv_cache_config=None,
**kwargs,
):
super().__init__(*args, **kwargs)
if kv_cache_config is None:
kv_cache_config = make_kv_cache_config(block_size=16)
super().__init__(*args, kv_cache_config=kv_cache_config, **kwargs)
self._hand_shake_latency = hand_shake_latency
self.kv_cache_layout = kv_cache_layout
# Mock register_kv_caches attribute needed for tests that do not call it.
......@@ -507,7 +525,9 @@ class TestNixlHandshake:
request_id = "req_id"
# Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
......@@ -528,13 +548,15 @@ class TestNixlHandshake:
num_xfers -= 1
metadata.add_new_req_to_recv(
request_id=request_id,
local_block_ids=[num_xfers + 1, num_xfers + 2, num_xfers + 3],
local_block_ids=([num_xfers + 1, num_xfers + 2, num_xfers + 3],),
kv_transfer_params={
"remote_block_ids": [
num_xfers + 4,
num_xfers + 5,
num_xfers + 6,
],
"remote_block_ids": (
[
num_xfers + 4,
num_xfers + 5,
num_xfers + 6,
],
),
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": f"prefill-{request_id}",
"remote_host": "localhost",
......@@ -594,16 +616,18 @@ class TestNixlHandshake:
vllm_config.parallel_config.tensor_parallel_size = decode_tp_size
# Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id
)
metadata = NixlConnectorMetadata()
metadata.add_new_req_to_recv(
request_id="id",
local_block_ids=[1, 2, 3],
local_block_ids=([1, 2, 3],),
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_block_ids": ([4, 5, 6],),
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": "prefill-id",
"remote_host": "localhost",
......@@ -652,7 +676,9 @@ class TestNixlHandshake:
local_tp_size = 1
vllm_config.parallel_config.tensor_parallel_size = local_tp_size
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
......@@ -717,8 +743,12 @@ class TestNixlHandshake:
p_tp_size = 2
# Build two separate connectors/workers to emulate P TP=2 ranks.
conn_p0 = NixlConnector(vllm_config, KVConnectorRole.WORKER)
conn_p1 = NixlConnector(vllm_config, KVConnectorRole.WORKER)
conn_p0 = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
conn_p1 = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
conn_p0.connector_worker = FakeNixlConnectorWorker(
vllm_config, conn_p0.engine_id, hand_shake_latency=0
)
......@@ -815,7 +845,9 @@ class TestNixlHandshake:
vllm_config = create_vllm_config()
# Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id
)
......@@ -827,9 +859,9 @@ class TestNixlHandshake:
for i in range(total_reqs):
metadata.add_new_req_to_recv(
request_id=f"id_{i}",
local_block_ids=[1, 2, 3],
local_block_ids=([1, 2, 3],),
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_block_ids": ([4, 5, 6],),
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": f"prefill-id-{i}",
"remote_host": "localhost",
......@@ -884,7 +916,9 @@ class TestNixlHandshake:
return_value=2,
):
# Initialize connector and worker (with fake NIXL wrapper)
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
......@@ -934,7 +968,9 @@ class TestNixlHandshake:
return_value=2,
):
# Initialize connector and worker (with fake NIXL wrapper)
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config,
connector.engine_id,
......@@ -979,7 +1015,9 @@ def test_kv_connector_stats(default_vllm_config, dist_init):
vllm_config = create_vllm_config()
# Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
......@@ -993,9 +1031,9 @@ def test_kv_connector_stats(default_vllm_config, dist_init):
metadata = NixlConnectorMetadata()
metadata.add_new_req_to_recv(
request_id=request_id,
local_block_ids=[1, 2, 3],
local_block_ids=([1, 2, 3],),
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_block_ids": ([4, 5, 6],),
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": f"prefill-{request_id}",
"remote_host": "localhost",
......@@ -1448,7 +1486,9 @@ def test_register_kv_caches(
mock_get_attn_backend.return_value = backend_cls
# Create connector
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
......@@ -1676,7 +1716,9 @@ def test_kv_buffer_to_nixl_memory_types(
),
): # noqa: E501
# Create connector and replace its worker with a fake one for isolation
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
# Verify get_reg_descs was called with the correct memory_type
assert connector.connector_worker.kv_buffer_device == kv_buffer_device
......@@ -1692,9 +1734,15 @@ def test_shutdown_cleans_up_resources(default_vllm_config, dist_init):
vllm_config = create_vllm_config()
scheduler = NixlConnectorScheduler(
vllm_config, vllm_config.kv_transfer_config.engine_id
vllm_config,
vllm_config.kv_transfer_config.engine_id,
make_kv_cache_config(block_size=16),
)
worker = NixlConnectorWorker(
vllm_config,
vllm_config.kv_transfer_config.engine_id,
make_kv_cache_config(block_size=16),
)
worker = NixlConnectorWorker(vllm_config, vllm_config.kv_transfer_config.engine_id)
nixl_wrapper = worker.nixl_wrapper
with (
......@@ -1756,7 +1804,9 @@ def test_aborted_request_removed_from_worker_in_batch(default_vllm_config, dist_
scheduler = create_scheduler(vllm_config)
# KVConnector Worker in P
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
......@@ -1875,12 +1925,14 @@ class FailingNixlWrapper(FakeNixlWrapper):
("transfer_exception", {"fail_transfer_exception": True}, True),
],
)
@pytest.mark.parametrize("enable_hma", [False, True])
def test_transfer_failure_logging(
default_vllm_config,
dist_init,
failure_type,
wrapper_config,
needs_get_finished,
enable_hma,
):
"""Test that transfer failures are logged with structured context.
......@@ -1897,9 +1949,16 @@ def test_transfer_failure_logging(
vllm_config = create_vllm_config()
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector = NixlConnector(
vllm_config,
KVConnectorRole.WORKER,
make_kv_cache_config(block_size=16, hma_enabled=enable_hma),
)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0.0
vllm_config,
connector.engine_id,
hand_shake_latency=0.0,
kv_cache_config=connector._kv_cache_config,
)
# Configure FailingNixlWrapper to fail in the specified way
......@@ -1910,8 +1969,17 @@ def test_transfer_failure_logging(
# For notification_failed, we need empty local blocks
# (full cache hit path to trigger send_notif)
local_blocks = [] if failure_type == "notification_failed" else [10, 11, 12]
remote_blocks = [20, 21, 22]
local_blocks: tuple[()] | tuple[list[int], ...]
if enable_hma:
# HMA enabled: multiple groups (FA + SW)
local_blocks = (
() if failure_type == "notification_failed" else ([10, 11, 12], [13, 14])
)
remote_blocks = [[20, 21, 22], [23, 24]]
else:
# HMA disabled: single group
local_blocks = () if failure_type == "notification_failed" else ([10, 11, 12],)
remote_blocks = [[20, 21, 22]]
metadata = NixlConnectorMetadata()
metadata.add_new_req_to_recv(
......@@ -2007,7 +2075,9 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
"""Test that handshake failures mark blocks invalid and return via get_finished."""
vllm_config = create_vllm_config()
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0.1
)
......@@ -2017,9 +2087,9 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
metadata = NixlConnectorMetadata()
metadata.add_new_req_to_recv(
request_id=request_id,
local_block_ids=[1, 2, 3],
local_block_ids=([1, 2, 3],),
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_block_ids": ([4, 5, 6],),
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": f"prefill-{request_id}",
"remote_host": "localhost",
......@@ -2058,7 +2128,9 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init)
and return via get_finished."""
vllm_config = create_vllm_config()
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
......@@ -2068,9 +2140,9 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init)
metadata = NixlConnectorMetadata()
metadata.add_new_req_to_recv(
request_id=request_id,
local_block_ids=[7, 8, 9],
local_block_ids=([7, 8, 9],),
kv_transfer_params={
"remote_block_ids": [10, 11, 12],
"remote_block_ids": ([10, 11, 12],),
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": f"prefill-{request_id}",
"remote_host": "localhost",
......@@ -2154,7 +2226,9 @@ def test_compatibility_hash_validation(
"enforce_handshake_compat": enforce_handshake_compat
},
)
decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER)
decode_connector = NixlConnector(
local_vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
decode_worker = decode_connector.connector_worker
kv_cache_shape = decode_worker.attn_backend.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
......@@ -2267,7 +2341,9 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
model="facebook/opt-125m",
block_size=16,
)
decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER)
decode_connector = NixlConnector(
local_vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
decode_worker = decode_connector.connector_worker
backend = get_current_attn_backend(local_vllm_config)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for NixlConnectorScheduler sw_sizes calculation with HMA."""
from unittest.mock import patch
import pytest
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
from vllm.v1.core.single_type_kv_cache_manager import (
FullAttentionManager,
SlidingWindowManager,
)
from .utils import (
create_vllm_config,
make_kv_cache_config,
)
@pytest.mark.cpu_test
@pytest.mark.parametrize(
"hma_enabled,expected_sw_sizes",
[
# HMA enabled: FullAttentionSpec (0) + SlidingWindowSpec (2048/16=128)
(True, [0, 128 + 1]),
# HMA disabled: only FullAttentionSpec (0)
(False, [0]),
],
)
@patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform")
def test_sw_sizes(mock_platform, hma_enabled, expected_sw_sizes):
"""Test sw_sizes is correctly computed based on HMA enabled/disabled."""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorScheduler,
)
mock_platform.device_type = "cpu"
block_size = 16
vllm_config = create_vllm_config(block_size=block_size)
# SW 2048 tokens=>128 blocks
kv_cache_config = make_kv_cache_config(
block_size=block_size, hma_enabled=hma_enabled, sw_size=2048
)
scheduler = NixlConnectorScheduler(
vllm_config=vllm_config,
engine_id="test-engine",
kv_cache_config=kv_cache_config,
)
# in number of blocks
assert scheduler.blocks_per_sw == expected_sw_sizes, (
f"Expected sw_sizes={expected_sw_sizes}, got {scheduler.blocks_per_sw}"
)
@pytest.mark.cpu_test
def test_logical_to_kernel_block_ids_with_hma():
"""Test _logical_to_kernel_block_ids expands blocks when HMA is enabled.
When HMA is enabled, the logical block size may differ from the kernel
block size. Each logical block maps to multiple kernel blocks.
"""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorWorker,
)
# Create a mock worker with just the required attributes
# (use __new__ to skip __init__)
worker = object.__new__(NixlConnectorWorker)
# Simulate HMA scenario: logical block size = 32, kernel block size = 16
# So each logical block maps to 2 kernel blocks eg [0]->[0,1]
worker._physical_blocks_per_logical_kv_block = 2
# Test conversion: FA + SW group
logical_block_ids = [[0, 1, 2], [3, 4]]
kernel_block_ids = worker._logical_to_kernel_block_ids(logical_block_ids)
expected_kernel_block_ids = [[0, 1, 2, 3, 4, 5], [6, 7, 8, 9]]
assert kernel_block_ids == expected_kernel_block_ids, (
f"Expected {expected_kernel_block_ids}, got {kernel_block_ids}"
)
@pytest.mark.parametrize("model_name, sw_size", [("google/gemma-3-1b-it", 512)])
def test_fewer_blocks_with_hma(monkeypatch, model_name, sw_size):
"""Test that a prefill instance returns fewer "remote blocks" for the SWA groups
when sequence exceeds the sliding window.
"""
kv_transfer_config = KVTransferConfig(
kv_connector="NixlConnector",
kv_role="kv_both",
)
block_size = 16
llm_kwargs = {
"model": model_name,
"enforce_eager": True,
"gpu_memory_utilization": 0.5,
"kv_transfer_config": kv_transfer_config,
"max_model_len": 2048,
# NOTE: Make sure HMA is enabled
"disable_hybrid_kv_cache_manager": False,
"max_num_batched_tokens": 1024,
"enable_prefix_caching": False,
"block_size": block_size,
}
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
def run_hma_test(llm: LLM):
remote_prefill_opts = {
"do_remote_decode": True,
"do_remote_prefill": False,
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": None,
"remote_port": None,
}
# Simulate sidecar request
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=1,
extra_args={"kv_transfer_params": remote_prefill_opts},
)
scheduler = llm.llm_engine.engine_core.engine_core.scheduler
kv_managers = scheduler.kv_cache_manager.coordinator.single_type_managers
# HMA enabled with FA + SWA groups
assert len(kv_managers) > 2
for kv_manager in kv_managers:
assert isinstance(kv_manager, (SlidingWindowManager, FullAttentionManager))
req_to_blocks = kv_managers[0].req_to_blocks
assert len(req_to_blocks) == 0
# Process some request with length exceeding the sliding window
outputs = llm.generate(["hi" * 1401], sampling_params)
kv_params = outputs[0].kv_transfer_params
# +1 to account for overlapping window across blocks.
expected_num_remote_blocks = sw_size // block_size + 1
remote_block_ids = kv_params["remote_block_ids"]
assert (
len(remote_block_ids[0])
== expected_num_remote_blocks
< len(remote_block_ids[-1])
)
for group_block_ids in remote_block_ids[:-1]:
assert len(group_block_ids) == expected_num_remote_blocks
def run_test_and_cleanup():
llm = LLM(**llm_kwargs)
try:
run_hma_test(llm)
finally:
llm.llm_engine.engine_core.shutdown()
run_test_and_cleanup()
@pytest.mark.cpu_test
def test_nixl_metadata_hma_block_ids_structure():
"""
Test that NixlConnectorMetadata correctly stores block IDs for multiple
KV cache groups when HMA is enabled.
"""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorMetadata,
)
metadata = NixlConnectorMetadata()
# Add request with block IDs for 2 groups (FA + SW)
fa_blocks = [0, 1, 2, 3, 4, 5, 6, 7] # 8 blocks for FA
sw_blocks = [8, 9, 10, 11] # 4 blocks for SW (clipped)
metadata.add_new_req_to_recv(
request_id="test-req-hma",
local_block_ids=(fa_blocks, sw_blocks),
kv_transfer_params={
"remote_block_ids": ([10, 11, 12, 13, 14, 15, 16, 17], [18, 19, 20, 21]),
"remote_engine_id": "remote-engine",
"remote_request_id": "prefill-test-req-hma",
"remote_host": "localhost",
"remote_port": 1234,
"tp_size": 1,
},
)
assert "test-req-hma" in metadata.reqs_to_recv
req_meta = metadata.reqs_to_recv["test-req-hma"]
# Verify local block IDs structure
assert len(req_meta.local_block_ids) == 2
assert list(req_meta.local_block_ids[0]) == fa_blocks
assert list(req_meta.local_block_ids[1]) == sw_blocks
# Verify remote block IDs structure
assert req_meta.remote is not None
assert len(req_meta.remote.block_ids) == 2
assert list(req_meta.remote.block_ids[0]) == [10, 11, 12, 13, 14, 15, 16, 17]
assert list(req_meta.remote.block_ids[1]) == [18, 19, 20, 21]
......@@ -208,7 +208,9 @@ def test_prefix_cache_lifecycle():
# Ensure we send all block ids, including the partial blocks,
# even if there is a cache hit.
assert len(kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS + 1)
# remote_block_ids is BlockIds (tuple of lists); sum block counts across groups.
num_remote_blocks = sum(len(g) for g in kv_transfer_params["remote_block_ids"])
assert num_remote_blocks == (NUM_EXTERNAL_FULL_BLOCKS + 1)
# STEP (2): Ensure it is freed.
scheduler_output = scheduler.schedule()
......
......@@ -36,6 +36,7 @@ from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
SlidingWindowSpec,
)
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import Request
......@@ -142,24 +143,26 @@ def create_vllm_config(
def create_scheduler(
vllm_config: VllmConfig,
num_blocks: int = 10000,
kv_cache_config: KVCacheConfig | None = None,
) -> Scheduler:
"""Initialize Scheduler For Testing."""
block_size = vllm_config.cache_config.block_size
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer"],
FullAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
),
)
],
)
if kv_cache_config is None:
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer"],
FullAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
),
)
],
)
vllm_config.cache_config.num_gpu_blocks = num_blocks
return Scheduler(
vllm_config=vllm_config,
......@@ -412,3 +415,38 @@ KVConnectorFactory.register_connector(
KVConnectorFactory.register_connector(
"MockKVConnector", __name__, MockKVConnector.__name__
)
def make_kv_cache_config(
block_size: int,
hma_enabled: bool = False,
sw_size: int = 128,
num_blocks: int = 100,
) -> KVCacheConfig:
kv_cache_groups = [
KVCacheGroupSpec(
["layer0", "layer2"],
FullAttentionSpec(
block_size=block_size,
num_kv_heads=4,
head_size=16,
dtype=torch.float16,
),
)
]
if hma_enabled:
kv_cache_groups.append(
KVCacheGroupSpec(
["layer1", "layer3"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=4,
head_size=16,
dtype=torch.float16,
sliding_window=sw_size,
),
)
)
return KVCacheConfig(
num_blocks=num_blocks, kv_cache_tensors=[], kv_cache_groups=kv_cache_groups
)
......@@ -24,6 +24,9 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
EngineId = str
# block ids as returned by the hybrid KV cache manager. list[list[int]] are allow
# mutability and are for connector internal use only.
BlockIds = tuple[list[int], ...] | list[list[int]]
def get_kv_connector_cache_layout():
......
......@@ -84,6 +84,18 @@ class KVCacheBlocks:
assert len(self.blocks) == 1, "Only one group is supported"
return [block.block_id for block in self.blocks[0] if block.block_hash is None]
def get_unhashed_block_ids_all_groups(self) -> list[list[int]]:
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
# Skip padding blocks.
return [
[
block.block_id
for block in group
if block.block_hash is None and not block.is_null
]
for group in self.blocks
]
def new_empty(self) -> "KVCacheBlocks":
"""
Creates a new KVCacheBlocks instance with no blocks.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment