Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5b3ba94a
Unverified
Commit
5b3ba94a
authored
Mar 06, 2026
by
Nicolò Lucchesi
Committed by
GitHub
Mar 06, 2026
Browse files
[Core][KVConnector] Support HMA+NixlConnector (#35758)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
90f3c01f
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
669 additions
and
230 deletions
+669
-230
tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
..._connector/nixl_integration/config_sweep_accuracy_test.sh
+9
-0
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
+36
-1
tests/v1/kv_connector/nixl_integration/test_accuracy.py
tests/v1/kv_connector/nixl_integration/test_accuracy.py
+1
-0
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+122
-46
tests/v1/kv_connector/unit/test_nixl_connector_hma.py
tests/v1/kv_connector/unit/test_nixl_connector_hma.py
+203
-0
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
+3
-1
tests/v1/kv_connector/unit/utils.py
tests/v1/kv_connector/unit/utils.py
+53
-15
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+3
-0
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+227
-167
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+12
-0
No files found.
tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
View file @
5b3ba94a
...
@@ -12,6 +12,7 @@ tp_configs=(
...
@@ -12,6 +12,7 @@ tp_configs=(
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
# MLA case
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
# MLA case
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192"
# SW model
)
)
dp_ep_configs
=(
dp_ep_configs
=(
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
# MLA+P-TP1, D-DPEP=2 (TP=1)
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
# MLA+P-TP1, D-DPEP=2 (TP=1)
...
@@ -26,6 +27,14 @@ else
...
@@ -26,6 +27,14 @@ else
configs
=(
"
${
tp_configs
[@]
}
"
)
configs
=(
"
${
tp_configs
[@]
}
"
)
fi
fi
if
[[
-n
"
${
ENABLE_HMA_FLAG
:-}
"
]]
;
then
# Append ENABLE_HMA_FLAG=1 to each config in the selected array
echo
"ENABLE_HMA_FLAG is set, appending ENABLE_HMA_FLAG=1 to each config"
for
i
in
"
${
!configs[@]
}
"
;
do
configs[
$i
]=
"ENABLE_HMA_FLAG=1
${
configs
[
$i
]
}
"
done
fi
run_tests
()
{
run_tests
()
{
local
label
=
$1
local
label
=
$1
local
extra_args
=
$2
local
extra_args
=
$2
...
...
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
View file @
5b3ba94a
...
@@ -5,6 +5,12 @@ set -xe
...
@@ -5,6 +5,12 @@ set -xe
KV_BUFFER_DEVICE
=
"cuda"
# Default to cuda
KV_BUFFER_DEVICE
=
"cuda"
# Default to cuda
ATTENTION_BACKEND
=
""
# Default to empty (use vllm default)
ATTENTION_BACKEND
=
""
# Default to empty (use vllm default)
CROSS_LAYERS_BLOCKS
=
"False"
CROSS_LAYERS_BLOCKS
=
"False"
ENABLE_HMA_VAR
=
""
# Default to empty (HMA disabled by default for kv connector)
# Check for ENABLE_HMA_FLAG environment variable
if
[[
-n
"
${
ENABLE_HMA_FLAG
:-}
"
]]
;
then
ENABLE_HMA_VAR
=
"--no-disable-hybrid-kv-cache-manager"
fi
while
[[
$#
-gt
0
]]
;
do
while
[[
$#
-gt
0
]]
;
do
case
$1
in
case
$1
in
--kv_buffer_device
)
--kv_buffer_device
)
...
@@ -31,6 +37,12 @@ echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE"
...
@@ -31,6 +37,12 @@ echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE"
if
[[
-n
"
$ATTENTION_BACKEND
"
]]
;
then
if
[[
-n
"
$ATTENTION_BACKEND
"
]]
;
then
echo
"Using attention backend:
$ATTENTION_BACKEND
"
echo
"Using attention backend:
$ATTENTION_BACKEND
"
fi
fi
if
[[
-n
"
$ENABLE_HMA_VAR
"
]]
;
then
echo
"HMA (Hybrid KV Cache Manager) enabled"
fi
if
[[
-n
"
$VLLM_SERVE_EXTRA_ARGS
"
]]
;
then
echo
"vLLM serve extra args:
$VLLM_SERVE_EXTRA_ARGS
"
fi
DECODER_KV_LAYOUT
=
${
DECODER_KV_LAYOUT
:-
"HND"
}
# Default to HND, optional NHD
DECODER_KV_LAYOUT
=
${
DECODER_KV_LAYOUT
:-
"HND"
}
# Default to HND, optional NHD
if
[[
"
$DECODER_KV_LAYOUT
"
==
"NHD"
]]
;
then
if
[[
"
$DECODER_KV_LAYOUT
"
==
"NHD"
]]
;
then
...
@@ -70,6 +82,8 @@ DECODER_TP_SIZE=${DECODER_TP_SIZE:-1}
...
@@ -70,6 +82,8 @@ DECODER_TP_SIZE=${DECODER_TP_SIZE:-1}
GPU_MEMORY_UTILIZATION
=
${
GPU_MEMORY_UTILIZATION
:-
0
.2
}
GPU_MEMORY_UTILIZATION
=
${
GPU_MEMORY_UTILIZATION
:-
0
.2
}
PREFILL_BLOCK_SIZE
=
${
PREFILL_BLOCK_SIZE
:-
128
}
PREFILL_BLOCK_SIZE
=
${
PREFILL_BLOCK_SIZE
:-
128
}
DECODE_BLOCK_SIZE
=
${
DECODE_BLOCK_SIZE
:-
128
}
DECODE_BLOCK_SIZE
=
${
DECODE_BLOCK_SIZE
:-
128
}
# Comma-separated extra args for vllm serve (e.g. --max-model-len,2048)
VLLM_SERVE_EXTRA_ARGS
=
${
VLLM_SERVE_EXTRA_ARGS
:-}
# Find the git repository root directory
# Find the git repository root directory
GIT_ROOT
=
$(
git rev-parse
--show-toplevel
)
GIT_ROOT
=
$(
git rev-parse
--show-toplevel
)
...
@@ -151,14 +165,24 @@ run_tests_for_model() {
...
@@ -151,14 +165,24 @@ run_tests_for_model() {
--gpu-memory-utilization
$GPU_MEMORY_UTILIZATION
\
--gpu-memory-utilization
$GPU_MEMORY_UTILIZATION
\
--tensor-parallel-size
$PREFILLER_TP_SIZE
\
--tensor-parallel-size
$PREFILLER_TP_SIZE
\
--kv-transfer-config '
$KV_CONFIG
'"
--kv-transfer-config '
$KV_CONFIG
'"
if
[[
-n
"
$VLLM_SERVE_EXTRA_ARGS
"
]]
;
then
IFS
=
','
read
-r
-a
extra_args
<<<
"
$VLLM_SERVE_EXTRA_ARGS
"
for
arg
in
"
${
extra_args
[@]
}
"
;
do
BASE_CMD
=
"
${
BASE_CMD
}
$arg
"
done
fi
# Add attention backend config if specified
# Add attention backend config if specified
if
[[
-n
"
$ATTENTION_BACKEND
"
]]
;
then
if
[[
-n
"
$ATTENTION_BACKEND
"
]]
;
then
BASE_CMD
=
"
${
BASE_CMD
}
--attention-backend=
$ATTENTION_BACKEND
"
BASE_CMD
=
"
${
BASE_CMD
}
--attention-backend=
$ATTENTION_BACKEND
"
fi
fi
# Add HMA flag if specified
if
[[
-n
"
$ENABLE_HMA_VAR
"
]]
;
then
BASE_CMD
=
"
${
BASE_CMD
}
$ENABLE_HMA_VAR
"
fi
FULL_CMD
=
"
$BASE_CMD
"
FULL_CMD
=
"
$BASE_CMD
"
eval
"
$FULL_CMD
&"
eval
"
$FULL_CMD
&"
# Store host and port for proxy configuration
# Store host and port for proxy configuration
...
@@ -193,12 +217,23 @@ run_tests_for_model() {
...
@@ -193,12 +217,23 @@ run_tests_for_model() {
--block-size
${
DECODE_BLOCK_SIZE
}
\
--block-size
${
DECODE_BLOCK_SIZE
}
\
--gpu-memory-utilization
$GPU_MEMORY_UTILIZATION
\
--gpu-memory-utilization
$GPU_MEMORY_UTILIZATION
\
--kv-transfer-config '
$KV_CONFIG
'"
--kv-transfer-config '
$KV_CONFIG
'"
if
[[
-n
"
$VLLM_SERVE_EXTRA_ARGS
"
]]
;
then
IFS
=
','
read
-r
-a
extra_args
<<<
"
$VLLM_SERVE_EXTRA_ARGS
"
for
arg
in
"
${
extra_args
[@]
}
"
;
do
BASE_CMD
=
"
${
BASE_CMD
}
$arg
"
done
fi
# Add attention backend config if specified
# Add attention backend config if specified
if
[[
-n
"
$ATTENTION_BACKEND
"
]]
;
then
if
[[
-n
"
$ATTENTION_BACKEND
"
]]
;
then
BASE_CMD
=
"
${
BASE_CMD
}
--attention-backend=
$ATTENTION_BACKEND
"
BASE_CMD
=
"
${
BASE_CMD
}
--attention-backend=
$ATTENTION_BACKEND
"
fi
fi
# Add HMA flag if specified
if
[[
-n
"
$ENABLE_HMA_VAR
"
]]
;
then
BASE_CMD
=
"
${
BASE_CMD
}
$ENABLE_HMA_VAR
"
fi
# DP-EP attention mode
# DP-EP attention mode
if
[[
-z
"
$DP_EP
"
]]
;
then
if
[[
-z
"
$DP_EP
"
]]
;
then
BASE_CMD
=
"
${
BASE_CMD
}
--tensor-parallel-size
$DECODER_TP_SIZE
"
BASE_CMD
=
"
${
BASE_CMD
}
--tensor-parallel-size
$DECODER_TP_SIZE
"
...
...
tests/v1/kv_connector/nixl_integration/test_accuracy.py
View file @
5b3ba94a
...
@@ -17,6 +17,7 @@ EXPECTED_VALUES = {
...
@@ -17,6 +17,7 @@ EXPECTED_VALUES = {
"deepseek-ai/deepseek-vl2-small"
:
0.59
,
"deepseek-ai/deepseek-vl2-small"
:
0.59
,
"deepseek-ai/deepseek-vl2-tiny"
:
0.19
,
"deepseek-ai/deepseek-vl2-tiny"
:
0.19
,
"deepseek-ai/DeepSeek-V2-Lite-Chat"
:
0.65
,
"deepseek-ai/DeepSeek-V2-Lite-Chat"
:
0.65
,
"google/gemma-3-4b-it"
:
0.74
,
}
}
SIMPLE_PROMPT
=
(
SIMPLE_PROMPT
=
(
...
...
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
5b3ba94a
...
@@ -59,7 +59,12 @@ from vllm.v1.request import RequestStatus
...
@@ -59,7 +59,12 @@ from vllm.v1.request import RequestStatus
from
vllm.v1.worker.kv_connector_model_runner_mixin
import
KVConnectorModelRunnerMixin
from
vllm.v1.worker.kv_connector_model_runner_mixin
import
KVConnectorModelRunnerMixin
from
vllm.v1.worker.utils
import
AttentionGroup
from
vllm.v1.worker.utils
import
AttentionGroup
from
.utils
import
create_request
,
create_scheduler
,
create_vllm_config
from
.utils
import
(
create_request
,
create_scheduler
,
create_vllm_config
,
make_kv_cache_config
,
)
@
pytest
.
fixture
(
scope
=
"module"
,
autouse
=
True
)
@
pytest
.
fixture
(
scope
=
"module"
,
autouse
=
True
)
...
@@ -263,7 +268,7 @@ def test_basic_interface():
...
@@ -263,7 +268,7 @@ def test_basic_interface():
req_meta
=
kv_connector_metadata
.
reqs_to_recv
[
request_id
]
req_meta
=
kv_connector_metadata
.
reqs_to_recv
[
request_id
]
for
block_id
,
block
in
zip
(
for
block_id
,
block
in
zip
(
req_meta
.
local_block_ids
,
req_meta
.
local_block_ids
[
0
]
,
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
[
0
].
req_to_blocks
[
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
[
0
].
req_to_blocks
[
request_id
request_id
],
],
...
@@ -327,7 +332,9 @@ def test_kv_transfer_handshake(dist_init):
...
@@ -327,7 +332,9 @@ def test_kv_transfer_handshake(dist_init):
# Prefill connector will register KV cache to populate proper handshake
# Prefill connector will register KV cache to populate proper handshake
# metadata.
# metadata.
prefill_connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
prefill_connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
kv_cache_shape
=
FlashAttentionBackend
.
get_kv_cache_shape
(
kv_cache_shape
=
FlashAttentionBackend
.
get_kv_cache_shape
(
num_blocks
=
2
,
block_size
=
16
,
num_kv_heads
=
4
,
head_size
=
64
num_blocks
=
2
,
block_size
=
16
,
num_kv_heads
=
4
,
head_size
=
64
)
)
...
@@ -367,13 +374,17 @@ def test_kv_transfer_handshake(dist_init):
...
@@ -367,13 +374,17 @@ def test_kv_transfer_handshake(dist_init):
do_remote_decode
=
True
,
do_remote_decode
=
True
,
)
)
request
.
status
=
RequestStatus
.
FINISHED_LENGTH_CAPPED
request
.
status
=
RequestStatus
.
FINISHED_LENGTH_CAPPED
delay
,
kv_connector_metadata
=
scheduler
.
get_kv_connector
().
request_finished
(
delay
,
kv_connector_metadata
=
(
request
,
[
0
,
1
,
2
]
scheduler
.
get_kv_connector
().
request_finished_all_groups
(
request
,
([
0
,
1
,
2
],)
)
)
)
assert
delay
assert
delay
# Decode connector will be able to create handshake with the prefill connector.
# Decode connector will be able to create handshake with the prefill connector.
decode_connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
decode_connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
decode_connector
.
register_kv_caches
(
kv_caches
)
decode_connector
.
register_kv_caches
(
kv_caches
)
# Here we are testing the retrieval of NIXLAgentMetadata.
# Here we are testing the retrieval of NIXLAgentMetadata.
...
@@ -404,9 +415,16 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
...
@@ -404,9 +415,16 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
REMOTE_ENGINE_ID
=
"remote_engine"
REMOTE_ENGINE_ID
=
"remote_engine"
def
__init__
(
def
__init__
(
self
,
*
args
,
hand_shake_latency
:
float
=
1.8
,
kv_cache_layout
=
"HND"
,
**
kwargs
self
,
*
args
,
hand_shake_latency
:
float
=
1.8
,
kv_cache_layout
=
"HND"
,
kv_cache_config
=
None
,
**
kwargs
,
):
):
super
().
__init__
(
*
args
,
**
kwargs
)
if
kv_cache_config
is
None
:
kv_cache_config
=
make_kv_cache_config
(
block_size
=
16
)
super
().
__init__
(
*
args
,
kv_cache_config
=
kv_cache_config
,
**
kwargs
)
self
.
_hand_shake_latency
=
hand_shake_latency
self
.
_hand_shake_latency
=
hand_shake_latency
self
.
kv_cache_layout
=
kv_cache_layout
self
.
kv_cache_layout
=
kv_cache_layout
# Mock register_kv_caches attribute needed for tests that do not call it.
# Mock register_kv_caches attribute needed for tests that do not call it.
...
@@ -507,7 +525,9 @@ class TestNixlHandshake:
...
@@ -507,7 +525,9 @@ class TestNixlHandshake:
request_id
=
"req_id"
request_id
=
"req_id"
# Test worker role in decode server.
# Test worker role in decode server.
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
)
...
@@ -528,13 +548,15 @@ class TestNixlHandshake:
...
@@ -528,13 +548,15 @@ class TestNixlHandshake:
num_xfers
-=
1
num_xfers
-=
1
metadata
.
add_new_req_to_recv
(
metadata
.
add_new_req_to_recv
(
request_id
=
request_id
,
request_id
=
request_id
,
local_block_ids
=
[
num_xfers
+
1
,
num_xfers
+
2
,
num_xfers
+
3
],
local_block_ids
=
(
[
num_xfers
+
1
,
num_xfers
+
2
,
num_xfers
+
3
],
),
kv_transfer_params
=
{
kv_transfer_params
=
{
"remote_block_ids"
:
[
"remote_block_ids"
:
(
num_xfers
+
4
,
[
num_xfers
+
5
,
num_xfers
+
4
,
num_xfers
+
6
,
num_xfers
+
5
,
],
num_xfers
+
6
,
],
),
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_request_id"
:
f
"prefill-
{
request_id
}
"
,
"remote_request_id"
:
f
"prefill-
{
request_id
}
"
,
"remote_host"
:
"localhost"
,
"remote_host"
:
"localhost"
,
...
@@ -594,16 +616,18 @@ class TestNixlHandshake:
...
@@ -594,16 +616,18 @@ class TestNixlHandshake:
vllm_config
.
parallel_config
.
tensor_parallel_size
=
decode_tp_size
vllm_config
.
parallel_config
.
tensor_parallel_size
=
decode_tp_size
# Test worker role in decode server.
# Test worker role in decode server.
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
vllm_config
,
connector
.
engine_id
)
)
metadata
=
NixlConnectorMetadata
()
metadata
=
NixlConnectorMetadata
()
metadata
.
add_new_req_to_recv
(
metadata
.
add_new_req_to_recv
(
request_id
=
"id"
,
request_id
=
"id"
,
local_block_ids
=
[
1
,
2
,
3
],
local_block_ids
=
(
[
1
,
2
,
3
],
),
kv_transfer_params
=
{
kv_transfer_params
=
{
"remote_block_ids"
:
[
4
,
5
,
6
],
"remote_block_ids"
:
(
[
4
,
5
,
6
],
),
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_request_id"
:
"prefill-id"
,
"remote_request_id"
:
"prefill-id"
,
"remote_host"
:
"localhost"
,
"remote_host"
:
"localhost"
,
...
@@ -652,7 +676,9 @@ class TestNixlHandshake:
...
@@ -652,7 +676,9 @@ class TestNixlHandshake:
local_tp_size
=
1
local_tp_size
=
1
vllm_config
.
parallel_config
.
tensor_parallel_size
=
local_tp_size
vllm_config
.
parallel_config
.
tensor_parallel_size
=
local_tp_size
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
)
...
@@ -717,8 +743,12 @@ class TestNixlHandshake:
...
@@ -717,8 +743,12 @@ class TestNixlHandshake:
p_tp_size
=
2
p_tp_size
=
2
# Build two separate connectors/workers to emulate P TP=2 ranks.
# Build two separate connectors/workers to emulate P TP=2 ranks.
conn_p0
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
conn_p0
=
NixlConnector
(
conn_p1
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
conn_p1
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
conn_p0
.
connector_worker
=
FakeNixlConnectorWorker
(
conn_p0
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
conn_p0
.
engine_id
,
hand_shake_latency
=
0
vllm_config
,
conn_p0
.
engine_id
,
hand_shake_latency
=
0
)
)
...
@@ -815,7 +845,9 @@ class TestNixlHandshake:
...
@@ -815,7 +845,9 @@ class TestNixlHandshake:
vllm_config
=
create_vllm_config
()
vllm_config
=
create_vllm_config
()
# Test worker role in decode server.
# Test worker role in decode server.
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
vllm_config
,
connector
.
engine_id
)
)
...
@@ -827,9 +859,9 @@ class TestNixlHandshake:
...
@@ -827,9 +859,9 @@ class TestNixlHandshake:
for
i
in
range
(
total_reqs
):
for
i
in
range
(
total_reqs
):
metadata
.
add_new_req_to_recv
(
metadata
.
add_new_req_to_recv
(
request_id
=
f
"id_
{
i
}
"
,
request_id
=
f
"id_
{
i
}
"
,
local_block_ids
=
[
1
,
2
,
3
],
local_block_ids
=
(
[
1
,
2
,
3
],
),
kv_transfer_params
=
{
kv_transfer_params
=
{
"remote_block_ids"
:
[
4
,
5
,
6
],
"remote_block_ids"
:
(
[
4
,
5
,
6
],
),
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_request_id"
:
f
"prefill-id-
{
i
}
"
,
"remote_request_id"
:
f
"prefill-id-
{
i
}
"
,
"remote_host"
:
"localhost"
,
"remote_host"
:
"localhost"
,
...
@@ -884,7 +916,9 @@ class TestNixlHandshake:
...
@@ -884,7 +916,9 @@ class TestNixlHandshake:
return_value
=
2
,
return_value
=
2
,
):
):
# Initialize connector and worker (with fake NIXL wrapper)
# Initialize connector and worker (with fake NIXL wrapper)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
)
...
@@ -934,7 +968,9 @@ class TestNixlHandshake:
...
@@ -934,7 +968,9 @@ class TestNixlHandshake:
return_value
=
2
,
return_value
=
2
,
):
):
# Initialize connector and worker (with fake NIXL wrapper)
# Initialize connector and worker (with fake NIXL wrapper)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
vllm_config
,
connector
.
engine_id
,
connector
.
engine_id
,
...
@@ -979,7 +1015,9 @@ def test_kv_connector_stats(default_vllm_config, dist_init):
...
@@ -979,7 +1015,9 @@ def test_kv_connector_stats(default_vllm_config, dist_init):
vllm_config
=
create_vllm_config
()
vllm_config
=
create_vllm_config
()
# Test worker role in decode server.
# Test worker role in decode server.
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
)
...
@@ -993,9 +1031,9 @@ def test_kv_connector_stats(default_vllm_config, dist_init):
...
@@ -993,9 +1031,9 @@ def test_kv_connector_stats(default_vllm_config, dist_init):
metadata
=
NixlConnectorMetadata
()
metadata
=
NixlConnectorMetadata
()
metadata
.
add_new_req_to_recv
(
metadata
.
add_new_req_to_recv
(
request_id
=
request_id
,
request_id
=
request_id
,
local_block_ids
=
[
1
,
2
,
3
],
local_block_ids
=
(
[
1
,
2
,
3
],
),
kv_transfer_params
=
{
kv_transfer_params
=
{
"remote_block_ids"
:
[
4
,
5
,
6
],
"remote_block_ids"
:
(
[
4
,
5
,
6
],
),
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_request_id"
:
f
"prefill-
{
request_id
}
"
,
"remote_request_id"
:
f
"prefill-
{
request_id
}
"
,
"remote_host"
:
"localhost"
,
"remote_host"
:
"localhost"
,
...
@@ -1448,7 +1486,9 @@ def test_register_kv_caches(
...
@@ -1448,7 +1486,9 @@ def test_register_kv_caches(
mock_get_attn_backend
.
return_value
=
backend_cls
mock_get_attn_backend
.
return_value
=
backend_cls
# Create connector
# Create connector
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
)
...
@@ -1676,7 +1716,9 @@ def test_kv_buffer_to_nixl_memory_types(
...
@@ -1676,7 +1716,9 @@ def test_kv_buffer_to_nixl_memory_types(
),
),
):
# noqa: E501
):
# noqa: E501
# Create connector and replace its worker with a fake one for isolation
# Create connector and replace its worker with a fake one for isolation
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
# Verify get_reg_descs was called with the correct memory_type
# Verify get_reg_descs was called with the correct memory_type
assert
connector
.
connector_worker
.
kv_buffer_device
==
kv_buffer_device
assert
connector
.
connector_worker
.
kv_buffer_device
==
kv_buffer_device
...
@@ -1692,9 +1734,15 @@ def test_shutdown_cleans_up_resources(default_vllm_config, dist_init):
...
@@ -1692,9 +1734,15 @@ def test_shutdown_cleans_up_resources(default_vllm_config, dist_init):
vllm_config
=
create_vllm_config
()
vllm_config
=
create_vllm_config
()
scheduler
=
NixlConnectorScheduler
(
scheduler
=
NixlConnectorScheduler
(
vllm_config
,
vllm_config
.
kv_transfer_config
.
engine_id
vllm_config
,
vllm_config
.
kv_transfer_config
.
engine_id
,
make_kv_cache_config
(
block_size
=
16
),
)
worker
=
NixlConnectorWorker
(
vllm_config
,
vllm_config
.
kv_transfer_config
.
engine_id
,
make_kv_cache_config
(
block_size
=
16
),
)
)
worker
=
NixlConnectorWorker
(
vllm_config
,
vllm_config
.
kv_transfer_config
.
engine_id
)
nixl_wrapper
=
worker
.
nixl_wrapper
nixl_wrapper
=
worker
.
nixl_wrapper
with
(
with
(
...
@@ -1756,7 +1804,9 @@ def test_aborted_request_removed_from_worker_in_batch(default_vllm_config, dist_
...
@@ -1756,7 +1804,9 @@ def test_aborted_request_removed_from_worker_in_batch(default_vllm_config, dist_
scheduler
=
create_scheduler
(
vllm_config
)
scheduler
=
create_scheduler
(
vllm_config
)
# KVConnector Worker in P
# KVConnector Worker in P
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
)
...
@@ -1875,12 +1925,14 @@ class FailingNixlWrapper(FakeNixlWrapper):
...
@@ -1875,12 +1925,14 @@ class FailingNixlWrapper(FakeNixlWrapper):
(
"transfer_exception"
,
{
"fail_transfer_exception"
:
True
},
True
),
(
"transfer_exception"
,
{
"fail_transfer_exception"
:
True
},
True
),
],
],
)
)
@
pytest
.
mark
.
parametrize
(
"enable_hma"
,
[
False
,
True
])
def
test_transfer_failure_logging
(
def
test_transfer_failure_logging
(
default_vllm_config
,
default_vllm_config
,
dist_init
,
dist_init
,
failure_type
,
failure_type
,
wrapper_config
,
wrapper_config
,
needs_get_finished
,
needs_get_finished
,
enable_hma
,
):
):
"""Test that transfer failures are logged with structured context.
"""Test that transfer failures are logged with structured context.
...
@@ -1897,9 +1949,16 @@ def test_transfer_failure_logging(
...
@@ -1897,9 +1949,16 @@ def test_transfer_failure_logging(
vllm_config
=
create_vllm_config
()
vllm_config
=
create_vllm_config
()
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
,
hma_enabled
=
enable_hma
),
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0.0
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0.0
,
kv_cache_config
=
connector
.
_kv_cache_config
,
)
)
# Configure FailingNixlWrapper to fail in the specified way
# Configure FailingNixlWrapper to fail in the specified way
...
@@ -1910,8 +1969,17 @@ def test_transfer_failure_logging(
...
@@ -1910,8 +1969,17 @@ def test_transfer_failure_logging(
# For notification_failed, we need empty local blocks
# For notification_failed, we need empty local blocks
# (full cache hit path to trigger send_notif)
# (full cache hit path to trigger send_notif)
local_blocks
=
[]
if
failure_type
==
"notification_failed"
else
[
10
,
11
,
12
]
local_blocks
:
tuple
[()]
|
tuple
[
list
[
int
],
...]
remote_blocks
=
[
20
,
21
,
22
]
if
enable_hma
:
# HMA enabled: multiple groups (FA + SW)
local_blocks
=
(
()
if
failure_type
==
"notification_failed"
else
([
10
,
11
,
12
],
[
13
,
14
])
)
remote_blocks
=
[[
20
,
21
,
22
],
[
23
,
24
]]
else
:
# HMA disabled: single group
local_blocks
=
()
if
failure_type
==
"notification_failed"
else
([
10
,
11
,
12
],)
remote_blocks
=
[[
20
,
21
,
22
]]
metadata
=
NixlConnectorMetadata
()
metadata
=
NixlConnectorMetadata
()
metadata
.
add_new_req_to_recv
(
metadata
.
add_new_req_to_recv
(
...
@@ -2007,7 +2075,9 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
...
@@ -2007,7 +2075,9 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
"""Test that handshake failures mark blocks invalid and return via get_finished."""
"""Test that handshake failures mark blocks invalid and return via get_finished."""
vllm_config
=
create_vllm_config
()
vllm_config
=
create_vllm_config
()
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0.1
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0.1
)
)
...
@@ -2017,9 +2087,9 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
...
@@ -2017,9 +2087,9 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
metadata
=
NixlConnectorMetadata
()
metadata
=
NixlConnectorMetadata
()
metadata
.
add_new_req_to_recv
(
metadata
.
add_new_req_to_recv
(
request_id
=
request_id
,
request_id
=
request_id
,
local_block_ids
=
[
1
,
2
,
3
],
local_block_ids
=
(
[
1
,
2
,
3
],
),
kv_transfer_params
=
{
kv_transfer_params
=
{
"remote_block_ids"
:
[
4
,
5
,
6
],
"remote_block_ids"
:
(
[
4
,
5
,
6
],
),
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_request_id"
:
f
"prefill-
{
request_id
}
"
,
"remote_request_id"
:
f
"prefill-
{
request_id
}
"
,
"remote_host"
:
"localhost"
,
"remote_host"
:
"localhost"
,
...
@@ -2058,7 +2128,9 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init)
...
@@ -2058,7 +2128,9 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init)
and return via get_finished."""
and return via get_finished."""
vllm_config
=
create_vllm_config
()
vllm_config
=
create_vllm_config
()
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
)
...
@@ -2068,9 +2140,9 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init)
...
@@ -2068,9 +2140,9 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init)
metadata
=
NixlConnectorMetadata
()
metadata
=
NixlConnectorMetadata
()
metadata
.
add_new_req_to_recv
(
metadata
.
add_new_req_to_recv
(
request_id
=
request_id
,
request_id
=
request_id
,
local_block_ids
=
[
7
,
8
,
9
],
local_block_ids
=
(
[
7
,
8
,
9
],
),
kv_transfer_params
=
{
kv_transfer_params
=
{
"remote_block_ids"
:
[
10
,
11
,
12
],
"remote_block_ids"
:
(
[
10
,
11
,
12
],
),
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_request_id"
:
f
"prefill-
{
request_id
}
"
,
"remote_request_id"
:
f
"prefill-
{
request_id
}
"
,
"remote_host"
:
"localhost"
,
"remote_host"
:
"localhost"
,
...
@@ -2154,7 +2226,9 @@ def test_compatibility_hash_validation(
...
@@ -2154,7 +2226,9 @@ def test_compatibility_hash_validation(
"enforce_handshake_compat"
:
enforce_handshake_compat
"enforce_handshake_compat"
:
enforce_handshake_compat
},
},
)
)
decode_connector
=
NixlConnector
(
local_vllm_config
,
KVConnectorRole
.
WORKER
)
decode_connector
=
NixlConnector
(
local_vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
decode_worker
=
decode_connector
.
connector_worker
decode_worker
=
decode_connector
.
connector_worker
kv_cache_shape
=
decode_worker
.
attn_backend
.
get_kv_cache_shape
(
kv_cache_shape
=
decode_worker
.
attn_backend
.
get_kv_cache_shape
(
num_blocks
=
2
,
block_size
=
16
,
num_kv_heads
=
4
,
head_size
=
64
num_blocks
=
2
,
block_size
=
16
,
num_kv_heads
=
4
,
head_size
=
64
...
@@ -2267,7 +2341,9 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
...
@@ -2267,7 +2341,9 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
model
=
"facebook/opt-125m"
,
model
=
"facebook/opt-125m"
,
block_size
=
16
,
block_size
=
16
,
)
)
decode_connector
=
NixlConnector
(
local_vllm_config
,
KVConnectorRole
.
WORKER
)
decode_connector
=
NixlConnector
(
local_vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
decode_worker
=
decode_connector
.
connector_worker
decode_worker
=
decode_connector
.
connector_worker
backend
=
get_current_attn_backend
(
local_vllm_config
)
backend
=
get_current_attn_backend
(
local_vllm_config
)
...
...
tests/v1/kv_connector/unit/test_nixl_connector_hma.py
0 → 100644
View file @
5b3ba94a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for NixlConnectorScheduler sw_sizes calculation with HMA."""
from
unittest.mock
import
patch
import
pytest
from
vllm
import
LLM
,
SamplingParams
from
vllm.config
import
KVTransferConfig
from
vllm.v1.core.single_type_kv_cache_manager
import
(
FullAttentionManager
,
SlidingWindowManager
,
)
from
.utils
import
(
create_vllm_config
,
make_kv_cache_config
,
)
@
pytest
.
mark
.
cpu_test
@
pytest
.
mark
.
parametrize
(
"hma_enabled,expected_sw_sizes"
,
[
# HMA enabled: FullAttentionSpec (0) + SlidingWindowSpec (2048/16=128)
(
True
,
[
0
,
128
+
1
]),
# HMA disabled: only FullAttentionSpec (0)
(
False
,
[
0
]),
],
)
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform"
)
def
test_sw_sizes
(
mock_platform
,
hma_enabled
,
expected_sw_sizes
):
"""Test sw_sizes is correctly computed based on HMA enabled/disabled."""
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
NixlConnectorScheduler
,
)
mock_platform
.
device_type
=
"cpu"
block_size
=
16
vllm_config
=
create_vllm_config
(
block_size
=
block_size
)
# SW 2048 tokens=>128 blocks
kv_cache_config
=
make_kv_cache_config
(
block_size
=
block_size
,
hma_enabled
=
hma_enabled
,
sw_size
=
2048
)
scheduler
=
NixlConnectorScheduler
(
vllm_config
=
vllm_config
,
engine_id
=
"test-engine"
,
kv_cache_config
=
kv_cache_config
,
)
# in number of blocks
assert
scheduler
.
blocks_per_sw
==
expected_sw_sizes
,
(
f
"Expected sw_sizes=
{
expected_sw_sizes
}
, got
{
scheduler
.
blocks_per_sw
}
"
)
@
pytest
.
mark
.
cpu_test
def
test_logical_to_kernel_block_ids_with_hma
():
"""Test _logical_to_kernel_block_ids expands blocks when HMA is enabled.
When HMA is enabled, the logical block size may differ from the kernel
block size. Each logical block maps to multiple kernel blocks.
"""
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
NixlConnectorWorker
,
)
# Create a mock worker with just the required attributes
# (use __new__ to skip __init__)
worker
=
object
.
__new__
(
NixlConnectorWorker
)
# Simulate HMA scenario: logical block size = 32, kernel block size = 16
# So each logical block maps to 2 kernel blocks eg [0]->[0,1]
worker
.
_physical_blocks_per_logical_kv_block
=
2
# Test conversion: FA + SW group
logical_block_ids
=
[[
0
,
1
,
2
],
[
3
,
4
]]
kernel_block_ids
=
worker
.
_logical_to_kernel_block_ids
(
logical_block_ids
)
expected_kernel_block_ids
=
[[
0
,
1
,
2
,
3
,
4
,
5
],
[
6
,
7
,
8
,
9
]]
assert
kernel_block_ids
==
expected_kernel_block_ids
,
(
f
"Expected
{
expected_kernel_block_ids
}
, got
{
kernel_block_ids
}
"
)
@
pytest
.
mark
.
parametrize
(
"model_name, sw_size"
,
[(
"google/gemma-3-1b-it"
,
512
)])
def
test_fewer_blocks_with_hma
(
monkeypatch
,
model_name
,
sw_size
):
"""Test that a prefill instance returns fewer "remote blocks" for the SWA groups
when sequence exceeds the sliding window.
"""
kv_transfer_config
=
KVTransferConfig
(
kv_connector
=
"NixlConnector"
,
kv_role
=
"kv_both"
,
)
block_size
=
16
llm_kwargs
=
{
"model"
:
model_name
,
"enforce_eager"
:
True
,
"gpu_memory_utilization"
:
0.5
,
"kv_transfer_config"
:
kv_transfer_config
,
"max_model_len"
:
2048
,
# NOTE: Make sure HMA is enabled
"disable_hybrid_kv_cache_manager"
:
False
,
"max_num_batched_tokens"
:
1024
,
"enable_prefix_caching"
:
False
,
"block_size"
:
block_size
,
}
monkeypatch
.
setenv
(
"VLLM_ENABLE_V1_MULTIPROCESSING"
,
"0"
)
def
run_hma_test
(
llm
:
LLM
):
remote_prefill_opts
=
{
"do_remote_decode"
:
True
,
"do_remote_prefill"
:
False
,
"remote_engine_id"
:
None
,
"remote_block_ids"
:
None
,
"remote_host"
:
None
,
"remote_port"
:
None
,
}
# Simulate sidecar request
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
1
,
extra_args
=
{
"kv_transfer_params"
:
remote_prefill_opts
},
)
scheduler
=
llm
.
llm_engine
.
engine_core
.
engine_core
.
scheduler
kv_managers
=
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
# HMA enabled with FA + SWA groups
assert
len
(
kv_managers
)
>
2
for
kv_manager
in
kv_managers
:
assert
isinstance
(
kv_manager
,
(
SlidingWindowManager
,
FullAttentionManager
))
req_to_blocks
=
kv_managers
[
0
].
req_to_blocks
assert
len
(
req_to_blocks
)
==
0
# Process some request with length exceeding the sliding window
outputs
=
llm
.
generate
([
"hi"
*
1401
],
sampling_params
)
kv_params
=
outputs
[
0
].
kv_transfer_params
# +1 to account for overlapping window across blocks.
expected_num_remote_blocks
=
sw_size
//
block_size
+
1
remote_block_ids
=
kv_params
[
"remote_block_ids"
]
assert
(
len
(
remote_block_ids
[
0
])
==
expected_num_remote_blocks
<
len
(
remote_block_ids
[
-
1
])
)
for
group_block_ids
in
remote_block_ids
[:
-
1
]:
assert
len
(
group_block_ids
)
==
expected_num_remote_blocks
def
run_test_and_cleanup
():
llm
=
LLM
(
**
llm_kwargs
)
try
:
run_hma_test
(
llm
)
finally
:
llm
.
llm_engine
.
engine_core
.
shutdown
()
run_test_and_cleanup
()
@
pytest
.
mark
.
cpu_test
def
test_nixl_metadata_hma_block_ids_structure
():
"""
Test that NixlConnectorMetadata correctly stores block IDs for multiple
KV cache groups when HMA is enabled.
"""
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
NixlConnectorMetadata
,
)
metadata
=
NixlConnectorMetadata
()
# Add request with block IDs for 2 groups (FA + SW)
fa_blocks
=
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
]
# 8 blocks for FA
sw_blocks
=
[
8
,
9
,
10
,
11
]
# 4 blocks for SW (clipped)
metadata
.
add_new_req_to_recv
(
request_id
=
"test-req-hma"
,
local_block_ids
=
(
fa_blocks
,
sw_blocks
),
kv_transfer_params
=
{
"remote_block_ids"
:
([
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
],
[
18
,
19
,
20
,
21
]),
"remote_engine_id"
:
"remote-engine"
,
"remote_request_id"
:
"prefill-test-req-hma"
,
"remote_host"
:
"localhost"
,
"remote_port"
:
1234
,
"tp_size"
:
1
,
},
)
assert
"test-req-hma"
in
metadata
.
reqs_to_recv
req_meta
=
metadata
.
reqs_to_recv
[
"test-req-hma"
]
# Verify local block IDs structure
assert
len
(
req_meta
.
local_block_ids
)
==
2
assert
list
(
req_meta
.
local_block_ids
[
0
])
==
fa_blocks
assert
list
(
req_meta
.
local_block_ids
[
1
])
==
sw_blocks
# Verify remote block IDs structure
assert
req_meta
.
remote
is
not
None
assert
len
(
req_meta
.
remote
.
block_ids
)
==
2
assert
list
(
req_meta
.
remote
.
block_ids
[
0
])
==
[
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
]
assert
list
(
req_meta
.
remote
.
block_ids
[
1
])
==
[
18
,
19
,
20
,
21
]
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
View file @
5b3ba94a
...
@@ -208,7 +208,9 @@ def test_prefix_cache_lifecycle():
...
@@ -208,7 +208,9 @@ def test_prefix_cache_lifecycle():
# Ensure we send all block ids, including the partial blocks,
# Ensure we send all block ids, including the partial blocks,
# even if there is a cache hit.
# even if there is a cache hit.
assert
len
(
kv_transfer_params
[
"remote_block_ids"
])
==
(
NUM_EXTERNAL_FULL_BLOCKS
+
1
)
# remote_block_ids is BlockIds (tuple of lists); sum block counts across groups.
num_remote_blocks
=
sum
(
len
(
g
)
for
g
in
kv_transfer_params
[
"remote_block_ids"
])
assert
num_remote_blocks
==
(
NUM_EXTERNAL_FULL_BLOCKS
+
1
)
# STEP (2): Ensure it is freed.
# STEP (2): Ensure it is freed.
scheduler_output
=
scheduler
.
schedule
()
scheduler_output
=
scheduler
.
schedule
()
...
...
tests/v1/kv_connector/unit/utils.py
View file @
5b3ba94a
...
@@ -36,6 +36,7 @@ from vllm.v1.kv_cache_interface import (
...
@@ -36,6 +36,7 @@ from vllm.v1.kv_cache_interface import (
FullAttentionSpec
,
FullAttentionSpec
,
KVCacheConfig
,
KVCacheConfig
,
KVCacheGroupSpec
,
KVCacheGroupSpec
,
SlidingWindowSpec
,
)
)
from
vllm.v1.outputs
import
KVConnectorOutput
,
ModelRunnerOutput
from
vllm.v1.outputs
import
KVConnectorOutput
,
ModelRunnerOutput
from
vllm.v1.request
import
Request
from
vllm.v1.request
import
Request
...
@@ -142,24 +143,26 @@ def create_vllm_config(
...
@@ -142,24 +143,26 @@ def create_vllm_config(
def
create_scheduler
(
def
create_scheduler
(
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
num_blocks
:
int
=
10000
,
num_blocks
:
int
=
10000
,
kv_cache_config
:
KVCacheConfig
|
None
=
None
,
)
->
Scheduler
:
)
->
Scheduler
:
"""Initialize Scheduler For Testing."""
"""Initialize Scheduler For Testing."""
block_size
=
vllm_config
.
cache_config
.
block_size
block_size
=
vllm_config
.
cache_config
.
block_size
kv_cache_config
=
KVCacheConfig
(
if
kv_cache_config
is
None
:
num_blocks
=
num_blocks
,
# A large number of blocks to hold all requests
kv_cache_config
=
KVCacheConfig
(
kv_cache_tensors
=
[],
num_blocks
=
num_blocks
,
# A large number of blocks to hold all requests
kv_cache_groups
=
[
kv_cache_tensors
=
[],
KVCacheGroupSpec
(
kv_cache_groups
=
[
[
"layer"
],
KVCacheGroupSpec
(
FullAttentionSpec
(
[
"layer"
],
block_size
=
block_size
,
FullAttentionSpec
(
num_kv_heads
=
1
,
block_size
=
block_size
,
head_size
=
1
,
num_kv_heads
=
1
,
dtype
=
torch
.
float32
,
head_size
=
1
,
),
dtype
=
torch
.
float32
,
)
),
],
)
)
],
)
vllm_config
.
cache_config
.
num_gpu_blocks
=
num_blocks
vllm_config
.
cache_config
.
num_gpu_blocks
=
num_blocks
return
Scheduler
(
return
Scheduler
(
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
...
@@ -412,3 +415,38 @@ KVConnectorFactory.register_connector(
...
@@ -412,3 +415,38 @@ KVConnectorFactory.register_connector(
KVConnectorFactory
.
register_connector
(
KVConnectorFactory
.
register_connector
(
"MockKVConnector"
,
__name__
,
MockKVConnector
.
__name__
"MockKVConnector"
,
__name__
,
MockKVConnector
.
__name__
)
)
def
make_kv_cache_config
(
block_size
:
int
,
hma_enabled
:
bool
=
False
,
sw_size
:
int
=
128
,
num_blocks
:
int
=
100
,
)
->
KVCacheConfig
:
kv_cache_groups
=
[
KVCacheGroupSpec
(
[
"layer0"
,
"layer2"
],
FullAttentionSpec
(
block_size
=
block_size
,
num_kv_heads
=
4
,
head_size
=
16
,
dtype
=
torch
.
float16
,
),
)
]
if
hma_enabled
:
kv_cache_groups
.
append
(
KVCacheGroupSpec
(
[
"layer1"
,
"layer3"
],
SlidingWindowSpec
(
block_size
=
block_size
,
num_kv_heads
=
4
,
head_size
=
16
,
dtype
=
torch
.
float16
,
sliding_window
=
sw_size
,
),
)
)
return
KVCacheConfig
(
num_blocks
=
num_blocks
,
kv_cache_tensors
=
[],
kv_cache_groups
=
kv_cache_groups
)
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
5b3ba94a
...
@@ -24,6 +24,9 @@ if TYPE_CHECKING:
...
@@ -24,6 +24,9 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
EngineId
=
str
EngineId
=
str
# block ids as returned by the hybrid KV cache manager. list[list[int]] are allow
# mutability and are for connector internal use only.
BlockIds
=
tuple
[
list
[
int
],
...]
|
list
[
list
[
int
]]
def
get_kv_connector_cache_layout
():
def
get_kv_connector_cache_layout
():
...
...
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
5b3ba94a
This diff is collapsed.
Click to expand it.
vllm/v1/core/kv_cache_manager.py
View file @
5b3ba94a
...
@@ -84,6 +84,18 @@ class KVCacheBlocks:
...
@@ -84,6 +84,18 @@ class KVCacheBlocks:
assert
len
(
self
.
blocks
)
==
1
,
"Only one group is supported"
assert
len
(
self
.
blocks
)
==
1
,
"Only one group is supported"
return
[
block
.
block_id
for
block
in
self
.
blocks
[
0
]
if
block
.
block_hash
is
None
]
return
[
block
.
block_id
for
block
in
self
.
blocks
[
0
]
if
block
.
block_hash
is
None
]
def
get_unhashed_block_ids_all_groups
(
self
)
->
list
[
list
[
int
]]:
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
# Skip padding blocks.
return
[
[
block
.
block_id
for
block
in
group
if
block
.
block_hash
is
None
and
not
block
.
is_null
]
for
group
in
self
.
blocks
]
def
new_empty
(
self
)
->
"KVCacheBlocks"
:
def
new_empty
(
self
)
->
"KVCacheBlocks"
:
"""
"""
Creates a new KVCacheBlocks instance with no blocks.
Creates a new KVCacheBlocks instance with no blocks.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment