Unverified Commit 5b3ba94a authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[Core][KVConnector] Support HMA+NixlConnector (#35758)


Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
parent 90f3c01f
...@@ -12,6 +12,7 @@ tp_configs=( ...@@ -12,6 +12,7 @@ tp_configs=(
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA case "GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA case
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" "GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" "GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192" # SW model
) )
dp_ep_configs=( dp_ep_configs=(
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP1, D-DPEP=2 (TP=1) "DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP1, D-DPEP=2 (TP=1)
...@@ -26,6 +27,14 @@ else ...@@ -26,6 +27,14 @@ else
configs=("${tp_configs[@]}") configs=("${tp_configs[@]}")
fi fi
if [[ -n "${ENABLE_HMA_FLAG:-}" ]]; then
# Append ENABLE_HMA_FLAG=1 to each config in the selected array
echo "ENABLE_HMA_FLAG is set, appending ENABLE_HMA_FLAG=1 to each config"
for i in "${!configs[@]}"; do
configs[$i]="ENABLE_HMA_FLAG=1 ${configs[$i]}"
done
fi
run_tests() { run_tests() {
local label=$1 local label=$1
local extra_args=$2 local extra_args=$2
......
...@@ -5,6 +5,12 @@ set -xe ...@@ -5,6 +5,12 @@ set -xe
KV_BUFFER_DEVICE="cuda" # Default to cuda KV_BUFFER_DEVICE="cuda" # Default to cuda
ATTENTION_BACKEND="" # Default to empty (use vllm default) ATTENTION_BACKEND="" # Default to empty (use vllm default)
CROSS_LAYERS_BLOCKS="False" CROSS_LAYERS_BLOCKS="False"
ENABLE_HMA_VAR="" # Default to empty (HMA disabled by default for kv connector)
# Check for ENABLE_HMA_FLAG environment variable
if [[ -n "${ENABLE_HMA_FLAG:-}" ]]; then
ENABLE_HMA_VAR="--no-disable-hybrid-kv-cache-manager"
fi
while [[ $# -gt 0 ]]; do while [[ $# -gt 0 ]]; do
case $1 in case $1 in
--kv_buffer_device) --kv_buffer_device)
...@@ -31,6 +37,12 @@ echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE" ...@@ -31,6 +37,12 @@ echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE"
if [[ -n "$ATTENTION_BACKEND" ]]; then if [[ -n "$ATTENTION_BACKEND" ]]; then
echo "Using attention backend: $ATTENTION_BACKEND" echo "Using attention backend: $ATTENTION_BACKEND"
fi fi
if [[ -n "$ENABLE_HMA_VAR" ]]; then
echo "HMA (Hybrid KV Cache Manager) enabled"
fi
if [[ -n "$VLLM_SERVE_EXTRA_ARGS" ]]; then
echo "vLLM serve extra args: $VLLM_SERVE_EXTRA_ARGS"
fi
DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"HND"} # Default to HND, optional NHD DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"HND"} # Default to HND, optional NHD
if [[ "$DECODER_KV_LAYOUT" == "NHD" ]]; then if [[ "$DECODER_KV_LAYOUT" == "NHD" ]]; then
...@@ -70,6 +82,8 @@ DECODER_TP_SIZE=${DECODER_TP_SIZE:-1} ...@@ -70,6 +82,8 @@ DECODER_TP_SIZE=${DECODER_TP_SIZE:-1}
GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.2} GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.2}
PREFILL_BLOCK_SIZE=${PREFILL_BLOCK_SIZE:-128} PREFILL_BLOCK_SIZE=${PREFILL_BLOCK_SIZE:-128}
DECODE_BLOCK_SIZE=${DECODE_BLOCK_SIZE:-128} DECODE_BLOCK_SIZE=${DECODE_BLOCK_SIZE:-128}
# Comma-separated extra args for vllm serve (e.g. --max-model-len,2048)
VLLM_SERVE_EXTRA_ARGS=${VLLM_SERVE_EXTRA_ARGS:-}
# Find the git repository root directory # Find the git repository root directory
GIT_ROOT=$(git rev-parse --show-toplevel) GIT_ROOT=$(git rev-parse --show-toplevel)
...@@ -151,14 +165,24 @@ run_tests_for_model() { ...@@ -151,14 +165,24 @@ run_tests_for_model() {
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \ --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--tensor-parallel-size $PREFILLER_TP_SIZE \ --tensor-parallel-size $PREFILLER_TP_SIZE \
--kv-transfer-config '$KV_CONFIG'" --kv-transfer-config '$KV_CONFIG'"
if [[ -n "$VLLM_SERVE_EXTRA_ARGS" ]]; then
IFS=',' read -r -a extra_args <<< "$VLLM_SERVE_EXTRA_ARGS"
for arg in "${extra_args[@]}"; do
BASE_CMD="${BASE_CMD} $arg"
done
fi
# Add attention backend config if specified # Add attention backend config if specified
if [[ -n "$ATTENTION_BACKEND" ]]; then if [[ -n "$ATTENTION_BACKEND" ]]; then
BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND" BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND"
fi fi
# Add HMA flag if specified
if [[ -n "$ENABLE_HMA_VAR" ]]; then
BASE_CMD="${BASE_CMD} $ENABLE_HMA_VAR"
fi
FULL_CMD="$BASE_CMD" FULL_CMD="$BASE_CMD"
eval "$FULL_CMD &" eval "$FULL_CMD &"
# Store host and port for proxy configuration # Store host and port for proxy configuration
...@@ -193,12 +217,23 @@ run_tests_for_model() { ...@@ -193,12 +217,23 @@ run_tests_for_model() {
--block-size ${DECODE_BLOCK_SIZE} \ --block-size ${DECODE_BLOCK_SIZE} \
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \ --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--kv-transfer-config '$KV_CONFIG'" --kv-transfer-config '$KV_CONFIG'"
if [[ -n "$VLLM_SERVE_EXTRA_ARGS" ]]; then
IFS=',' read -r -a extra_args <<< "$VLLM_SERVE_EXTRA_ARGS"
for arg in "${extra_args[@]}"; do
BASE_CMD="${BASE_CMD} $arg"
done
fi
# Add attention backend config if specified # Add attention backend config if specified
if [[ -n "$ATTENTION_BACKEND" ]]; then if [[ -n "$ATTENTION_BACKEND" ]]; then
BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND" BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND"
fi fi
# Add HMA flag if specified
if [[ -n "$ENABLE_HMA_VAR" ]]; then
BASE_CMD="${BASE_CMD} $ENABLE_HMA_VAR"
fi
# DP-EP attention mode # DP-EP attention mode
if [[ -z "$DP_EP" ]]; then if [[ -z "$DP_EP" ]]; then
BASE_CMD="${BASE_CMD} --tensor-parallel-size $DECODER_TP_SIZE" BASE_CMD="${BASE_CMD} --tensor-parallel-size $DECODER_TP_SIZE"
......
...@@ -17,6 +17,7 @@ EXPECTED_VALUES = { ...@@ -17,6 +17,7 @@ EXPECTED_VALUES = {
"deepseek-ai/deepseek-vl2-small": 0.59, "deepseek-ai/deepseek-vl2-small": 0.59,
"deepseek-ai/deepseek-vl2-tiny": 0.19, "deepseek-ai/deepseek-vl2-tiny": 0.19,
"deepseek-ai/DeepSeek-V2-Lite-Chat": 0.65, "deepseek-ai/DeepSeek-V2-Lite-Chat": 0.65,
"google/gemma-3-4b-it": 0.74,
} }
SIMPLE_PROMPT = ( SIMPLE_PROMPT = (
......
...@@ -59,7 +59,12 @@ from vllm.v1.request import RequestStatus ...@@ -59,7 +59,12 @@ from vllm.v1.request import RequestStatus
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
from vllm.v1.worker.utils import AttentionGroup from vllm.v1.worker.utils import AttentionGroup
from .utils import create_request, create_scheduler, create_vllm_config from .utils import (
create_request,
create_scheduler,
create_vllm_config,
make_kv_cache_config,
)
@pytest.fixture(scope="module", autouse=True) @pytest.fixture(scope="module", autouse=True)
...@@ -263,7 +268,7 @@ def test_basic_interface(): ...@@ -263,7 +268,7 @@ def test_basic_interface():
req_meta = kv_connector_metadata.reqs_to_recv[request_id] req_meta = kv_connector_metadata.reqs_to_recv[request_id]
for block_id, block in zip( for block_id, block in zip(
req_meta.local_block_ids, req_meta.local_block_ids[0],
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[ scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[
request_id request_id
], ],
...@@ -327,7 +332,9 @@ def test_kv_transfer_handshake(dist_init): ...@@ -327,7 +332,9 @@ def test_kv_transfer_handshake(dist_init):
# Prefill connector will register KV cache to populate proper handshake # Prefill connector will register KV cache to populate proper handshake
# metadata. # metadata.
prefill_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) prefill_connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
) )
...@@ -367,13 +374,17 @@ def test_kv_transfer_handshake(dist_init): ...@@ -367,13 +374,17 @@ def test_kv_transfer_handshake(dist_init):
do_remote_decode=True, do_remote_decode=True,
) )
request.status = RequestStatus.FINISHED_LENGTH_CAPPED request.status = RequestStatus.FINISHED_LENGTH_CAPPED
delay, kv_connector_metadata = scheduler.get_kv_connector().request_finished( delay, kv_connector_metadata = (
request, [0, 1, 2] scheduler.get_kv_connector().request_finished_all_groups(
request, ([0, 1, 2],)
)
) )
assert delay assert delay
# Decode connector will be able to create handshake with the prefill connector. # Decode connector will be able to create handshake with the prefill connector.
decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) decode_connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
decode_connector.register_kv_caches(kv_caches) decode_connector.register_kv_caches(kv_caches)
# Here we are testing the retrieval of NIXLAgentMetadata. # Here we are testing the retrieval of NIXLAgentMetadata.
...@@ -404,9 +415,16 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): ...@@ -404,9 +415,16 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
REMOTE_ENGINE_ID = "remote_engine" REMOTE_ENGINE_ID = "remote_engine"
def __init__( def __init__(
self, *args, hand_shake_latency: float = 1.8, kv_cache_layout="HND", **kwargs self,
*args,
hand_shake_latency: float = 1.8,
kv_cache_layout="HND",
kv_cache_config=None,
**kwargs,
): ):
super().__init__(*args, **kwargs) if kv_cache_config is None:
kv_cache_config = make_kv_cache_config(block_size=16)
super().__init__(*args, kv_cache_config=kv_cache_config, **kwargs)
self._hand_shake_latency = hand_shake_latency self._hand_shake_latency = hand_shake_latency
self.kv_cache_layout = kv_cache_layout self.kv_cache_layout = kv_cache_layout
# Mock register_kv_caches attribute needed for tests that do not call it. # Mock register_kv_caches attribute needed for tests that do not call it.
...@@ -507,7 +525,9 @@ class TestNixlHandshake: ...@@ -507,7 +525,9 @@ class TestNixlHandshake:
request_id = "req_id" request_id = "req_id"
# Test worker role in decode server. # Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker( connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0 vllm_config, connector.engine_id, hand_shake_latency=0
) )
...@@ -528,13 +548,15 @@ class TestNixlHandshake: ...@@ -528,13 +548,15 @@ class TestNixlHandshake:
num_xfers -= 1 num_xfers -= 1
metadata.add_new_req_to_recv( metadata.add_new_req_to_recv(
request_id=request_id, request_id=request_id,
local_block_ids=[num_xfers + 1, num_xfers + 2, num_xfers + 3], local_block_ids=([num_xfers + 1, num_xfers + 2, num_xfers + 3],),
kv_transfer_params={ kv_transfer_params={
"remote_block_ids": [ "remote_block_ids": (
num_xfers + 4, [
num_xfers + 5, num_xfers + 4,
num_xfers + 6, num_xfers + 5,
], num_xfers + 6,
],
),
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": f"prefill-{request_id}", "remote_request_id": f"prefill-{request_id}",
"remote_host": "localhost", "remote_host": "localhost",
...@@ -594,16 +616,18 @@ class TestNixlHandshake: ...@@ -594,16 +616,18 @@ class TestNixlHandshake:
vllm_config.parallel_config.tensor_parallel_size = decode_tp_size vllm_config.parallel_config.tensor_parallel_size = decode_tp_size
# Test worker role in decode server. # Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker( connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id vllm_config, connector.engine_id
) )
metadata = NixlConnectorMetadata() metadata = NixlConnectorMetadata()
metadata.add_new_req_to_recv( metadata.add_new_req_to_recv(
request_id="id", request_id="id",
local_block_ids=[1, 2, 3], local_block_ids=([1, 2, 3],),
kv_transfer_params={ kv_transfer_params={
"remote_block_ids": [4, 5, 6], "remote_block_ids": ([4, 5, 6],),
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": "prefill-id", "remote_request_id": "prefill-id",
"remote_host": "localhost", "remote_host": "localhost",
...@@ -652,7 +676,9 @@ class TestNixlHandshake: ...@@ -652,7 +676,9 @@ class TestNixlHandshake:
local_tp_size = 1 local_tp_size = 1
vllm_config.parallel_config.tensor_parallel_size = local_tp_size vllm_config.parallel_config.tensor_parallel_size = local_tp_size
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker( connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0 vllm_config, connector.engine_id, hand_shake_latency=0
) )
...@@ -717,8 +743,12 @@ class TestNixlHandshake: ...@@ -717,8 +743,12 @@ class TestNixlHandshake:
p_tp_size = 2 p_tp_size = 2
# Build two separate connectors/workers to emulate P TP=2 ranks. # Build two separate connectors/workers to emulate P TP=2 ranks.
conn_p0 = NixlConnector(vllm_config, KVConnectorRole.WORKER) conn_p0 = NixlConnector(
conn_p1 = NixlConnector(vllm_config, KVConnectorRole.WORKER) vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
conn_p1 = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
conn_p0.connector_worker = FakeNixlConnectorWorker( conn_p0.connector_worker = FakeNixlConnectorWorker(
vllm_config, conn_p0.engine_id, hand_shake_latency=0 vllm_config, conn_p0.engine_id, hand_shake_latency=0
) )
...@@ -815,7 +845,9 @@ class TestNixlHandshake: ...@@ -815,7 +845,9 @@ class TestNixlHandshake:
vllm_config = create_vllm_config() vllm_config = create_vllm_config()
# Test worker role in decode server. # Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker( connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id vllm_config, connector.engine_id
) )
...@@ -827,9 +859,9 @@ class TestNixlHandshake: ...@@ -827,9 +859,9 @@ class TestNixlHandshake:
for i in range(total_reqs): for i in range(total_reqs):
metadata.add_new_req_to_recv( metadata.add_new_req_to_recv(
request_id=f"id_{i}", request_id=f"id_{i}",
local_block_ids=[1, 2, 3], local_block_ids=([1, 2, 3],),
kv_transfer_params={ kv_transfer_params={
"remote_block_ids": [4, 5, 6], "remote_block_ids": ([4, 5, 6],),
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": f"prefill-id-{i}", "remote_request_id": f"prefill-id-{i}",
"remote_host": "localhost", "remote_host": "localhost",
...@@ -884,7 +916,9 @@ class TestNixlHandshake: ...@@ -884,7 +916,9 @@ class TestNixlHandshake:
return_value=2, return_value=2,
): ):
# Initialize connector and worker (with fake NIXL wrapper) # Initialize connector and worker (with fake NIXL wrapper)
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker( connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0 vllm_config, connector.engine_id, hand_shake_latency=0
) )
...@@ -934,7 +968,9 @@ class TestNixlHandshake: ...@@ -934,7 +968,9 @@ class TestNixlHandshake:
return_value=2, return_value=2,
): ):
# Initialize connector and worker (with fake NIXL wrapper) # Initialize connector and worker (with fake NIXL wrapper)
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker( connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, vllm_config,
connector.engine_id, connector.engine_id,
...@@ -979,7 +1015,9 @@ def test_kv_connector_stats(default_vllm_config, dist_init): ...@@ -979,7 +1015,9 @@ def test_kv_connector_stats(default_vllm_config, dist_init):
vllm_config = create_vllm_config() vllm_config = create_vllm_config()
# Test worker role in decode server. # Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker( connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0 vllm_config, connector.engine_id, hand_shake_latency=0
) )
...@@ -993,9 +1031,9 @@ def test_kv_connector_stats(default_vllm_config, dist_init): ...@@ -993,9 +1031,9 @@ def test_kv_connector_stats(default_vllm_config, dist_init):
metadata = NixlConnectorMetadata() metadata = NixlConnectorMetadata()
metadata.add_new_req_to_recv( metadata.add_new_req_to_recv(
request_id=request_id, request_id=request_id,
local_block_ids=[1, 2, 3], local_block_ids=([1, 2, 3],),
kv_transfer_params={ kv_transfer_params={
"remote_block_ids": [4, 5, 6], "remote_block_ids": ([4, 5, 6],),
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": f"prefill-{request_id}", "remote_request_id": f"prefill-{request_id}",
"remote_host": "localhost", "remote_host": "localhost",
...@@ -1448,7 +1486,9 @@ def test_register_kv_caches( ...@@ -1448,7 +1486,9 @@ def test_register_kv_caches(
mock_get_attn_backend.return_value = backend_cls mock_get_attn_backend.return_value = backend_cls
# Create connector # Create connector
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker( connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0 vllm_config, connector.engine_id, hand_shake_latency=0
) )
...@@ -1676,7 +1716,9 @@ def test_kv_buffer_to_nixl_memory_types( ...@@ -1676,7 +1716,9 @@ def test_kv_buffer_to_nixl_memory_types(
), ),
): # noqa: E501 ): # noqa: E501
# Create connector and replace its worker with a fake one for isolation # Create connector and replace its worker with a fake one for isolation
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
# Verify get_reg_descs was called with the correct memory_type # Verify get_reg_descs was called with the correct memory_type
assert connector.connector_worker.kv_buffer_device == kv_buffer_device assert connector.connector_worker.kv_buffer_device == kv_buffer_device
...@@ -1692,9 +1734,15 @@ def test_shutdown_cleans_up_resources(default_vllm_config, dist_init): ...@@ -1692,9 +1734,15 @@ def test_shutdown_cleans_up_resources(default_vllm_config, dist_init):
vllm_config = create_vllm_config() vllm_config = create_vllm_config()
scheduler = NixlConnectorScheduler( scheduler = NixlConnectorScheduler(
vllm_config, vllm_config.kv_transfer_config.engine_id vllm_config,
vllm_config.kv_transfer_config.engine_id,
make_kv_cache_config(block_size=16),
)
worker = NixlConnectorWorker(
vllm_config,
vllm_config.kv_transfer_config.engine_id,
make_kv_cache_config(block_size=16),
) )
worker = NixlConnectorWorker(vllm_config, vllm_config.kv_transfer_config.engine_id)
nixl_wrapper = worker.nixl_wrapper nixl_wrapper = worker.nixl_wrapper
with ( with (
...@@ -1756,7 +1804,9 @@ def test_aborted_request_removed_from_worker_in_batch(default_vllm_config, dist_ ...@@ -1756,7 +1804,9 @@ def test_aborted_request_removed_from_worker_in_batch(default_vllm_config, dist_
scheduler = create_scheduler(vllm_config) scheduler = create_scheduler(vllm_config)
# KVConnector Worker in P # KVConnector Worker in P
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker( connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0 vllm_config, connector.engine_id, hand_shake_latency=0
) )
...@@ -1875,12 +1925,14 @@ class FailingNixlWrapper(FakeNixlWrapper): ...@@ -1875,12 +1925,14 @@ class FailingNixlWrapper(FakeNixlWrapper):
("transfer_exception", {"fail_transfer_exception": True}, True), ("transfer_exception", {"fail_transfer_exception": True}, True),
], ],
) )
@pytest.mark.parametrize("enable_hma", [False, True])
def test_transfer_failure_logging( def test_transfer_failure_logging(
default_vllm_config, default_vllm_config,
dist_init, dist_init,
failure_type, failure_type,
wrapper_config, wrapper_config,
needs_get_finished, needs_get_finished,
enable_hma,
): ):
"""Test that transfer failures are logged with structured context. """Test that transfer failures are logged with structured context.
...@@ -1897,9 +1949,16 @@ def test_transfer_failure_logging( ...@@ -1897,9 +1949,16 @@ def test_transfer_failure_logging(
vllm_config = create_vllm_config() vllm_config = create_vllm_config()
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) connector = NixlConnector(
vllm_config,
KVConnectorRole.WORKER,
make_kv_cache_config(block_size=16, hma_enabled=enable_hma),
)
connector.connector_worker = FakeNixlConnectorWorker( connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0.0 vllm_config,
connector.engine_id,
hand_shake_latency=0.0,
kv_cache_config=connector._kv_cache_config,
) )
# Configure FailingNixlWrapper to fail in the specified way # Configure FailingNixlWrapper to fail in the specified way
...@@ -1910,8 +1969,17 @@ def test_transfer_failure_logging( ...@@ -1910,8 +1969,17 @@ def test_transfer_failure_logging(
# For notification_failed, we need empty local blocks # For notification_failed, we need empty local blocks
# (full cache hit path to trigger send_notif) # (full cache hit path to trigger send_notif)
local_blocks = [] if failure_type == "notification_failed" else [10, 11, 12] local_blocks: tuple[()] | tuple[list[int], ...]
remote_blocks = [20, 21, 22] if enable_hma:
# HMA enabled: multiple groups (FA + SW)
local_blocks = (
() if failure_type == "notification_failed" else ([10, 11, 12], [13, 14])
)
remote_blocks = [[20, 21, 22], [23, 24]]
else:
# HMA disabled: single group
local_blocks = () if failure_type == "notification_failed" else ([10, 11, 12],)
remote_blocks = [[20, 21, 22]]
metadata = NixlConnectorMetadata() metadata = NixlConnectorMetadata()
metadata.add_new_req_to_recv( metadata.add_new_req_to_recv(
...@@ -2007,7 +2075,9 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init): ...@@ -2007,7 +2075,9 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
"""Test that handshake failures mark blocks invalid and return via get_finished.""" """Test that handshake failures mark blocks invalid and return via get_finished."""
vllm_config = create_vllm_config() vllm_config = create_vllm_config()
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker( connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0.1 vllm_config, connector.engine_id, hand_shake_latency=0.1
) )
...@@ -2017,9 +2087,9 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init): ...@@ -2017,9 +2087,9 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
metadata = NixlConnectorMetadata() metadata = NixlConnectorMetadata()
metadata.add_new_req_to_recv( metadata.add_new_req_to_recv(
request_id=request_id, request_id=request_id,
local_block_ids=[1, 2, 3], local_block_ids=([1, 2, 3],),
kv_transfer_params={ kv_transfer_params={
"remote_block_ids": [4, 5, 6], "remote_block_ids": ([4, 5, 6],),
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": f"prefill-{request_id}", "remote_request_id": f"prefill-{request_id}",
"remote_host": "localhost", "remote_host": "localhost",
...@@ -2058,7 +2128,9 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init) ...@@ -2058,7 +2128,9 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init)
and return via get_finished.""" and return via get_finished."""
vllm_config = create_vllm_config() vllm_config = create_vllm_config()
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
connector.connector_worker = FakeNixlConnectorWorker( connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0 vllm_config, connector.engine_id, hand_shake_latency=0
) )
...@@ -2068,9 +2140,9 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init) ...@@ -2068,9 +2140,9 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init)
metadata = NixlConnectorMetadata() metadata = NixlConnectorMetadata()
metadata.add_new_req_to_recv( metadata.add_new_req_to_recv(
request_id=request_id, request_id=request_id,
local_block_ids=[7, 8, 9], local_block_ids=([7, 8, 9],),
kv_transfer_params={ kv_transfer_params={
"remote_block_ids": [10, 11, 12], "remote_block_ids": ([10, 11, 12],),
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": f"prefill-{request_id}", "remote_request_id": f"prefill-{request_id}",
"remote_host": "localhost", "remote_host": "localhost",
...@@ -2154,7 +2226,9 @@ def test_compatibility_hash_validation( ...@@ -2154,7 +2226,9 @@ def test_compatibility_hash_validation(
"enforce_handshake_compat": enforce_handshake_compat "enforce_handshake_compat": enforce_handshake_compat
}, },
) )
decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER) decode_connector = NixlConnector(
local_vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
decode_worker = decode_connector.connector_worker decode_worker = decode_connector.connector_worker
kv_cache_shape = decode_worker.attn_backend.get_kv_cache_shape( kv_cache_shape = decode_worker.attn_backend.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
...@@ -2267,7 +2341,9 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario) ...@@ -2267,7 +2341,9 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
model="facebook/opt-125m", model="facebook/opt-125m",
block_size=16, block_size=16,
) )
decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER) decode_connector = NixlConnector(
local_vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
)
decode_worker = decode_connector.connector_worker decode_worker = decode_connector.connector_worker
backend = get_current_attn_backend(local_vllm_config) backend = get_current_attn_backend(local_vllm_config)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for NixlConnectorScheduler sw_sizes calculation with HMA."""
from unittest.mock import patch
import pytest
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
from vllm.v1.core.single_type_kv_cache_manager import (
FullAttentionManager,
SlidingWindowManager,
)
from .utils import (
create_vllm_config,
make_kv_cache_config,
)
@pytest.mark.cpu_test
@pytest.mark.parametrize(
"hma_enabled,expected_sw_sizes",
[
# HMA enabled: FullAttentionSpec (0) + SlidingWindowSpec (2048/16=128)
(True, [0, 128 + 1]),
# HMA disabled: only FullAttentionSpec (0)
(False, [0]),
],
)
@patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform")
def test_sw_sizes(mock_platform, hma_enabled, expected_sw_sizes):
"""Test sw_sizes is correctly computed based on HMA enabled/disabled."""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorScheduler,
)
mock_platform.device_type = "cpu"
block_size = 16
vllm_config = create_vllm_config(block_size=block_size)
# SW 2048 tokens=>128 blocks
kv_cache_config = make_kv_cache_config(
block_size=block_size, hma_enabled=hma_enabled, sw_size=2048
)
scheduler = NixlConnectorScheduler(
vllm_config=vllm_config,
engine_id="test-engine",
kv_cache_config=kv_cache_config,
)
# in number of blocks
assert scheduler.blocks_per_sw == expected_sw_sizes, (
f"Expected sw_sizes={expected_sw_sizes}, got {scheduler.blocks_per_sw}"
)
@pytest.mark.cpu_test
def test_logical_to_kernel_block_ids_with_hma():
"""Test _logical_to_kernel_block_ids expands blocks when HMA is enabled.
When HMA is enabled, the logical block size may differ from the kernel
block size. Each logical block maps to multiple kernel blocks.
"""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorWorker,
)
# Create a mock worker with just the required attributes
# (use __new__ to skip __init__)
worker = object.__new__(NixlConnectorWorker)
# Simulate HMA scenario: logical block size = 32, kernel block size = 16
# So each logical block maps to 2 kernel blocks eg [0]->[0,1]
worker._physical_blocks_per_logical_kv_block = 2
# Test conversion: FA + SW group
logical_block_ids = [[0, 1, 2], [3, 4]]
kernel_block_ids = worker._logical_to_kernel_block_ids(logical_block_ids)
expected_kernel_block_ids = [[0, 1, 2, 3, 4, 5], [6, 7, 8, 9]]
assert kernel_block_ids == expected_kernel_block_ids, (
f"Expected {expected_kernel_block_ids}, got {kernel_block_ids}"
)
@pytest.mark.parametrize("model_name, sw_size", [("google/gemma-3-1b-it", 512)])
def test_fewer_blocks_with_hma(monkeypatch, model_name, sw_size):
"""Test that a prefill instance returns fewer "remote blocks" for the SWA groups
when sequence exceeds the sliding window.
"""
kv_transfer_config = KVTransferConfig(
kv_connector="NixlConnector",
kv_role="kv_both",
)
block_size = 16
llm_kwargs = {
"model": model_name,
"enforce_eager": True,
"gpu_memory_utilization": 0.5,
"kv_transfer_config": kv_transfer_config,
"max_model_len": 2048,
# NOTE: Make sure HMA is enabled
"disable_hybrid_kv_cache_manager": False,
"max_num_batched_tokens": 1024,
"enable_prefix_caching": False,
"block_size": block_size,
}
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
def run_hma_test(llm: LLM):
remote_prefill_opts = {
"do_remote_decode": True,
"do_remote_prefill": False,
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": None,
"remote_port": None,
}
# Simulate sidecar request
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=1,
extra_args={"kv_transfer_params": remote_prefill_opts},
)
scheduler = llm.llm_engine.engine_core.engine_core.scheduler
kv_managers = scheduler.kv_cache_manager.coordinator.single_type_managers
# HMA enabled with FA + SWA groups
assert len(kv_managers) > 2
for kv_manager in kv_managers:
assert isinstance(kv_manager, (SlidingWindowManager, FullAttentionManager))
req_to_blocks = kv_managers[0].req_to_blocks
assert len(req_to_blocks) == 0
# Process some request with length exceeding the sliding window
outputs = llm.generate(["hi" * 1401], sampling_params)
kv_params = outputs[0].kv_transfer_params
# +1 to account for overlapping window across blocks.
expected_num_remote_blocks = sw_size // block_size + 1
remote_block_ids = kv_params["remote_block_ids"]
assert (
len(remote_block_ids[0])
== expected_num_remote_blocks
< len(remote_block_ids[-1])
)
for group_block_ids in remote_block_ids[:-1]:
assert len(group_block_ids) == expected_num_remote_blocks
def run_test_and_cleanup():
llm = LLM(**llm_kwargs)
try:
run_hma_test(llm)
finally:
llm.llm_engine.engine_core.shutdown()
run_test_and_cleanup()
@pytest.mark.cpu_test
def test_nixl_metadata_hma_block_ids_structure():
"""
Test that NixlConnectorMetadata correctly stores block IDs for multiple
KV cache groups when HMA is enabled.
"""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorMetadata,
)
metadata = NixlConnectorMetadata()
# Add request with block IDs for 2 groups (FA + SW)
fa_blocks = [0, 1, 2, 3, 4, 5, 6, 7] # 8 blocks for FA
sw_blocks = [8, 9, 10, 11] # 4 blocks for SW (clipped)
metadata.add_new_req_to_recv(
request_id="test-req-hma",
local_block_ids=(fa_blocks, sw_blocks),
kv_transfer_params={
"remote_block_ids": ([10, 11, 12, 13, 14, 15, 16, 17], [18, 19, 20, 21]),
"remote_engine_id": "remote-engine",
"remote_request_id": "prefill-test-req-hma",
"remote_host": "localhost",
"remote_port": 1234,
"tp_size": 1,
},
)
assert "test-req-hma" in metadata.reqs_to_recv
req_meta = metadata.reqs_to_recv["test-req-hma"]
# Verify local block IDs structure
assert len(req_meta.local_block_ids) == 2
assert list(req_meta.local_block_ids[0]) == fa_blocks
assert list(req_meta.local_block_ids[1]) == sw_blocks
# Verify remote block IDs structure
assert req_meta.remote is not None
assert len(req_meta.remote.block_ids) == 2
assert list(req_meta.remote.block_ids[0]) == [10, 11, 12, 13, 14, 15, 16, 17]
assert list(req_meta.remote.block_ids[1]) == [18, 19, 20, 21]
...@@ -208,7 +208,9 @@ def test_prefix_cache_lifecycle(): ...@@ -208,7 +208,9 @@ def test_prefix_cache_lifecycle():
# Ensure we send all block ids, including the partial blocks, # Ensure we send all block ids, including the partial blocks,
# even if there is a cache hit. # even if there is a cache hit.
assert len(kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS + 1) # remote_block_ids is BlockIds (tuple of lists); sum block counts across groups.
num_remote_blocks = sum(len(g) for g in kv_transfer_params["remote_block_ids"])
assert num_remote_blocks == (NUM_EXTERNAL_FULL_BLOCKS + 1)
# STEP (2): Ensure it is freed. # STEP (2): Ensure it is freed.
scheduler_output = scheduler.schedule() scheduler_output = scheduler.schedule()
......
...@@ -36,6 +36,7 @@ from vllm.v1.kv_cache_interface import ( ...@@ -36,6 +36,7 @@ from vllm.v1.kv_cache_interface import (
FullAttentionSpec, FullAttentionSpec,
KVCacheConfig, KVCacheConfig,
KVCacheGroupSpec, KVCacheGroupSpec,
SlidingWindowSpec,
) )
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import Request from vllm.v1.request import Request
...@@ -142,24 +143,26 @@ def create_vllm_config( ...@@ -142,24 +143,26 @@ def create_vllm_config(
def create_scheduler( def create_scheduler(
vllm_config: VllmConfig, vllm_config: VllmConfig,
num_blocks: int = 10000, num_blocks: int = 10000,
kv_cache_config: KVCacheConfig | None = None,
) -> Scheduler: ) -> Scheduler:
"""Initialize Scheduler For Testing.""" """Initialize Scheduler For Testing."""
block_size = vllm_config.cache_config.block_size block_size = vllm_config.cache_config.block_size
kv_cache_config = KVCacheConfig( if kv_cache_config is None:
num_blocks=num_blocks, # A large number of blocks to hold all requests kv_cache_config = KVCacheConfig(
kv_cache_tensors=[], num_blocks=num_blocks, # A large number of blocks to hold all requests
kv_cache_groups=[ kv_cache_tensors=[],
KVCacheGroupSpec( kv_cache_groups=[
["layer"], KVCacheGroupSpec(
FullAttentionSpec( ["layer"],
block_size=block_size, FullAttentionSpec(
num_kv_heads=1, block_size=block_size,
head_size=1, num_kv_heads=1,
dtype=torch.float32, head_size=1,
), dtype=torch.float32,
) ),
], )
) ],
)
vllm_config.cache_config.num_gpu_blocks = num_blocks vllm_config.cache_config.num_gpu_blocks = num_blocks
return Scheduler( return Scheduler(
vllm_config=vllm_config, vllm_config=vllm_config,
...@@ -412,3 +415,38 @@ KVConnectorFactory.register_connector( ...@@ -412,3 +415,38 @@ KVConnectorFactory.register_connector(
KVConnectorFactory.register_connector( KVConnectorFactory.register_connector(
"MockKVConnector", __name__, MockKVConnector.__name__ "MockKVConnector", __name__, MockKVConnector.__name__
) )
def make_kv_cache_config(
block_size: int,
hma_enabled: bool = False,
sw_size: int = 128,
num_blocks: int = 100,
) -> KVCacheConfig:
kv_cache_groups = [
KVCacheGroupSpec(
["layer0", "layer2"],
FullAttentionSpec(
block_size=block_size,
num_kv_heads=4,
head_size=16,
dtype=torch.float16,
),
)
]
if hma_enabled:
kv_cache_groups.append(
KVCacheGroupSpec(
["layer1", "layer3"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=4,
head_size=16,
dtype=torch.float16,
sliding_window=sw_size,
),
)
)
return KVCacheConfig(
num_blocks=num_blocks, kv_cache_tensors=[], kv_cache_groups=kv_cache_groups
)
...@@ -24,6 +24,9 @@ if TYPE_CHECKING: ...@@ -24,6 +24,9 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
EngineId = str EngineId = str
# block ids as returned by the hybrid KV cache manager. list[list[int]] are allow
# mutability and are for connector internal use only.
BlockIds = tuple[list[int], ...] | list[list[int]]
def get_kv_connector_cache_layout(): def get_kv_connector_cache_layout():
......
...@@ -84,6 +84,18 @@ class KVCacheBlocks: ...@@ -84,6 +84,18 @@ class KVCacheBlocks:
assert len(self.blocks) == 1, "Only one group is supported" assert len(self.blocks) == 1, "Only one group is supported"
return [block.block_id for block in self.blocks[0] if block.block_hash is None] return [block.block_id for block in self.blocks[0] if block.block_hash is None]
def get_unhashed_block_ids_all_groups(self) -> list[list[int]]:
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
# Skip padding blocks.
return [
[
block.block_id
for block in group
if block.block_hash is None and not block.is_null
]
for group in self.blocks
]
def new_empty(self) -> "KVCacheBlocks": def new_empty(self) -> "KVCacheBlocks":
""" """
Creates a new KVCacheBlocks instance with no blocks. Creates a new KVCacheBlocks instance with no blocks.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment