Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5b3ba94a
Unverified
Commit
5b3ba94a
authored
Mar 06, 2026
by
Nicolò Lucchesi
Committed by
GitHub
Mar 06, 2026
Browse files
[Core][KVConnector] Support HMA+NixlConnector (#35758)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
90f3c01f
Changes
10
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
669 additions
and
230 deletions
+669
-230
tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
..._connector/nixl_integration/config_sweep_accuracy_test.sh
+9
-0
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
+36
-1
tests/v1/kv_connector/nixl_integration/test_accuracy.py
tests/v1/kv_connector/nixl_integration/test_accuracy.py
+1
-0
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+122
-46
tests/v1/kv_connector/unit/test_nixl_connector_hma.py
tests/v1/kv_connector/unit/test_nixl_connector_hma.py
+203
-0
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
+3
-1
tests/v1/kv_connector/unit/utils.py
tests/v1/kv_connector/unit/utils.py
+53
-15
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+3
-0
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+227
-167
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+12
-0
No files found.
tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
View file @
5b3ba94a
...
...
@@ -12,6 +12,7 @@ tp_configs=(
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
# MLA case
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192"
# SW model
)
dp_ep_configs
=(
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
# MLA+P-TP1, D-DPEP=2 (TP=1)
...
...
@@ -26,6 +27,14 @@ else
configs
=(
"
${
tp_configs
[@]
}
"
)
fi
if
[[
-n
"
${
ENABLE_HMA_FLAG
:-}
"
]]
;
then
# Append ENABLE_HMA_FLAG=1 to each config in the selected array
echo
"ENABLE_HMA_FLAG is set, appending ENABLE_HMA_FLAG=1 to each config"
for
i
in
"
${
!configs[@]
}
"
;
do
configs[
$i
]=
"ENABLE_HMA_FLAG=1
${
configs
[
$i
]
}
"
done
fi
run_tests
()
{
local
label
=
$1
local
extra_args
=
$2
...
...
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
View file @
5b3ba94a
...
...
@@ -5,6 +5,12 @@ set -xe
KV_BUFFER_DEVICE
=
"cuda"
# Default to cuda
ATTENTION_BACKEND
=
""
# Default to empty (use vllm default)
CROSS_LAYERS_BLOCKS
=
"False"
ENABLE_HMA_VAR
=
""
# Default to empty (HMA disabled by default for kv connector)
# Check for ENABLE_HMA_FLAG environment variable
if
[[
-n
"
${
ENABLE_HMA_FLAG
:-}
"
]]
;
then
ENABLE_HMA_VAR
=
"--no-disable-hybrid-kv-cache-manager"
fi
while
[[
$#
-gt
0
]]
;
do
case
$1
in
--kv_buffer_device
)
...
...
@@ -31,6 +37,12 @@ echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE"
if
[[
-n
"
$ATTENTION_BACKEND
"
]]
;
then
echo
"Using attention backend:
$ATTENTION_BACKEND
"
fi
if
[[
-n
"
$ENABLE_HMA_VAR
"
]]
;
then
echo
"HMA (Hybrid KV Cache Manager) enabled"
fi
if
[[
-n
"
$VLLM_SERVE_EXTRA_ARGS
"
]]
;
then
echo
"vLLM serve extra args:
$VLLM_SERVE_EXTRA_ARGS
"
fi
DECODER_KV_LAYOUT
=
${
DECODER_KV_LAYOUT
:-
"HND"
}
# Default to HND, optional NHD
if
[[
"
$DECODER_KV_LAYOUT
"
==
"NHD"
]]
;
then
...
...
@@ -70,6 +82,8 @@ DECODER_TP_SIZE=${DECODER_TP_SIZE:-1}
GPU_MEMORY_UTILIZATION
=
${
GPU_MEMORY_UTILIZATION
:-
0
.2
}
PREFILL_BLOCK_SIZE
=
${
PREFILL_BLOCK_SIZE
:-
128
}
DECODE_BLOCK_SIZE
=
${
DECODE_BLOCK_SIZE
:-
128
}
# Comma-separated extra args for vllm serve (e.g. --max-model-len,2048)
VLLM_SERVE_EXTRA_ARGS
=
${
VLLM_SERVE_EXTRA_ARGS
:-}
# Find the git repository root directory
GIT_ROOT
=
$(
git rev-parse
--show-toplevel
)
...
...
@@ -151,14 +165,24 @@ run_tests_for_model() {
--gpu-memory-utilization
$GPU_MEMORY_UTILIZATION
\
--tensor-parallel-size
$PREFILLER_TP_SIZE
\
--kv-transfer-config '
$KV_CONFIG
'"
if
[[
-n
"
$VLLM_SERVE_EXTRA_ARGS
"
]]
;
then
IFS
=
','
read
-r
-a
extra_args
<<<
"
$VLLM_SERVE_EXTRA_ARGS
"
for
arg
in
"
${
extra_args
[@]
}
"
;
do
BASE_CMD
=
"
${
BASE_CMD
}
$arg
"
done
fi
# Add attention backend config if specified
if
[[
-n
"
$ATTENTION_BACKEND
"
]]
;
then
BASE_CMD
=
"
${
BASE_CMD
}
--attention-backend=
$ATTENTION_BACKEND
"
fi
FULL_CMD
=
"
$BASE_CMD
"
# Add HMA flag if specified
if
[[
-n
"
$ENABLE_HMA_VAR
"
]]
;
then
BASE_CMD
=
"
${
BASE_CMD
}
$ENABLE_HMA_VAR
"
fi
FULL_CMD
=
"
$BASE_CMD
"
eval
"
$FULL_CMD
&"
# Store host and port for proxy configuration
...
...
@@ -193,12 +217,23 @@ run_tests_for_model() {
--block-size
${
DECODE_BLOCK_SIZE
}
\
--gpu-memory-utilization
$GPU_MEMORY_UTILIZATION
\
--kv-transfer-config '
$KV_CONFIG
'"
if
[[
-n
"
$VLLM_SERVE_EXTRA_ARGS
"
]]
;
then
IFS
=
','
read
-r
-a
extra_args
<<<
"
$VLLM_SERVE_EXTRA_ARGS
"
for
arg
in
"
${
extra_args
[@]
}
"
;
do
BASE_CMD
=
"
${
BASE_CMD
}
$arg
"
done
fi
# Add attention backend config if specified
if
[[
-n
"
$ATTENTION_BACKEND
"
]]
;
then
BASE_CMD
=
"
${
BASE_CMD
}
--attention-backend=
$ATTENTION_BACKEND
"
fi
# Add HMA flag if specified
if
[[
-n
"
$ENABLE_HMA_VAR
"
]]
;
then
BASE_CMD
=
"
${
BASE_CMD
}
$ENABLE_HMA_VAR
"
fi
# DP-EP attention mode
if
[[
-z
"
$DP_EP
"
]]
;
then
BASE_CMD
=
"
${
BASE_CMD
}
--tensor-parallel-size
$DECODER_TP_SIZE
"
...
...
tests/v1/kv_connector/nixl_integration/test_accuracy.py
View file @
5b3ba94a
...
...
@@ -17,6 +17,7 @@ EXPECTED_VALUES = {
"deepseek-ai/deepseek-vl2-small"
:
0.59
,
"deepseek-ai/deepseek-vl2-tiny"
:
0.19
,
"deepseek-ai/DeepSeek-V2-Lite-Chat"
:
0.65
,
"google/gemma-3-4b-it"
:
0.74
,
}
SIMPLE_PROMPT
=
(
...
...
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
5b3ba94a
...
...
@@ -59,7 +59,12 @@ from vllm.v1.request import RequestStatus
from
vllm.v1.worker.kv_connector_model_runner_mixin
import
KVConnectorModelRunnerMixin
from
vllm.v1.worker.utils
import
AttentionGroup
from
.utils
import
create_request
,
create_scheduler
,
create_vllm_config
from
.utils
import
(
create_request
,
create_scheduler
,
create_vllm_config
,
make_kv_cache_config
,
)
@
pytest
.
fixture
(
scope
=
"module"
,
autouse
=
True
)
...
...
@@ -263,7 +268,7 @@ def test_basic_interface():
req_meta
=
kv_connector_metadata
.
reqs_to_recv
[
request_id
]
for
block_id
,
block
in
zip
(
req_meta
.
local_block_ids
,
req_meta
.
local_block_ids
[
0
]
,
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
[
0
].
req_to_blocks
[
request_id
],
...
...
@@ -327,7 +332,9 @@ def test_kv_transfer_handshake(dist_init):
# Prefill connector will register KV cache to populate proper handshake
# metadata.
prefill_connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
prefill_connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
kv_cache_shape
=
FlashAttentionBackend
.
get_kv_cache_shape
(
num_blocks
=
2
,
block_size
=
16
,
num_kv_heads
=
4
,
head_size
=
64
)
...
...
@@ -367,13 +374,17 @@ def test_kv_transfer_handshake(dist_init):
do_remote_decode
=
True
,
)
request
.
status
=
RequestStatus
.
FINISHED_LENGTH_CAPPED
delay
,
kv_connector_metadata
=
scheduler
.
get_kv_connector
().
request_finished
(
request
,
[
0
,
1
,
2
]
delay
,
kv_connector_metadata
=
(
scheduler
.
get_kv_connector
().
request_finished_all_groups
(
request
,
([
0
,
1
,
2
],)
)
)
assert
delay
# Decode connector will be able to create handshake with the prefill connector.
decode_connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
decode_connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
decode_connector
.
register_kv_caches
(
kv_caches
)
# Here we are testing the retrieval of NIXLAgentMetadata.
...
...
@@ -404,9 +415,16 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
REMOTE_ENGINE_ID
=
"remote_engine"
def
__init__
(
self
,
*
args
,
hand_shake_latency
:
float
=
1.8
,
kv_cache_layout
=
"HND"
,
**
kwargs
self
,
*
args
,
hand_shake_latency
:
float
=
1.8
,
kv_cache_layout
=
"HND"
,
kv_cache_config
=
None
,
**
kwargs
,
):
super
().
__init__
(
*
args
,
**
kwargs
)
if
kv_cache_config
is
None
:
kv_cache_config
=
make_kv_cache_config
(
block_size
=
16
)
super
().
__init__
(
*
args
,
kv_cache_config
=
kv_cache_config
,
**
kwargs
)
self
.
_hand_shake_latency
=
hand_shake_latency
self
.
kv_cache_layout
=
kv_cache_layout
# Mock register_kv_caches attribute needed for tests that do not call it.
...
...
@@ -507,7 +525,9 @@ class TestNixlHandshake:
request_id
=
"req_id"
# Test worker role in decode server.
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
...
...
@@ -528,13 +548,15 @@ class TestNixlHandshake:
num_xfers
-=
1
metadata
.
add_new_req_to_recv
(
request_id
=
request_id
,
local_block_ids
=
[
num_xfers
+
1
,
num_xfers
+
2
,
num_xfers
+
3
],
local_block_ids
=
(
[
num_xfers
+
1
,
num_xfers
+
2
,
num_xfers
+
3
],
),
kv_transfer_params
=
{
"remote_block_ids"
:
[
"remote_block_ids"
:
(
[
num_xfers
+
4
,
num_xfers
+
5
,
num_xfers
+
6
,
],
),
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_request_id"
:
f
"prefill-
{
request_id
}
"
,
"remote_host"
:
"localhost"
,
...
...
@@ -594,16 +616,18 @@ class TestNixlHandshake:
vllm_config
.
parallel_config
.
tensor_parallel_size
=
decode_tp_size
# Test worker role in decode server.
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
)
metadata
=
NixlConnectorMetadata
()
metadata
.
add_new_req_to_recv
(
request_id
=
"id"
,
local_block_ids
=
[
1
,
2
,
3
],
local_block_ids
=
(
[
1
,
2
,
3
],
),
kv_transfer_params
=
{
"remote_block_ids"
:
[
4
,
5
,
6
],
"remote_block_ids"
:
(
[
4
,
5
,
6
],
),
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_request_id"
:
"prefill-id"
,
"remote_host"
:
"localhost"
,
...
...
@@ -652,7 +676,9 @@ class TestNixlHandshake:
local_tp_size
=
1
vllm_config
.
parallel_config
.
tensor_parallel_size
=
local_tp_size
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
...
...
@@ -717,8 +743,12 @@ class TestNixlHandshake:
p_tp_size
=
2
# Build two separate connectors/workers to emulate P TP=2 ranks.
conn_p0
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
conn_p1
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
conn_p0
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
conn_p1
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
conn_p0
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
conn_p0
.
engine_id
,
hand_shake_latency
=
0
)
...
...
@@ -815,7 +845,9 @@ class TestNixlHandshake:
vllm_config
=
create_vllm_config
()
# Test worker role in decode server.
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
)
...
...
@@ -827,9 +859,9 @@ class TestNixlHandshake:
for
i
in
range
(
total_reqs
):
metadata
.
add_new_req_to_recv
(
request_id
=
f
"id_
{
i
}
"
,
local_block_ids
=
[
1
,
2
,
3
],
local_block_ids
=
(
[
1
,
2
,
3
],
),
kv_transfer_params
=
{
"remote_block_ids"
:
[
4
,
5
,
6
],
"remote_block_ids"
:
(
[
4
,
5
,
6
],
),
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_request_id"
:
f
"prefill-id-
{
i
}
"
,
"remote_host"
:
"localhost"
,
...
...
@@ -884,7 +916,9 @@ class TestNixlHandshake:
return_value
=
2
,
):
# Initialize connector and worker (with fake NIXL wrapper)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
...
...
@@ -934,7 +968,9 @@ class TestNixlHandshake:
return_value
=
2
,
):
# Initialize connector and worker (with fake NIXL wrapper)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
...
...
@@ -979,7 +1015,9 @@ def test_kv_connector_stats(default_vllm_config, dist_init):
vllm_config
=
create_vllm_config
()
# Test worker role in decode server.
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
...
...
@@ -993,9 +1031,9 @@ def test_kv_connector_stats(default_vllm_config, dist_init):
metadata
=
NixlConnectorMetadata
()
metadata
.
add_new_req_to_recv
(
request_id
=
request_id
,
local_block_ids
=
[
1
,
2
,
3
],
local_block_ids
=
(
[
1
,
2
,
3
],
),
kv_transfer_params
=
{
"remote_block_ids"
:
[
4
,
5
,
6
],
"remote_block_ids"
:
(
[
4
,
5
,
6
],
),
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_request_id"
:
f
"prefill-
{
request_id
}
"
,
"remote_host"
:
"localhost"
,
...
...
@@ -1448,7 +1486,9 @@ def test_register_kv_caches(
mock_get_attn_backend
.
return_value
=
backend_cls
# Create connector
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
...
...
@@ -1676,7 +1716,9 @@ def test_kv_buffer_to_nixl_memory_types(
),
):
# noqa: E501
# Create connector and replace its worker with a fake one for isolation
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
# Verify get_reg_descs was called with the correct memory_type
assert
connector
.
connector_worker
.
kv_buffer_device
==
kv_buffer_device
...
...
@@ -1692,9 +1734,15 @@ def test_shutdown_cleans_up_resources(default_vllm_config, dist_init):
vllm_config
=
create_vllm_config
()
scheduler
=
NixlConnectorScheduler
(
vllm_config
,
vllm_config
.
kv_transfer_config
.
engine_id
vllm_config
,
vllm_config
.
kv_transfer_config
.
engine_id
,
make_kv_cache_config
(
block_size
=
16
),
)
worker
=
NixlConnectorWorker
(
vllm_config
,
vllm_config
.
kv_transfer_config
.
engine_id
,
make_kv_cache_config
(
block_size
=
16
),
)
worker
=
NixlConnectorWorker
(
vllm_config
,
vllm_config
.
kv_transfer_config
.
engine_id
)
nixl_wrapper
=
worker
.
nixl_wrapper
with
(
...
...
@@ -1756,7 +1804,9 @@ def test_aborted_request_removed_from_worker_in_batch(default_vllm_config, dist_
scheduler
=
create_scheduler
(
vllm_config
)
# KVConnector Worker in P
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
...
...
@@ -1875,12 +1925,14 @@ class FailingNixlWrapper(FakeNixlWrapper):
(
"transfer_exception"
,
{
"fail_transfer_exception"
:
True
},
True
),
],
)
@
pytest
.
mark
.
parametrize
(
"enable_hma"
,
[
False
,
True
])
def
test_transfer_failure_logging
(
default_vllm_config
,
dist_init
,
failure_type
,
wrapper_config
,
needs_get_finished
,
enable_hma
,
):
"""Test that transfer failures are logged with structured context.
...
...
@@ -1897,9 +1949,16 @@ def test_transfer_failure_logging(
vllm_config
=
create_vllm_config
()
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
,
hma_enabled
=
enable_hma
),
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0.0
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0.0
,
kv_cache_config
=
connector
.
_kv_cache_config
,
)
# Configure FailingNixlWrapper to fail in the specified way
...
...
@@ -1910,8 +1969,17 @@ def test_transfer_failure_logging(
# For notification_failed, we need empty local blocks
# (full cache hit path to trigger send_notif)
local_blocks
=
[]
if
failure_type
==
"notification_failed"
else
[
10
,
11
,
12
]
remote_blocks
=
[
20
,
21
,
22
]
local_blocks
:
tuple
[()]
|
tuple
[
list
[
int
],
...]
if
enable_hma
:
# HMA enabled: multiple groups (FA + SW)
local_blocks
=
(
()
if
failure_type
==
"notification_failed"
else
([
10
,
11
,
12
],
[
13
,
14
])
)
remote_blocks
=
[[
20
,
21
,
22
],
[
23
,
24
]]
else
:
# HMA disabled: single group
local_blocks
=
()
if
failure_type
==
"notification_failed"
else
([
10
,
11
,
12
],)
remote_blocks
=
[[
20
,
21
,
22
]]
metadata
=
NixlConnectorMetadata
()
metadata
.
add_new_req_to_recv
(
...
...
@@ -2007,7 +2075,9 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
"""Test that handshake failures mark blocks invalid and return via get_finished."""
vllm_config
=
create_vllm_config
()
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0.1
)
...
...
@@ -2017,9 +2087,9 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
metadata
=
NixlConnectorMetadata
()
metadata
.
add_new_req_to_recv
(
request_id
=
request_id
,
local_block_ids
=
[
1
,
2
,
3
],
local_block_ids
=
(
[
1
,
2
,
3
],
),
kv_transfer_params
=
{
"remote_block_ids"
:
[
4
,
5
,
6
],
"remote_block_ids"
:
(
[
4
,
5
,
6
],
),
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_request_id"
:
f
"prefill-
{
request_id
}
"
,
"remote_host"
:
"localhost"
,
...
...
@@ -2058,7 +2128,9 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init)
and return via get_finished."""
vllm_config
=
create_vllm_config
()
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
...
...
@@ -2068,9 +2140,9 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init)
metadata
=
NixlConnectorMetadata
()
metadata
.
add_new_req_to_recv
(
request_id
=
request_id
,
local_block_ids
=
[
7
,
8
,
9
],
local_block_ids
=
(
[
7
,
8
,
9
],
),
kv_transfer_params
=
{
"remote_block_ids"
:
[
10
,
11
,
12
],
"remote_block_ids"
:
(
[
10
,
11
,
12
],
),
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_request_id"
:
f
"prefill-
{
request_id
}
"
,
"remote_host"
:
"localhost"
,
...
...
@@ -2154,7 +2226,9 @@ def test_compatibility_hash_validation(
"enforce_handshake_compat"
:
enforce_handshake_compat
},
)
decode_connector
=
NixlConnector
(
local_vllm_config
,
KVConnectorRole
.
WORKER
)
decode_connector
=
NixlConnector
(
local_vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
decode_worker
=
decode_connector
.
connector_worker
kv_cache_shape
=
decode_worker
.
attn_backend
.
get_kv_cache_shape
(
num_blocks
=
2
,
block_size
=
16
,
num_kv_heads
=
4
,
head_size
=
64
...
...
@@ -2267,7 +2341,9 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
model
=
"facebook/opt-125m"
,
block_size
=
16
,
)
decode_connector
=
NixlConnector
(
local_vllm_config
,
KVConnectorRole
.
WORKER
)
decode_connector
=
NixlConnector
(
local_vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
decode_worker
=
decode_connector
.
connector_worker
backend
=
get_current_attn_backend
(
local_vllm_config
)
...
...
tests/v1/kv_connector/unit/test_nixl_connector_hma.py
0 → 100644
View file @
5b3ba94a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for NixlConnectorScheduler sw_sizes calculation with HMA."""
from
unittest.mock
import
patch
import
pytest
from
vllm
import
LLM
,
SamplingParams
from
vllm.config
import
KVTransferConfig
from
vllm.v1.core.single_type_kv_cache_manager
import
(
FullAttentionManager
,
SlidingWindowManager
,
)
from
.utils
import
(
create_vllm_config
,
make_kv_cache_config
,
)
@
pytest
.
mark
.
cpu_test
@
pytest
.
mark
.
parametrize
(
"hma_enabled,expected_sw_sizes"
,
[
# HMA enabled: FullAttentionSpec (0) + SlidingWindowSpec (2048/16=128)
(
True
,
[
0
,
128
+
1
]),
# HMA disabled: only FullAttentionSpec (0)
(
False
,
[
0
]),
],
)
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform"
)
def
test_sw_sizes
(
mock_platform
,
hma_enabled
,
expected_sw_sizes
):
"""Test sw_sizes is correctly computed based on HMA enabled/disabled."""
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
NixlConnectorScheduler
,
)
mock_platform
.
device_type
=
"cpu"
block_size
=
16
vllm_config
=
create_vllm_config
(
block_size
=
block_size
)
# SW 2048 tokens=>128 blocks
kv_cache_config
=
make_kv_cache_config
(
block_size
=
block_size
,
hma_enabled
=
hma_enabled
,
sw_size
=
2048
)
scheduler
=
NixlConnectorScheduler
(
vllm_config
=
vllm_config
,
engine_id
=
"test-engine"
,
kv_cache_config
=
kv_cache_config
,
)
# in number of blocks
assert
scheduler
.
blocks_per_sw
==
expected_sw_sizes
,
(
f
"Expected sw_sizes=
{
expected_sw_sizes
}
, got
{
scheduler
.
blocks_per_sw
}
"
)
@
pytest
.
mark
.
cpu_test
def
test_logical_to_kernel_block_ids_with_hma
():
"""Test _logical_to_kernel_block_ids expands blocks when HMA is enabled.
When HMA is enabled, the logical block size may differ from the kernel
block size. Each logical block maps to multiple kernel blocks.
"""
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
NixlConnectorWorker
,
)
# Create a mock worker with just the required attributes
# (use __new__ to skip __init__)
worker
=
object
.
__new__
(
NixlConnectorWorker
)
# Simulate HMA scenario: logical block size = 32, kernel block size = 16
# So each logical block maps to 2 kernel blocks eg [0]->[0,1]
worker
.
_physical_blocks_per_logical_kv_block
=
2
# Test conversion: FA + SW group
logical_block_ids
=
[[
0
,
1
,
2
],
[
3
,
4
]]
kernel_block_ids
=
worker
.
_logical_to_kernel_block_ids
(
logical_block_ids
)
expected_kernel_block_ids
=
[[
0
,
1
,
2
,
3
,
4
,
5
],
[
6
,
7
,
8
,
9
]]
assert
kernel_block_ids
==
expected_kernel_block_ids
,
(
f
"Expected
{
expected_kernel_block_ids
}
, got
{
kernel_block_ids
}
"
)
@
pytest
.
mark
.
parametrize
(
"model_name, sw_size"
,
[(
"google/gemma-3-1b-it"
,
512
)])
def
test_fewer_blocks_with_hma
(
monkeypatch
,
model_name
,
sw_size
):
"""Test that a prefill instance returns fewer "remote blocks" for the SWA groups
when sequence exceeds the sliding window.
"""
kv_transfer_config
=
KVTransferConfig
(
kv_connector
=
"NixlConnector"
,
kv_role
=
"kv_both"
,
)
block_size
=
16
llm_kwargs
=
{
"model"
:
model_name
,
"enforce_eager"
:
True
,
"gpu_memory_utilization"
:
0.5
,
"kv_transfer_config"
:
kv_transfer_config
,
"max_model_len"
:
2048
,
# NOTE: Make sure HMA is enabled
"disable_hybrid_kv_cache_manager"
:
False
,
"max_num_batched_tokens"
:
1024
,
"enable_prefix_caching"
:
False
,
"block_size"
:
block_size
,
}
monkeypatch
.
setenv
(
"VLLM_ENABLE_V1_MULTIPROCESSING"
,
"0"
)
def
run_hma_test
(
llm
:
LLM
):
remote_prefill_opts
=
{
"do_remote_decode"
:
True
,
"do_remote_prefill"
:
False
,
"remote_engine_id"
:
None
,
"remote_block_ids"
:
None
,
"remote_host"
:
None
,
"remote_port"
:
None
,
}
# Simulate sidecar request
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
1
,
extra_args
=
{
"kv_transfer_params"
:
remote_prefill_opts
},
)
scheduler
=
llm
.
llm_engine
.
engine_core
.
engine_core
.
scheduler
kv_managers
=
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
# HMA enabled with FA + SWA groups
assert
len
(
kv_managers
)
>
2
for
kv_manager
in
kv_managers
:
assert
isinstance
(
kv_manager
,
(
SlidingWindowManager
,
FullAttentionManager
))
req_to_blocks
=
kv_managers
[
0
].
req_to_blocks
assert
len
(
req_to_blocks
)
==
0
# Process some request with length exceeding the sliding window
outputs
=
llm
.
generate
([
"hi"
*
1401
],
sampling_params
)
kv_params
=
outputs
[
0
].
kv_transfer_params
# +1 to account for overlapping window across blocks.
expected_num_remote_blocks
=
sw_size
//
block_size
+
1
remote_block_ids
=
kv_params
[
"remote_block_ids"
]
assert
(
len
(
remote_block_ids
[
0
])
==
expected_num_remote_blocks
<
len
(
remote_block_ids
[
-
1
])
)
for
group_block_ids
in
remote_block_ids
[:
-
1
]:
assert
len
(
group_block_ids
)
==
expected_num_remote_blocks
def
run_test_and_cleanup
():
llm
=
LLM
(
**
llm_kwargs
)
try
:
run_hma_test
(
llm
)
finally
:
llm
.
llm_engine
.
engine_core
.
shutdown
()
run_test_and_cleanup
()
@
pytest
.
mark
.
cpu_test
def
test_nixl_metadata_hma_block_ids_structure
():
"""
Test that NixlConnectorMetadata correctly stores block IDs for multiple
KV cache groups when HMA is enabled.
"""
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
NixlConnectorMetadata
,
)
metadata
=
NixlConnectorMetadata
()
# Add request with block IDs for 2 groups (FA + SW)
fa_blocks
=
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
]
# 8 blocks for FA
sw_blocks
=
[
8
,
9
,
10
,
11
]
# 4 blocks for SW (clipped)
metadata
.
add_new_req_to_recv
(
request_id
=
"test-req-hma"
,
local_block_ids
=
(
fa_blocks
,
sw_blocks
),
kv_transfer_params
=
{
"remote_block_ids"
:
([
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
],
[
18
,
19
,
20
,
21
]),
"remote_engine_id"
:
"remote-engine"
,
"remote_request_id"
:
"prefill-test-req-hma"
,
"remote_host"
:
"localhost"
,
"remote_port"
:
1234
,
"tp_size"
:
1
,
},
)
assert
"test-req-hma"
in
metadata
.
reqs_to_recv
req_meta
=
metadata
.
reqs_to_recv
[
"test-req-hma"
]
# Verify local block IDs structure
assert
len
(
req_meta
.
local_block_ids
)
==
2
assert
list
(
req_meta
.
local_block_ids
[
0
])
==
fa_blocks
assert
list
(
req_meta
.
local_block_ids
[
1
])
==
sw_blocks
# Verify remote block IDs structure
assert
req_meta
.
remote
is
not
None
assert
len
(
req_meta
.
remote
.
block_ids
)
==
2
assert
list
(
req_meta
.
remote
.
block_ids
[
0
])
==
[
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
]
assert
list
(
req_meta
.
remote
.
block_ids
[
1
])
==
[
18
,
19
,
20
,
21
]
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
View file @
5b3ba94a
...
...
@@ -208,7 +208,9 @@ def test_prefix_cache_lifecycle():
# Ensure we send all block ids, including the partial blocks,
# even if there is a cache hit.
assert
len
(
kv_transfer_params
[
"remote_block_ids"
])
==
(
NUM_EXTERNAL_FULL_BLOCKS
+
1
)
# remote_block_ids is BlockIds (tuple of lists); sum block counts across groups.
num_remote_blocks
=
sum
(
len
(
g
)
for
g
in
kv_transfer_params
[
"remote_block_ids"
])
assert
num_remote_blocks
==
(
NUM_EXTERNAL_FULL_BLOCKS
+
1
)
# STEP (2): Ensure it is freed.
scheduler_output
=
scheduler
.
schedule
()
...
...
tests/v1/kv_connector/unit/utils.py
View file @
5b3ba94a
...
...
@@ -36,6 +36,7 @@ from vllm.v1.kv_cache_interface import (
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
,
SlidingWindowSpec
,
)
from
vllm.v1.outputs
import
KVConnectorOutput
,
ModelRunnerOutput
from
vllm.v1.request
import
Request
...
...
@@ -142,9 +143,11 @@ def create_vllm_config(
def
create_scheduler
(
vllm_config
:
VllmConfig
,
num_blocks
:
int
=
10000
,
kv_cache_config
:
KVCacheConfig
|
None
=
None
,
)
->
Scheduler
:
"""Initialize Scheduler For Testing."""
block_size
=
vllm_config
.
cache_config
.
block_size
if
kv_cache_config
is
None
:
kv_cache_config
=
KVCacheConfig
(
num_blocks
=
num_blocks
,
# A large number of blocks to hold all requests
kv_cache_tensors
=
[],
...
...
@@ -412,3 +415,38 @@ KVConnectorFactory.register_connector(
KVConnectorFactory
.
register_connector
(
"MockKVConnector"
,
__name__
,
MockKVConnector
.
__name__
)
def
make_kv_cache_config
(
block_size
:
int
,
hma_enabled
:
bool
=
False
,
sw_size
:
int
=
128
,
num_blocks
:
int
=
100
,
)
->
KVCacheConfig
:
kv_cache_groups
=
[
KVCacheGroupSpec
(
[
"layer0"
,
"layer2"
],
FullAttentionSpec
(
block_size
=
block_size
,
num_kv_heads
=
4
,
head_size
=
16
,
dtype
=
torch
.
float16
,
),
)
]
if
hma_enabled
:
kv_cache_groups
.
append
(
KVCacheGroupSpec
(
[
"layer1"
,
"layer3"
],
SlidingWindowSpec
(
block_size
=
block_size
,
num_kv_heads
=
4
,
head_size
=
16
,
dtype
=
torch
.
float16
,
sliding_window
=
sw_size
,
),
)
)
return
KVCacheConfig
(
num_blocks
=
num_blocks
,
kv_cache_tensors
=
[],
kv_cache_groups
=
kv_cache_groups
)
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
5b3ba94a
...
...
@@ -24,6 +24,9 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
EngineId
=
str
# block ids as returned by the hybrid KV cache manager. list[list[int]] are allow
# mutability and are for connector internal use only.
BlockIds
=
tuple
[
list
[
int
],
...]
|
list
[
list
[
int
]]
def
get_kv_connector_cache_layout
():
...
...
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
5b3ba94a
This diff is collapsed.
Click to expand it.
vllm/v1/core/kv_cache_manager.py
View file @
5b3ba94a
...
...
@@ -84,6 +84,18 @@ class KVCacheBlocks:
assert
len
(
self
.
blocks
)
==
1
,
"Only one group is supported"
return
[
block
.
block_id
for
block
in
self
.
blocks
[
0
]
if
block
.
block_hash
is
None
]
def
get_unhashed_block_ids_all_groups
(
self
)
->
list
[
list
[
int
]]:
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
# Skip padding blocks.
return
[
[
block
.
block_id
for
block
in
group
if
block
.
block_hash
is
None
and
not
block
.
is_null
]
for
group
in
self
.
blocks
]
def
new_empty
(
self
)
->
"KVCacheBlocks"
:
"""
Creates a new KVCacheBlocks instance with no blocks.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment