Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5b3ba94a
Unverified
Commit
5b3ba94a
authored
Mar 06, 2026
by
Nicolò Lucchesi
Committed by
GitHub
Mar 06, 2026
Browse files
[Core][KVConnector] Support HMA+NixlConnector (#35758)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
90f3c01f
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
669 additions
and
230 deletions
+669
-230
tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
..._connector/nixl_integration/config_sweep_accuracy_test.sh
+9
-0
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
+36
-1
tests/v1/kv_connector/nixl_integration/test_accuracy.py
tests/v1/kv_connector/nixl_integration/test_accuracy.py
+1
-0
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+122
-46
tests/v1/kv_connector/unit/test_nixl_connector_hma.py
tests/v1/kv_connector/unit/test_nixl_connector_hma.py
+203
-0
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
+3
-1
tests/v1/kv_connector/unit/utils.py
tests/v1/kv_connector/unit/utils.py
+53
-15
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+3
-0
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+227
-167
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+12
-0
No files found.
tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
View file @
5b3ba94a
...
...
@@ -12,6 +12,7 @@ tp_configs=(
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
# MLA case
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192"
# SW model
)
dp_ep_configs
=(
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
# MLA+P-TP1, D-DPEP=2 (TP=1)
...
...
@@ -26,6 +27,14 @@ else
configs
=(
"
${
tp_configs
[@]
}
"
)
fi
if
[[
-n
"
${
ENABLE_HMA_FLAG
:-}
"
]]
;
then
# Append ENABLE_HMA_FLAG=1 to each config in the selected array
echo
"ENABLE_HMA_FLAG is set, appending ENABLE_HMA_FLAG=1 to each config"
for
i
in
"
${
!configs[@]
}
"
;
do
configs[
$i
]=
"ENABLE_HMA_FLAG=1
${
configs
[
$i
]
}
"
done
fi
run_tests
()
{
local
label
=
$1
local
extra_args
=
$2
...
...
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
View file @
5b3ba94a
...
...
@@ -5,6 +5,12 @@ set -xe
KV_BUFFER_DEVICE
=
"cuda"
# Default to cuda
ATTENTION_BACKEND
=
""
# Default to empty (use vllm default)
CROSS_LAYERS_BLOCKS
=
"False"
ENABLE_HMA_VAR
=
""
# Default to empty (HMA disabled by default for kv connector)
# Check for ENABLE_HMA_FLAG environment variable
if
[[
-n
"
${
ENABLE_HMA_FLAG
:-}
"
]]
;
then
ENABLE_HMA_VAR
=
"--no-disable-hybrid-kv-cache-manager"
fi
while
[[
$#
-gt
0
]]
;
do
case
$1
in
--kv_buffer_device
)
...
...
@@ -31,6 +37,12 @@ echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE"
if
[[
-n
"
$ATTENTION_BACKEND
"
]]
;
then
echo
"Using attention backend:
$ATTENTION_BACKEND
"
fi
if
[[
-n
"
$ENABLE_HMA_VAR
"
]]
;
then
echo
"HMA (Hybrid KV Cache Manager) enabled"
fi
if
[[
-n
"
$VLLM_SERVE_EXTRA_ARGS
"
]]
;
then
echo
"vLLM serve extra args:
$VLLM_SERVE_EXTRA_ARGS
"
fi
DECODER_KV_LAYOUT
=
${
DECODER_KV_LAYOUT
:-
"HND"
}
# Default to HND, optional NHD
if
[[
"
$DECODER_KV_LAYOUT
"
==
"NHD"
]]
;
then
...
...
@@ -70,6 +82,8 @@ DECODER_TP_SIZE=${DECODER_TP_SIZE:-1}
GPU_MEMORY_UTILIZATION
=
${
GPU_MEMORY_UTILIZATION
:-
0
.2
}
PREFILL_BLOCK_SIZE
=
${
PREFILL_BLOCK_SIZE
:-
128
}
DECODE_BLOCK_SIZE
=
${
DECODE_BLOCK_SIZE
:-
128
}
# Comma-separated extra args for vllm serve (e.g. --max-model-len,2048)
VLLM_SERVE_EXTRA_ARGS
=
${
VLLM_SERVE_EXTRA_ARGS
:-}
# Find the git repository root directory
GIT_ROOT
=
$(
git rev-parse
--show-toplevel
)
...
...
@@ -151,14 +165,24 @@ run_tests_for_model() {
--gpu-memory-utilization
$GPU_MEMORY_UTILIZATION
\
--tensor-parallel-size
$PREFILLER_TP_SIZE
\
--kv-transfer-config '
$KV_CONFIG
'"
if
[[
-n
"
$VLLM_SERVE_EXTRA_ARGS
"
]]
;
then
IFS
=
','
read
-r
-a
extra_args
<<<
"
$VLLM_SERVE_EXTRA_ARGS
"
for
arg
in
"
${
extra_args
[@]
}
"
;
do
BASE_CMD
=
"
${
BASE_CMD
}
$arg
"
done
fi
# Add attention backend config if specified
if
[[
-n
"
$ATTENTION_BACKEND
"
]]
;
then
BASE_CMD
=
"
${
BASE_CMD
}
--attention-backend=
$ATTENTION_BACKEND
"
fi
FULL_CMD
=
"
$BASE_CMD
"
# Add HMA flag if specified
if
[[
-n
"
$ENABLE_HMA_VAR
"
]]
;
then
BASE_CMD
=
"
${
BASE_CMD
}
$ENABLE_HMA_VAR
"
fi
FULL_CMD
=
"
$BASE_CMD
"
eval
"
$FULL_CMD
&"
# Store host and port for proxy configuration
...
...
@@ -193,12 +217,23 @@ run_tests_for_model() {
--block-size
${
DECODE_BLOCK_SIZE
}
\
--gpu-memory-utilization
$GPU_MEMORY_UTILIZATION
\
--kv-transfer-config '
$KV_CONFIG
'"
if
[[
-n
"
$VLLM_SERVE_EXTRA_ARGS
"
]]
;
then
IFS
=
','
read
-r
-a
extra_args
<<<
"
$VLLM_SERVE_EXTRA_ARGS
"
for
arg
in
"
${
extra_args
[@]
}
"
;
do
BASE_CMD
=
"
${
BASE_CMD
}
$arg
"
done
fi
# Add attention backend config if specified
if
[[
-n
"
$ATTENTION_BACKEND
"
]]
;
then
BASE_CMD
=
"
${
BASE_CMD
}
--attention-backend=
$ATTENTION_BACKEND
"
fi
# Add HMA flag if specified
if
[[
-n
"
$ENABLE_HMA_VAR
"
]]
;
then
BASE_CMD
=
"
${
BASE_CMD
}
$ENABLE_HMA_VAR
"
fi
# DP-EP attention mode
if
[[
-z
"
$DP_EP
"
]]
;
then
BASE_CMD
=
"
${
BASE_CMD
}
--tensor-parallel-size
$DECODER_TP_SIZE
"
...
...
tests/v1/kv_connector/nixl_integration/test_accuracy.py
View file @
5b3ba94a
...
...
@@ -17,6 +17,7 @@ EXPECTED_VALUES = {
"deepseek-ai/deepseek-vl2-small"
:
0.59
,
"deepseek-ai/deepseek-vl2-tiny"
:
0.19
,
"deepseek-ai/DeepSeek-V2-Lite-Chat"
:
0.65
,
"google/gemma-3-4b-it"
:
0.74
,
}
SIMPLE_PROMPT
=
(
...
...
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
5b3ba94a
...
...
@@ -59,7 +59,12 @@ from vllm.v1.request import RequestStatus
from
vllm.v1.worker.kv_connector_model_runner_mixin
import
KVConnectorModelRunnerMixin
from
vllm.v1.worker.utils
import
AttentionGroup
from
.utils
import
create_request
,
create_scheduler
,
create_vllm_config
from
.utils
import
(
create_request
,
create_scheduler
,
create_vllm_config
,
make_kv_cache_config
,
)
@
pytest
.
fixture
(
scope
=
"module"
,
autouse
=
True
)
...
...
@@ -263,7 +268,7 @@ def test_basic_interface():
req_meta
=
kv_connector_metadata
.
reqs_to_recv
[
request_id
]
for
block_id
,
block
in
zip
(
req_meta
.
local_block_ids
,
req_meta
.
local_block_ids
[
0
]
,
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
[
0
].
req_to_blocks
[
request_id
],
...
...
@@ -327,7 +332,9 @@ def test_kv_transfer_handshake(dist_init):
# Prefill connector will register KV cache to populate proper handshake
# metadata.
prefill_connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
prefill_connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
kv_cache_shape
=
FlashAttentionBackend
.
get_kv_cache_shape
(
num_blocks
=
2
,
block_size
=
16
,
num_kv_heads
=
4
,
head_size
=
64
)
...
...
@@ -367,13 +374,17 @@ def test_kv_transfer_handshake(dist_init):
do_remote_decode
=
True
,
)
request
.
status
=
RequestStatus
.
FINISHED_LENGTH_CAPPED
delay
,
kv_connector_metadata
=
scheduler
.
get_kv_connector
().
request_finished
(
request
,
[
0
,
1
,
2
]
delay
,
kv_connector_metadata
=
(
scheduler
.
get_kv_connector
().
request_finished_all_groups
(
request
,
([
0
,
1
,
2
],)
)
)
assert
delay
# Decode connector will be able to create handshake with the prefill connector.
decode_connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
decode_connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
decode_connector
.
register_kv_caches
(
kv_caches
)
# Here we are testing the retrieval of NIXLAgentMetadata.
...
...
@@ -404,9 +415,16 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
REMOTE_ENGINE_ID
=
"remote_engine"
def
__init__
(
self
,
*
args
,
hand_shake_latency
:
float
=
1.8
,
kv_cache_layout
=
"HND"
,
**
kwargs
self
,
*
args
,
hand_shake_latency
:
float
=
1.8
,
kv_cache_layout
=
"HND"
,
kv_cache_config
=
None
,
**
kwargs
,
):
super
().
__init__
(
*
args
,
**
kwargs
)
if
kv_cache_config
is
None
:
kv_cache_config
=
make_kv_cache_config
(
block_size
=
16
)
super
().
__init__
(
*
args
,
kv_cache_config
=
kv_cache_config
,
**
kwargs
)
self
.
_hand_shake_latency
=
hand_shake_latency
self
.
kv_cache_layout
=
kv_cache_layout
# Mock register_kv_caches attribute needed for tests that do not call it.
...
...
@@ -507,7 +525,9 @@ class TestNixlHandshake:
request_id
=
"req_id"
# Test worker role in decode server.
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
...
...
@@ -528,13 +548,15 @@ class TestNixlHandshake:
num_xfers
-=
1
metadata
.
add_new_req_to_recv
(
request_id
=
request_id
,
local_block_ids
=
[
num_xfers
+
1
,
num_xfers
+
2
,
num_xfers
+
3
],
local_block_ids
=
(
[
num_xfers
+
1
,
num_xfers
+
2
,
num_xfers
+
3
],
),
kv_transfer_params
=
{
"remote_block_ids"
:
[
"remote_block_ids"
:
(
[
num_xfers
+
4
,
num_xfers
+
5
,
num_xfers
+
6
,
],
),
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_request_id"
:
f
"prefill-
{
request_id
}
"
,
"remote_host"
:
"localhost"
,
...
...
@@ -594,16 +616,18 @@ class TestNixlHandshake:
vllm_config
.
parallel_config
.
tensor_parallel_size
=
decode_tp_size
# Test worker role in decode server.
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
)
metadata
=
NixlConnectorMetadata
()
metadata
.
add_new_req_to_recv
(
request_id
=
"id"
,
local_block_ids
=
[
1
,
2
,
3
],
local_block_ids
=
(
[
1
,
2
,
3
],
),
kv_transfer_params
=
{
"remote_block_ids"
:
[
4
,
5
,
6
],
"remote_block_ids"
:
(
[
4
,
5
,
6
],
),
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_request_id"
:
"prefill-id"
,
"remote_host"
:
"localhost"
,
...
...
@@ -652,7 +676,9 @@ class TestNixlHandshake:
local_tp_size
=
1
vllm_config
.
parallel_config
.
tensor_parallel_size
=
local_tp_size
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
...
...
@@ -717,8 +743,12 @@ class TestNixlHandshake:
p_tp_size
=
2
# Build two separate connectors/workers to emulate P TP=2 ranks.
conn_p0
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
conn_p1
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
conn_p0
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
conn_p1
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
conn_p0
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
conn_p0
.
engine_id
,
hand_shake_latency
=
0
)
...
...
@@ -815,7 +845,9 @@ class TestNixlHandshake:
vllm_config
=
create_vllm_config
()
# Test worker role in decode server.
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
)
...
...
@@ -827,9 +859,9 @@ class TestNixlHandshake:
for
i
in
range
(
total_reqs
):
metadata
.
add_new_req_to_recv
(
request_id
=
f
"id_
{
i
}
"
,
local_block_ids
=
[
1
,
2
,
3
],
local_block_ids
=
(
[
1
,
2
,
3
],
),
kv_transfer_params
=
{
"remote_block_ids"
:
[
4
,
5
,
6
],
"remote_block_ids"
:
(
[
4
,
5
,
6
],
),
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_request_id"
:
f
"prefill-id-
{
i
}
"
,
"remote_host"
:
"localhost"
,
...
...
@@ -884,7 +916,9 @@ class TestNixlHandshake:
return_value
=
2
,
):
# Initialize connector and worker (with fake NIXL wrapper)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
...
...
@@ -934,7 +968,9 @@ class TestNixlHandshake:
return_value
=
2
,
):
# Initialize connector and worker (with fake NIXL wrapper)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
...
...
@@ -979,7 +1015,9 @@ def test_kv_connector_stats(default_vllm_config, dist_init):
vllm_config
=
create_vllm_config
()
# Test worker role in decode server.
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
...
...
@@ -993,9 +1031,9 @@ def test_kv_connector_stats(default_vllm_config, dist_init):
metadata
=
NixlConnectorMetadata
()
metadata
.
add_new_req_to_recv
(
request_id
=
request_id
,
local_block_ids
=
[
1
,
2
,
3
],
local_block_ids
=
(
[
1
,
2
,
3
],
),
kv_transfer_params
=
{
"remote_block_ids"
:
[
4
,
5
,
6
],
"remote_block_ids"
:
(
[
4
,
5
,
6
],
),
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_request_id"
:
f
"prefill-
{
request_id
}
"
,
"remote_host"
:
"localhost"
,
...
...
@@ -1448,7 +1486,9 @@ def test_register_kv_caches(
mock_get_attn_backend
.
return_value
=
backend_cls
# Create connector
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
...
...
@@ -1676,7 +1716,9 @@ def test_kv_buffer_to_nixl_memory_types(
),
):
# noqa: E501
# Create connector and replace its worker with a fake one for isolation
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
# Verify get_reg_descs was called with the correct memory_type
assert
connector
.
connector_worker
.
kv_buffer_device
==
kv_buffer_device
...
...
@@ -1692,9 +1734,15 @@ def test_shutdown_cleans_up_resources(default_vllm_config, dist_init):
vllm_config
=
create_vllm_config
()
scheduler
=
NixlConnectorScheduler
(
vllm_config
,
vllm_config
.
kv_transfer_config
.
engine_id
vllm_config
,
vllm_config
.
kv_transfer_config
.
engine_id
,
make_kv_cache_config
(
block_size
=
16
),
)
worker
=
NixlConnectorWorker
(
vllm_config
,
vllm_config
.
kv_transfer_config
.
engine_id
,
make_kv_cache_config
(
block_size
=
16
),
)
worker
=
NixlConnectorWorker
(
vllm_config
,
vllm_config
.
kv_transfer_config
.
engine_id
)
nixl_wrapper
=
worker
.
nixl_wrapper
with
(
...
...
@@ -1756,7 +1804,9 @@ def test_aborted_request_removed_from_worker_in_batch(default_vllm_config, dist_
scheduler
=
create_scheduler
(
vllm_config
)
# KVConnector Worker in P
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
...
...
@@ -1875,12 +1925,14 @@ class FailingNixlWrapper(FakeNixlWrapper):
(
"transfer_exception"
,
{
"fail_transfer_exception"
:
True
},
True
),
],
)
@
pytest
.
mark
.
parametrize
(
"enable_hma"
,
[
False
,
True
])
def
test_transfer_failure_logging
(
default_vllm_config
,
dist_init
,
failure_type
,
wrapper_config
,
needs_get_finished
,
enable_hma
,
):
"""Test that transfer failures are logged with structured context.
...
...
@@ -1897,9 +1949,16 @@ def test_transfer_failure_logging(
vllm_config
=
create_vllm_config
()
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
,
hma_enabled
=
enable_hma
),
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0.0
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0.0
,
kv_cache_config
=
connector
.
_kv_cache_config
,
)
# Configure FailingNixlWrapper to fail in the specified way
...
...
@@ -1910,8 +1969,17 @@ def test_transfer_failure_logging(
# For notification_failed, we need empty local blocks
# (full cache hit path to trigger send_notif)
local_blocks
=
[]
if
failure_type
==
"notification_failed"
else
[
10
,
11
,
12
]
remote_blocks
=
[
20
,
21
,
22
]
local_blocks
:
tuple
[()]
|
tuple
[
list
[
int
],
...]
if
enable_hma
:
# HMA enabled: multiple groups (FA + SW)
local_blocks
=
(
()
if
failure_type
==
"notification_failed"
else
([
10
,
11
,
12
],
[
13
,
14
])
)
remote_blocks
=
[[
20
,
21
,
22
],
[
23
,
24
]]
else
:
# HMA disabled: single group
local_blocks
=
()
if
failure_type
==
"notification_failed"
else
([
10
,
11
,
12
],)
remote_blocks
=
[[
20
,
21
,
22
]]
metadata
=
NixlConnectorMetadata
()
metadata
.
add_new_req_to_recv
(
...
...
@@ -2007,7 +2075,9 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
"""Test that handshake failures mark blocks invalid and return via get_finished."""
vllm_config
=
create_vllm_config
()
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0.1
)
...
...
@@ -2017,9 +2087,9 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
metadata
=
NixlConnectorMetadata
()
metadata
.
add_new_req_to_recv
(
request_id
=
request_id
,
local_block_ids
=
[
1
,
2
,
3
],
local_block_ids
=
(
[
1
,
2
,
3
],
),
kv_transfer_params
=
{
"remote_block_ids"
:
[
4
,
5
,
6
],
"remote_block_ids"
:
(
[
4
,
5
,
6
],
),
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_request_id"
:
f
"prefill-
{
request_id
}
"
,
"remote_host"
:
"localhost"
,
...
...
@@ -2058,7 +2128,9 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init)
and return via get_finished."""
vllm_config
=
create_vllm_config
()
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
...
...
@@ -2068,9 +2140,9 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init)
metadata
=
NixlConnectorMetadata
()
metadata
.
add_new_req_to_recv
(
request_id
=
request_id
,
local_block_ids
=
[
7
,
8
,
9
],
local_block_ids
=
(
[
7
,
8
,
9
],
),
kv_transfer_params
=
{
"remote_block_ids"
:
[
10
,
11
,
12
],
"remote_block_ids"
:
(
[
10
,
11
,
12
],
),
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_request_id"
:
f
"prefill-
{
request_id
}
"
,
"remote_host"
:
"localhost"
,
...
...
@@ -2154,7 +2226,9 @@ def test_compatibility_hash_validation(
"enforce_handshake_compat"
:
enforce_handshake_compat
},
)
decode_connector
=
NixlConnector
(
local_vllm_config
,
KVConnectorRole
.
WORKER
)
decode_connector
=
NixlConnector
(
local_vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
decode_worker
=
decode_connector
.
connector_worker
kv_cache_shape
=
decode_worker
.
attn_backend
.
get_kv_cache_shape
(
num_blocks
=
2
,
block_size
=
16
,
num_kv_heads
=
4
,
head_size
=
64
...
...
@@ -2267,7 +2341,9 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
model
=
"facebook/opt-125m"
,
block_size
=
16
,
)
decode_connector
=
NixlConnector
(
local_vllm_config
,
KVConnectorRole
.
WORKER
)
decode_connector
=
NixlConnector
(
local_vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
decode_worker
=
decode_connector
.
connector_worker
backend
=
get_current_attn_backend
(
local_vllm_config
)
...
...
tests/v1/kv_connector/unit/test_nixl_connector_hma.py
0 → 100644
View file @
5b3ba94a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for NixlConnectorScheduler sw_sizes calculation with HMA."""
from
unittest.mock
import
patch
import
pytest
from
vllm
import
LLM
,
SamplingParams
from
vllm.config
import
KVTransferConfig
from
vllm.v1.core.single_type_kv_cache_manager
import
(
FullAttentionManager
,
SlidingWindowManager
,
)
from
.utils
import
(
create_vllm_config
,
make_kv_cache_config
,
)
@
pytest
.
mark
.
cpu_test
@
pytest
.
mark
.
parametrize
(
"hma_enabled,expected_sw_sizes"
,
[
# HMA enabled: FullAttentionSpec (0) + SlidingWindowSpec (2048/16=128)
(
True
,
[
0
,
128
+
1
]),
# HMA disabled: only FullAttentionSpec (0)
(
False
,
[
0
]),
],
)
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform"
)
def
test_sw_sizes
(
mock_platform
,
hma_enabled
,
expected_sw_sizes
):
"""Test sw_sizes is correctly computed based on HMA enabled/disabled."""
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
NixlConnectorScheduler
,
)
mock_platform
.
device_type
=
"cpu"
block_size
=
16
vllm_config
=
create_vllm_config
(
block_size
=
block_size
)
# SW 2048 tokens=>128 blocks
kv_cache_config
=
make_kv_cache_config
(
block_size
=
block_size
,
hma_enabled
=
hma_enabled
,
sw_size
=
2048
)
scheduler
=
NixlConnectorScheduler
(
vllm_config
=
vllm_config
,
engine_id
=
"test-engine"
,
kv_cache_config
=
kv_cache_config
,
)
# in number of blocks
assert
scheduler
.
blocks_per_sw
==
expected_sw_sizes
,
(
f
"Expected sw_sizes=
{
expected_sw_sizes
}
, got
{
scheduler
.
blocks_per_sw
}
"
)
@
pytest
.
mark
.
cpu_test
def
test_logical_to_kernel_block_ids_with_hma
():
"""Test _logical_to_kernel_block_ids expands blocks when HMA is enabled.
When HMA is enabled, the logical block size may differ from the kernel
block size. Each logical block maps to multiple kernel blocks.
"""
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
NixlConnectorWorker
,
)
# Create a mock worker with just the required attributes
# (use __new__ to skip __init__)
worker
=
object
.
__new__
(
NixlConnectorWorker
)
# Simulate HMA scenario: logical block size = 32, kernel block size = 16
# So each logical block maps to 2 kernel blocks eg [0]->[0,1]
worker
.
_physical_blocks_per_logical_kv_block
=
2
# Test conversion: FA + SW group
logical_block_ids
=
[[
0
,
1
,
2
],
[
3
,
4
]]
kernel_block_ids
=
worker
.
_logical_to_kernel_block_ids
(
logical_block_ids
)
expected_kernel_block_ids
=
[[
0
,
1
,
2
,
3
,
4
,
5
],
[
6
,
7
,
8
,
9
]]
assert
kernel_block_ids
==
expected_kernel_block_ids
,
(
f
"Expected
{
expected_kernel_block_ids
}
, got
{
kernel_block_ids
}
"
)
@
pytest
.
mark
.
parametrize
(
"model_name, sw_size"
,
[(
"google/gemma-3-1b-it"
,
512
)])
def
test_fewer_blocks_with_hma
(
monkeypatch
,
model_name
,
sw_size
):
"""Test that a prefill instance returns fewer "remote blocks" for the SWA groups
when sequence exceeds the sliding window.
"""
kv_transfer_config
=
KVTransferConfig
(
kv_connector
=
"NixlConnector"
,
kv_role
=
"kv_both"
,
)
block_size
=
16
llm_kwargs
=
{
"model"
:
model_name
,
"enforce_eager"
:
True
,
"gpu_memory_utilization"
:
0.5
,
"kv_transfer_config"
:
kv_transfer_config
,
"max_model_len"
:
2048
,
# NOTE: Make sure HMA is enabled
"disable_hybrid_kv_cache_manager"
:
False
,
"max_num_batched_tokens"
:
1024
,
"enable_prefix_caching"
:
False
,
"block_size"
:
block_size
,
}
monkeypatch
.
setenv
(
"VLLM_ENABLE_V1_MULTIPROCESSING"
,
"0"
)
def
run_hma_test
(
llm
:
LLM
):
remote_prefill_opts
=
{
"do_remote_decode"
:
True
,
"do_remote_prefill"
:
False
,
"remote_engine_id"
:
None
,
"remote_block_ids"
:
None
,
"remote_host"
:
None
,
"remote_port"
:
None
,
}
# Simulate sidecar request
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
1
,
extra_args
=
{
"kv_transfer_params"
:
remote_prefill_opts
},
)
scheduler
=
llm
.
llm_engine
.
engine_core
.
engine_core
.
scheduler
kv_managers
=
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
# HMA enabled with FA + SWA groups
assert
len
(
kv_managers
)
>
2
for
kv_manager
in
kv_managers
:
assert
isinstance
(
kv_manager
,
(
SlidingWindowManager
,
FullAttentionManager
))
req_to_blocks
=
kv_managers
[
0
].
req_to_blocks
assert
len
(
req_to_blocks
)
==
0
# Process some request with length exceeding the sliding window
outputs
=
llm
.
generate
([
"hi"
*
1401
],
sampling_params
)
kv_params
=
outputs
[
0
].
kv_transfer_params
# +1 to account for overlapping window across blocks.
expected_num_remote_blocks
=
sw_size
//
block_size
+
1
remote_block_ids
=
kv_params
[
"remote_block_ids"
]
assert
(
len
(
remote_block_ids
[
0
])
==
expected_num_remote_blocks
<
len
(
remote_block_ids
[
-
1
])
)
for
group_block_ids
in
remote_block_ids
[:
-
1
]:
assert
len
(
group_block_ids
)
==
expected_num_remote_blocks
def
run_test_and_cleanup
():
llm
=
LLM
(
**
llm_kwargs
)
try
:
run_hma_test
(
llm
)
finally
:
llm
.
llm_engine
.
engine_core
.
shutdown
()
run_test_and_cleanup
()
@
pytest
.
mark
.
cpu_test
def
test_nixl_metadata_hma_block_ids_structure
():
"""
Test that NixlConnectorMetadata correctly stores block IDs for multiple
KV cache groups when HMA is enabled.
"""
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
NixlConnectorMetadata
,
)
metadata
=
NixlConnectorMetadata
()
# Add request with block IDs for 2 groups (FA + SW)
fa_blocks
=
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
]
# 8 blocks for FA
sw_blocks
=
[
8
,
9
,
10
,
11
]
# 4 blocks for SW (clipped)
metadata
.
add_new_req_to_recv
(
request_id
=
"test-req-hma"
,
local_block_ids
=
(
fa_blocks
,
sw_blocks
),
kv_transfer_params
=
{
"remote_block_ids"
:
([
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
],
[
18
,
19
,
20
,
21
]),
"remote_engine_id"
:
"remote-engine"
,
"remote_request_id"
:
"prefill-test-req-hma"
,
"remote_host"
:
"localhost"
,
"remote_port"
:
1234
,
"tp_size"
:
1
,
},
)
assert
"test-req-hma"
in
metadata
.
reqs_to_recv
req_meta
=
metadata
.
reqs_to_recv
[
"test-req-hma"
]
# Verify local block IDs structure
assert
len
(
req_meta
.
local_block_ids
)
==
2
assert
list
(
req_meta
.
local_block_ids
[
0
])
==
fa_blocks
assert
list
(
req_meta
.
local_block_ids
[
1
])
==
sw_blocks
# Verify remote block IDs structure
assert
req_meta
.
remote
is
not
None
assert
len
(
req_meta
.
remote
.
block_ids
)
==
2
assert
list
(
req_meta
.
remote
.
block_ids
[
0
])
==
[
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
]
assert
list
(
req_meta
.
remote
.
block_ids
[
1
])
==
[
18
,
19
,
20
,
21
]
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
View file @
5b3ba94a
...
...
@@ -208,7 +208,9 @@ def test_prefix_cache_lifecycle():
# Ensure we send all block ids, including the partial blocks,
# even if there is a cache hit.
assert
len
(
kv_transfer_params
[
"remote_block_ids"
])
==
(
NUM_EXTERNAL_FULL_BLOCKS
+
1
)
# remote_block_ids is BlockIds (tuple of lists); sum block counts across groups.
num_remote_blocks
=
sum
(
len
(
g
)
for
g
in
kv_transfer_params
[
"remote_block_ids"
])
assert
num_remote_blocks
==
(
NUM_EXTERNAL_FULL_BLOCKS
+
1
)
# STEP (2): Ensure it is freed.
scheduler_output
=
scheduler
.
schedule
()
...
...
tests/v1/kv_connector/unit/utils.py
View file @
5b3ba94a
...
...
@@ -36,6 +36,7 @@ from vllm.v1.kv_cache_interface import (
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
,
SlidingWindowSpec
,
)
from
vllm.v1.outputs
import
KVConnectorOutput
,
ModelRunnerOutput
from
vllm.v1.request
import
Request
...
...
@@ -142,9 +143,11 @@ def create_vllm_config(
def
create_scheduler
(
vllm_config
:
VllmConfig
,
num_blocks
:
int
=
10000
,
kv_cache_config
:
KVCacheConfig
|
None
=
None
,
)
->
Scheduler
:
"""Initialize Scheduler For Testing."""
block_size
=
vllm_config
.
cache_config
.
block_size
if
kv_cache_config
is
None
:
kv_cache_config
=
KVCacheConfig
(
num_blocks
=
num_blocks
,
# A large number of blocks to hold all requests
kv_cache_tensors
=
[],
...
...
@@ -412,3 +415,38 @@ KVConnectorFactory.register_connector(
KVConnectorFactory
.
register_connector
(
"MockKVConnector"
,
__name__
,
MockKVConnector
.
__name__
)
def
make_kv_cache_config
(
block_size
:
int
,
hma_enabled
:
bool
=
False
,
sw_size
:
int
=
128
,
num_blocks
:
int
=
100
,
)
->
KVCacheConfig
:
kv_cache_groups
=
[
KVCacheGroupSpec
(
[
"layer0"
,
"layer2"
],
FullAttentionSpec
(
block_size
=
block_size
,
num_kv_heads
=
4
,
head_size
=
16
,
dtype
=
torch
.
float16
,
),
)
]
if
hma_enabled
:
kv_cache_groups
.
append
(
KVCacheGroupSpec
(
[
"layer1"
,
"layer3"
],
SlidingWindowSpec
(
block_size
=
block_size
,
num_kv_heads
=
4
,
head_size
=
16
,
dtype
=
torch
.
float16
,
sliding_window
=
sw_size
,
),
)
)
return
KVCacheConfig
(
num_blocks
=
num_blocks
,
kv_cache_tensors
=
[],
kv_cache_groups
=
kv_cache_groups
)
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
5b3ba94a
...
...
@@ -24,6 +24,9 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
EngineId
=
str
# block ids as returned by the hybrid KV cache manager. list[list[int]] are allow
# mutability and are for connector internal use only.
BlockIds
=
tuple
[
list
[
int
],
...]
|
list
[
list
[
int
]]
def
get_kv_connector_cache_layout
():
...
...
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
5b3ba94a
...
...
@@ -3,7 +3,6 @@
import
contextlib
import
copy
import
logging
import
math
import
os
import
queue
import
sys
...
...
@@ -24,6 +23,7 @@ import zmq
from
vllm
import
envs
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.utils
import
(
BlockIds
,
EngineId
,
TpKVTopology
,
get_current_attn_backend
,
...
...
@@ -38,6 +38,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorHandshakeMetadata
,
KVConnectorMetadata
,
KVConnectorRole
,
SupportsHMA
,
)
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
(
KVConnectorPromMetrics
,
...
...
@@ -53,10 +54,12 @@ from vllm.distributed.parallel_state import (
from
vllm.forward_context
import
ForwardContext
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.network_utils
import
make_zmq_path
,
make_zmq_socket
from
vllm.v1.attention.backend
import
AttentionBackend
,
AttentionMetadata
from
vllm.v1.attention.backends.utils
import
get_kv_cache_layout
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.kv_cache_interface
import
FullAttentionSpec
,
MambaSpec
,
SlidingWindowSpec
from
vllm.v1.worker.block_table
import
BlockTable
if
TYPE_CHECKING
:
...
...
@@ -205,6 +208,7 @@ def compute_nixl_compatibility_hash(
model_config
=
vllm_config
.
model_config
cache_config
=
vllm_config
.
cache_config
is_hma_enabled
=
not
vllm_config
.
scheduler_config
.
disable_hybrid_kv_cache_manager
factors
=
{
# Version compatibility
...
...
@@ -220,6 +224,7 @@ def compute_nixl_compatibility_hash(
"attn_backend_name"
:
attn_backend_name
,
"cache_dtype"
:
str
(
cache_config
.
cache_dtype
),
"cross_layers_blocks"
:
cross_layers_blocks
,
"is_hma_enabled"
:
is_hma_enabled
,
}
compat_hash
=
hash_factors
(
factors
)
...
...
@@ -238,7 +243,7 @@ def compute_nixl_compatibility_hash(
@
dataclass
class
RemoteMeta
:
block_ids
:
list
[
int
]
block_ids
:
BlockIds
host
:
str
port
:
int
engine_id
:
str
...
...
@@ -247,9 +252,9 @@ class RemoteMeta:
@
dataclass
class
ReqMeta
:
local_block_ids
:
list
[
int
]
local_block_ids
:
BlockIds
# To be used when logical block size does not match the kernel block size
local_physical_block_ids
:
list
[
int
]
local_physical_block_ids
:
BlockIds
tp_size
:
int
remote
:
RemoteMeta
|
None
=
None
...
...
@@ -264,7 +269,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
def
_add_new_req
(
self
,
local_block_ids
:
list
[
int
]
,
local_block_ids
:
BlockIds
,
kv_transfer_params
:
dict
[
str
,
Any
],
)
->
ReqMeta
:
return
ReqMeta
(
...
...
@@ -277,7 +282,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
def
add_new_req_to_save
(
self
,
request_id
:
ReqId
,
local_block_ids
:
list
[
int
]
,
local_block_ids
:
BlockIds
,
kv_transfer_params
:
dict
[
str
,
Any
],
):
self
.
reqs_to_save
[
request_id
]
=
self
.
_add_new_req
(
...
...
@@ -287,7 +292,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
def
add_new_req_to_recv
(
self
,
request_id
:
ReqId
,
local_block_ids
:
list
[
int
]
,
local_block_ids
:
BlockIds
,
kv_transfer_params
:
dict
[
str
,
Any
],
):
req
=
self
.
_add_new_req
(
local_block_ids
,
kv_transfer_params
)
...
...
@@ -301,7 +306,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
self
.
reqs_to_recv
[
request_id
]
=
req
class
NixlConnector
(
KVConnectorBase_V1
):
class
NixlConnector
(
KVConnectorBase_V1
,
SupportsHMA
):
@
property
def
prefer_cross_layer_blocks
(
self
)
->
bool
:
backend
=
get_current_attn_backend
(
self
.
_vllm_config
)
...
...
@@ -326,22 +331,27 @@ class NixlConnector(KVConnectorBase_V1):
self
,
vllm_config
:
VllmConfig
,
role
:
KVConnectorRole
,
kv_cache_config
:
"KVCacheConfig
| None"
=
None
,
kv_cache_config
:
"KVCacheConfig
"
,
):
super
().
__init__
(
vllm_config
,
role
,
kv_cache_config
)
assert
vllm_config
.
kv_transfer_config
is
not
None
assert
vllm_config
.
kv_transfer_config
.
engine_id
is
not
None
for
group
in
kv_cache_config
.
kv_cache_groups
:
if
isinstance
(
group
.
kv_cache_spec
,
MambaSpec
):
raise
ValueError
(
"NixlConnector does not support Mamba models."
)
self
.
engine_id
:
EngineId
=
vllm_config
.
kv_transfer_config
.
engine_id
self
.
kv_transfer_config
=
vllm_config
.
kv_transfer_config
if
role
==
KVConnectorRole
.
SCHEDULER
:
self
.
connector_scheduler
:
NixlConnectorScheduler
|
None
=
(
NixlConnectorScheduler
(
vllm_config
,
self
.
engine_id
)
NixlConnectorScheduler
(
vllm_config
,
self
.
engine_id
,
kv_cache_config
)
)
self
.
connector_worker
:
NixlConnectorWorker
|
None
=
None
elif
role
==
KVConnectorRole
.
WORKER
:
self
.
connector_scheduler
=
None
self
.
connector_worker
=
NixlConnectorWorker
(
vllm_config
,
self
.
engine_id
)
self
.
connector_worker
=
NixlConnectorWorker
(
vllm_config
,
self
.
engine_id
,
kv_cache_config
)
############################################################
# Class Methods
...
...
@@ -392,10 +402,10 @@ class NixlConnector(KVConnectorBase_V1):
assert
self
.
connector_scheduler
is
not
None
return
self
.
connector_scheduler
.
build_connector_meta
(
scheduler_output
)
def
request_finished
(
def
request_finished
_all_groups
(
self
,
request
:
"Request"
,
block_ids
:
list
[
int
],
block_ids
:
tuple
[
list
[
int
],
...],
)
->
tuple
[
bool
,
dict
[
str
,
Any
]
|
None
]:
assert
self
.
connector_scheduler
is
not
None
return
self
.
connector_scheduler
.
request_finished
(
request
,
block_ids
)
...
...
@@ -518,10 +528,13 @@ class NixlConnector(KVConnectorBase_V1):
class
NixlConnectorScheduler
:
"""Implementation of Scheduler side methods"""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
engine_id
:
str
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
engine_id
:
str
,
kv_cache_config
:
"KVCacheConfig"
):
self
.
vllm_config
=
vllm_config
self
.
block_size
=
vllm_config
.
cache_config
.
block_size
self
.
engine_id
:
EngineId
=
engine_id
self
.
kv_cache_config
=
kv_cache_config
self
.
side_channel_host
=
envs
.
VLLM_NIXL_SIDE_CHANNEL_HOST
self
.
side_channel_port
=
(
envs
.
VLLM_NIXL_SIDE_CHANNEL_PORT
...
...
@@ -534,8 +547,18 @@ class NixlConnectorScheduler:
self
.
use_host_buffer
=
(
vllm_config
.
kv_transfer_config
.
kv_buffer_device
==
"cpu"
)
self
.
_is_hma_required
=
(
not
vllm_config
.
scheduler_config
.
disable_hybrid_kv_cache_manager
# Also handle unlikely SW-only model case instead of checking num_groups>1.
and
any
(
not
isinstance
(
g
.
kv_cache_spec
,
FullAttentionSpec
)
for
g
in
kv_cache_config
.
kv_cache_groups
)
)
logger
.
info
(
"Initializing NIXL Scheduler %s"
,
engine_id
)
if
vllm_config
.
scheduler_config
.
disable_hybrid_kv_cache_manager
:
logger
.
info
(
"Hybrid Memory Allocator is enabled with NIXL"
)
# Background thread for handling new handshake requests.
self
.
_nixl_handshake_listener_t
:
threading
.
Thread
|
None
=
None
...
...
@@ -545,7 +568,7 @@ class NixlConnectorScheduler:
# Requests that need to start recv/send.
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
self
.
_reqs_need_recv
:
dict
[
ReqId
,
tuple
[
Request
,
list
[
int
]
]]
=
{}
self
.
_reqs_need_recv
:
dict
[
ReqId
,
tuple
[
Request
,
BlockIds
]]
=
{}
self
.
_reqs_need_save
:
dict
[
ReqId
,
Request
]
=
{}
# Reqs to send and their expiration time
self
.
_reqs_need_send
:
dict
[
ReqId
,
float
]
=
{}
...
...
@@ -554,12 +577,54 @@ class NixlConnectorScheduler:
# remote prefill or aborted.
self
.
_reqs_not_processed
:
set
[
ReqId
]
=
set
()
# Gather Sliding Window sizes for each kv cache group (if any) in number of
# blocks per KV cache group. This is used to clip the local attention window.
sw_sizes_tokens
:
list
[
tuple
[
int
,
int
]]
=
[
(
g
.
kv_cache_spec
.
sliding_window
,
g
.
kv_cache_spec
.
block_size
)
if
isinstance
(
g
.
kv_cache_spec
,
SlidingWindowSpec
)
else
(
0
,
self
.
block_size
)
for
g
in
kv_cache_config
.
kv_cache_groups
]
# cdiv(n_tokens, block_size) gives blocks/window; add 1 to conservatively
# account for boundary overlap eg window isn't fully aligned with blocks.
self
.
blocks_per_sw
=
[
cdiv
(
n_tokens
,
block_size
)
+
1
if
n_tokens
else
0
for
n_tokens
,
block_size
in
sw_sizes_tokens
]
def
shutdown
(
self
):
self
.
_stop_event
.
set
()
if
self
.
_nixl_handshake_listener_t
is
not
None
:
self
.
_nixl_handshake_listener_t
.
join
()
self
.
_nixl_handshake_listener_t
=
None
def
get_sw_clipped_blocks
(
self
,
block_ids
:
BlockIds
)
->
BlockIds
:
"""
Clip the number of blocks to the sliding window size for each kv cache group
that employs SWA.
This is necessary because the KV Cache manager initially allocates blocks for
the entire sequence length, and successively cleans up blocks that are outside
the window prior to the `request_finished_all_groups` hook.
"""
if
len
(
block_ids
)
==
0
or
not
self
.
_is_hma_required
:
# No blocks to clip eg Full prefix cache hit or not a hybrid model.
return
block_ids
# NOTE (NickLucche) This logic is currently handled at the connector level
# because offloading connectors might want to receive the whole sequence even
# for SWA groups. We will abstract this logic once the interface is more stable
assert
len
(
block_ids
)
==
len
(
self
.
blocks_per_sw
),
(
"Number of KV cache groups must match"
)
# For non-SWA groups, blocks_per_sw is 0 so we return all block_ids unchanged
return
tuple
(
[
blocks
[
-
self
.
blocks_per_sw
[
i
]
:]
if
self
.
blocks_per_sw
[
i
]
>
0
else
blocks
for
i
,
blocks
in
enumerate
(
block_ids
)
]
)
def
set_xfer_handshake_metadata
(
self
,
metadata
:
dict
[
int
,
KVConnectorHandshakeMetadata
]
)
->
None
:
...
...
@@ -707,12 +772,18 @@ class NixlConnectorScheduler:
# If remote_blocks and num_external_tokens = 0, we have
# a full prefix cache hit on the D worker. We need to call
# send_notif in _read_blocks to free the memory on the P.
local_block_ids
=
(
blocks
.
get_unhashed_block_ids
()
unhashed_local_block_ids
:
BlockIds
=
(
blocks
.
get_unhashed_block_ids_all_groups
()
if
num_external_tokens
>
0
else
[]
else
()
)
# Get unhashed blocks to pull from remote.
local_block_ids
=
self
.
get_sw_clipped_blocks
(
unhashed_local_block_ids
)
# Get unhashed blocks to pull from remote. Mind that a full prefix
# cache hit is indicated with an empty list.
self
.
_reqs_need_recv
[
request
.
request_id
]
=
(
request
,
local_block_ids
,
...
...
@@ -753,9 +824,10 @@ class NixlConnectorScheduler:
req
=
req_to_save
assert
req
.
kv_transfer_params
is
not
None
clipped_block_id_groups
=
self
.
get_sw_clipped_blocks
(
new_block_id_groups
)
meta
.
add_new_req_to_save
(
request_id
=
req_id
,
local_block_ids
=
new
_block_id_groups
[
0
]
,
local_block_ids
=
clipped
_block_id_groups
,
kv_transfer_params
=
req
.
kv_transfer_params
,
)
assert
scheduler_output
.
num_scheduled_tokens
is
not
None
...
...
@@ -786,7 +858,7 @@ class NixlConnectorScheduler:
def
request_finished
(
self
,
request
:
"Request"
,
block_ids
:
list
[
int
]
,
block_ids
:
BlockIds
,
)
->
tuple
[
bool
,
dict
[
str
,
Any
]
|
None
]:
"""
Once a request is finished, determine whether request blocks
...
...
@@ -828,7 +900,7 @@ class NixlConnectorScheduler:
# TODO: check whether block_ids actually ever be 0. If not we could
# remove the conditional below
delay_free_blocks
=
len
(
block_ids
)
>
0
delay_free_blocks
=
any
(
len
(
group
)
>
0
for
group
in
block_ids
)
if
delay_free_blocks
:
# Prefill request on remote. It will be read from D upon completion
...
...
@@ -841,6 +913,11 @@ class NixlConnectorScheduler:
self
.
_reqs_need_send
[
request
.
request_id
]
=
(
time
.
perf_counter
()
+
envs
.
VLLM_NIXL_ABORT_REQUEST_TIMEOUT
)
# NOTE HMA will "mark" empty/null blocks in groups with 0s (eg SWA ones),
# trimming down after allocating for the whole sequence length. Empty
# blocks are always at the start of the list.
# Here we "unpad" blocks to send the actual remote blocks to be read.
block_ids
=
self
.
get_sw_clipped_blocks
(
block_ids
)
return
delay_free_blocks
,
dict
(
do_remote_prefill
=
True
,
...
...
@@ -857,7 +934,9 @@ class NixlConnectorScheduler:
class
NixlConnectorWorker
:
"""Implementation of Worker side methods"""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
engine_id
:
str
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
engine_id
:
str
,
kv_cache_config
:
"KVCacheConfig"
):
if
NixlWrapper
is
None
:
logger
.
error
(
"NIXL is not available"
)
raise
RuntimeError
(
"NIXL is not available"
)
...
...
@@ -875,6 +954,14 @@ class NixlConnectorWorker:
self
.
nixl_backends
=
vllm_config
.
kv_transfer_config
.
get_from_extra_config
(
"backends"
,
[
"UCX"
]
)
self
.
_is_hma_required
=
(
not
vllm_config
.
scheduler_config
.
disable_hybrid_kv_cache_manager
and
any
(
not
isinstance
(
g
.
kv_cache_spec
,
FullAttentionSpec
)
for
g
in
kv_cache_config
.
kv_cache_groups
)
)
self
.
kv_cache_config
=
kv_cache_config
# Agent.
non_ucx_backends
=
[
b
for
b
in
self
.
nixl_backends
if
b
!=
"UCX"
]
...
...
@@ -1017,10 +1104,6 @@ class NixlConnectorWorker:
self
.
model_config
=
vllm_config
.
model_config
self
.
cache_config
=
vllm_config
.
cache_config
# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
# List of block window sizes for each layer for local attention
self
.
block_window_per_layer
:
list
[
int
|
None
]
=
[]
self
.
use_mla
=
self
.
model_config
.
use_mla
# Get the attention backend from the first layer
...
...
@@ -1030,8 +1113,8 @@ class NixlConnectorWorker:
self
.
backend_name
=
self
.
attn_backend
.
get_name
()
self
.
kv_cache_layout
=
get_kv_cache_layout
()
self
.
host_buffer_kv_cache_layout
=
self
.
kv_cache_layout
logger
.
debug
(
"Detected attention backend %s"
,
self
.
backend_name
)
logger
.
debug
(
"Detected kv cache layout %s"
,
self
.
kv_cache_layout
)
logger
.
info
(
"Detected attention backend %s"
,
self
.
backend_name
)
logger
.
info
(
"Detected kv cache layout %s"
,
self
.
kv_cache_layout
)
# lazy initialized in register_kv_caches
self
.
compat_hash
:
str
|
None
=
None
...
...
@@ -1238,9 +1321,15 @@ class NixlConnectorWorker:
"remote_request_id"
:
meta
.
remote
.
request_id
,
"remote_host"
:
meta
.
remote
.
host
,
"remote_port"
:
meta
.
remote
.
port
,
"num_local_blocks"
:
len
(
meta
.
local_block_ids
),
"num_remote_blocks"
:
len
(
meta
.
remote
.
block_ids
),
"local_block_ids_sample"
:
meta
.
local_block_ids
[:
10
],
"num_local_blocks"
:
sum
(
len
(
group
)
for
group
in
meta
.
local_block_ids
),
"num_remote_blocks"
:
sum
(
len
(
group
)
for
group
in
meta
.
remote
.
block_ids
),
"local_block_ids_sample"
:
meta
.
local_block_ids
[
0
][:
10
]
if
meta
.
local_block_ids
else
[],
}
)
...
...
@@ -1301,8 +1390,10 @@ class NixlConnectorWorker:
error
=
e
,
meta
=
meta
,
)
if
req_meta
:
=
self
.
_recving_metadata
.
get
(
req_id
):
self
.
_invalid_block_ids
.
update
(
req_meta
.
local_block_ids
)
if
(
req_meta
:
=
self
.
_recving_metadata
.
get
(
req_id
)
)
and
not
self
.
_is_hma_required
:
self
.
_invalid_block_ids
.
update
(
req_meta
.
local_block_ids
[
0
])
self
.
_failed_recv_reqs
.
add
(
req_id
)
fut
.
add_done_callback
(
request_ready
)
...
...
@@ -1370,6 +1461,10 @@ class NixlConnectorWorker:
for
cache
in
cache_list
:
base_addr
=
cache
.
data_ptr
()
if
base_addr
in
seen_base_addresses
:
# NOTE (NickLucche) HMA employs memory pooling to share tensors
# across groups. This results in skipping all tensors but the ones
# pointed to by group0. Also, generally we will have more blocks
# per tensor but fewer regions.
continue
logger
.
debug
(
...
...
@@ -1457,28 +1552,6 @@ class NixlConnectorWorker:
self
.
register_local_xfer_handler
(
self
.
block_size
)
)
# TODO(mgoin): Hybrid memory allocator is currently disabled for
# models with local attention (Llama 4). Can remove this once enabled.
if
self
.
model_config
.
hf_config
.
model_type
==
"llama4"
:
from
transformers
import
Llama4TextConfig
assert
isinstance
(
self
.
model_config
.
hf_text_config
,
Llama4TextConfig
)
llama4_config
=
self
.
model_config
.
hf_text_config
no_rope_layers
=
llama4_config
.
no_rope_layers
chunk_size
=
llama4_config
.
attention_chunk_size
chunk_block_size
=
math
.
ceil
(
chunk_size
/
self
.
block_size
)
for
layer_idx
in
range
(
self
.
num_layers
):
# no_rope_layers[layer_idx] == 0 means NoPE (global)
# Any other value means RoPE (local chunked)
is_local_attention
=
no_rope_layers
[
layer_idx
]
!=
0
block_window
=
chunk_block_size
if
is_local_attention
else
None
self
.
block_window_per_layer
.
append
(
block_window
)
logger
.
debug
(
"Llama 4 block window per layer mapping: %s"
,
self
.
block_window_per_layer
,
)
assert
len
(
self
.
block_window_per_layer
)
==
self
.
num_layers
# After KV Caches registered, listen for new connections.
agent_metadata
=
NixlAgentMetadata
(
engine_id
=
self
.
engine_id
,
...
...
@@ -1767,6 +1840,11 @@ class NixlConnectorWorker:
# Num kv_heads > tp_size and P TP > D TP case, not supported
assert
not
(
tp_ratio
<
0
and
self
.
kv_topo
.
is_kv_replicated
(
remote_engine_id
))
if
self
.
_is_hma_required
:
assert
block_size_ratio
==
1
,
(
"HMA does not support different remote block size yet"
)
kv_cache_layout
=
(
self
.
kv_cache_layout
if
not
self
.
use_host_buffer
...
...
@@ -1781,6 +1859,9 @@ class NixlConnectorWorker:
"Remote is HND and local is NHD, enabled additional permute "
"on local device KV."
)
assert
not
self
.
_is_hma_required
,
(
"HMA does not support block size post processing"
)
self
.
enable_permute_local_kv
=
True
else
:
raise
RuntimeError
(
...
...
@@ -1836,11 +1917,13 @@ class NixlConnectorWorker:
assert
self
.
copy_blocks
is
not
None
local_block_ids
=
meta
.
local_physical_block_ids
# TODO (NickLucche) D2H<>H2D ops could benefit from coalescing io across groups
for
group_block_ids
in
local_block_ids
:
self
.
copy_blocks
(
self
.
host_xfer_buffers
,
self
.
device_kv_caches
,
local
_block_ids
,
local
_block_ids
,
group
_block_ids
,
group
_block_ids
,
"h2d"
,
)
if
logger
.
isEnabledFor
(
logging
.
DEBUG
):
...
...
@@ -1868,11 +1951,12 @@ class NixlConnectorWorker:
","
.
join
(
map
(
str
,
meta
.
local_physical_block_ids
)),
)
# blocking
for
group_block_ids
in
meta
.
local_physical_block_ids
:
self
.
copy_blocks
(
self
.
device_kv_caches
,
self
.
host_xfer_buffers
,
meta
.
local_physical
_block_ids
,
meta
.
local_physical
_block_ids
,
group
_block_ids
,
group
_block_ids
,
"d2h"
,
)
...
...
@@ -1973,8 +2057,9 @@ class NixlConnectorWorker:
if
not
self
.
use_mla
and
(
block_size_ratio
>
1
or
self
.
enable_permute_local_kv
):
assert
not
self
.
_is_hma_required
block_ids_for_blocksize_post_process
[
block_size_ratio
].
append
(
meta
.
local_physical_block_ids
meta
.
local_physical_block_ids
[
0
]
)
for
(
block_size_ratio
,
...
...
@@ -2106,8 +2191,9 @@ class NixlConnectorWorker:
handle: The transfer handle.
"""
# Use .get() here as the metadata cleanup is handled by get_finished()
if
meta
:
=
self
.
_recving_metadata
.
get
(
req_id
):
self
.
_invalid_block_ids
.
update
(
meta
.
local_block_ids
)
# TODO (NickLucche) handle failed transfer for HMA.
if
(
meta
:
=
self
.
_recving_metadata
.
get
(
req_id
))
and
not
self
.
_is_hma_required
:
self
.
_invalid_block_ids
.
update
(
meta
.
local_block_ids
[
0
])
self
.
nixl_wrapper
.
release_xfer_handle
(
handle
)
self
.
xfer_stats
.
record_failed_transfer
()
...
...
@@ -2230,8 +2316,8 @@ class NixlConnectorWorker:
def
_read_blocks
(
self
,
local_block_ids
:
list
[
int
]
,
remote_block_ids
:
list
[
int
]
,
local_block_ids
:
BlockIds
,
remote_block_ids
:
BlockIds
,
dst_engine_id
:
str
,
request_id
:
str
,
remote_request_id
:
str
,
...
...
@@ -2246,22 +2332,30 @@ class NixlConnectorWorker:
assert
self
.
kv_topo
is
not
None
block_size_ratio
=
self
.
kv_topo
.
block_size_ratio_from_engine_id
(
dst_engine_id
)
if
block_size_ratio
>
1
:
local_block_ids
=
self
.
get_mapped_blocks
(
np
.
asarray
(
local_block_ids
),
block_size_ratio
)
if
len
(
local_block_ids
)
>
len
(
remote_block_ids
):
# TODO (NickLucche) assume HMA is off. Change to handle multiple KV groups.
assert
not
self
.
_is_hma_required
local_block_ids0
=
local_block_ids
[
0
]
if
local_block_ids
else
[]
remote_block_ids0
=
remote_block_ids
[
0
]
local_block_ids_mapped
=
self
.
get_mapped_blocks
(
np
.
asarray
(
local_block_ids0
),
block_size_ratio
).
tolist
()
if
len
(
local_block_ids_mapped
)
>
len
(
remote_block_ids0
):
# NOTE:
# get_mapped_blocks will always expand block_ids for n times.
# ex:
# prefill block_ids with block_size as 4:
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# Local decode block_ids with block_size as 16: [1, 2, 3]
# exp
l
and
ecode block_ids with get_mapped_blocks from [1, 2, 3] to
# expand
ed d
ecode block_ids with get_mapped_blocks from [1, 2, 3] to
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
# Then we clip local to align with prefill
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] to
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
local_block_ids
=
local_block_ids
[:
len
(
remote_block_ids
)]
local_block_ids_mapped
=
local_block_ids_mapped
[
:
len
(
remote_block_ids0
)
]
local_block_ids
=
[
local_block_ids_mapped
]
if
local_block_ids_mapped
else
[]
remote_block_ids
=
[
remote_block_ids0
]
# NOTE(rob): having the staging blocks be on the READER side is
# not going to work well (since we will have to call rearrange tensors).
# after we detect the txn is complete (which means we cannot make the
...
...
@@ -2269,8 +2363,7 @@ class NixlConnectorWorker:
# then we will need to have the staging blocks on the remote side.
# NOTE(rob): according to nvidia the staging blocks are used to
# saturate IB with heterogeneous TP sizes. We should remove the staging
# blocks until we are ready.
# saturate IB with heterogeneous TP sizes.
# Number of D TP workers that will read from dst P. Propagate info
# on notification so that dst worker can wait before freeing blocks.
...
...
@@ -2278,8 +2371,8 @@ class NixlConnectorWorker:
# Full prefix cache hit: do not need to read remote blocks,
# just notify P worker that we have the blocks we need.
num_local_blocks
=
len
(
local_block_ids
)
if
num_local_blocks
==
0
:
if
len
(
local_block_ids
)
==
0
:
# A full prefix cache hit is indicated with an empty list.
agent_name
=
self
.
_remote_agents
[
dst_engine_id
][
remote_rank
]
try
:
self
.
nixl_wrapper
.
send_notif
(
agent_name
,
notif_msg
=
notif_id
)
...
...
@@ -2297,22 +2390,25 @@ class NixlConnectorWorker:
self
.
xfer_stats
.
record_failed_notification
()
return
# Partial prefix cache hit: just read uncomputed blocks.
num_remote_blocks
=
len
(
remote_block_ids
)
assert
(
len
(
remote_block_ids
)
==
len
(
local_block_ids
)
==
len
(
self
.
kv_cache_config
.
kv_cache_groups
)
)
remote_block_ids
=
list
(
remote_block_ids
)
for
i
,
remote_group
in
enumerate
(
remote_block_ids
):
num_remote_blocks
=
len
(
remote_group
)
num_local_blocks
=
len
(
local_block_ids
[
i
])
assert
num_local_blocks
<=
num_remote_blocks
# Partial prefix cache hit: just read uncomputed blocks.
if
num_local_blocks
<
num_remote_blocks
:
remote_block_ids
=
remote_
block_ids
[
-
num_local_blocks
:]
remote_block_ids
[
i
]
=
remote_
group
[
-
num_local_blocks
:]
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp
# workers will issue xfers to parts of the P worker remote kv caches.
# Get descs ids.
local_block_descs_ids
:
np
.
ndarray
remote_block_descs_ids
:
np
.
ndarray
if
not
self
.
block_window_per_layer
:
# Default case: assume global attention
remote_block_descs_ids
=
self
.
_get_block_descs_ids
(
dst_engine_id
,
remote_block_ids
,
...
...
@@ -2322,41 +2418,6 @@ class NixlConnectorWorker:
local_block_ids
,
block_size_ratio
=
block_size_ratio
,
)
else
:
# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
local_descs_list
=
[]
remote_descs_list
=
[]
for
layer_idx
,
block_window
in
enumerate
(
self
.
block_window_per_layer
):
# For each layer:
if
block_window
is
None
:
# If not chunked, we just use the
# full block lists (global attention)
layer_local_block_ids
=
local_block_ids
layer_remote_block_ids
=
remote_block_ids
else
:
# If chunked, get the last block_window blocks
layer_local_block_ids
=
local_block_ids
[
-
block_window
:]
layer_remote_block_ids
=
remote_block_ids
[
-
block_window
:]
# Get descs ids for the layer.
layer_local_desc_ids
=
self
.
_get_block_descs_ids
(
self
.
engine_id
,
layer_local_block_ids
,
layer_idx
,
block_size_ratio
=
block_size_ratio
,
)
layer_remote_desc_ids
=
self
.
_get_block_descs_ids
(
dst_engine_id
,
layer_remote_block_ids
,
layer_idx
,
)
local_descs_list
.
append
(
layer_local_desc_ids
)
remote_descs_list
.
append
(
layer_remote_desc_ids
)
local_block_descs_ids
=
np
.
concatenate
(
local_descs_list
)
remote_block_descs_ids
=
np
.
concatenate
(
remote_descs_list
)
assert
len
(
local_block_descs_ids
)
==
len
(
remote_block_descs_ids
)
...
...
@@ -2387,14 +2448,18 @@ class NixlConnectorWorker:
dst_engine_id
=
dst_engine_id
,
remote_rank
=
remote_rank
,
)
if
meta
:
=
self
.
_recving_metadata
.
get
(
request_id
):
self
.
_invalid_block_ids
.
update
(
meta
.
local_block_ids
)
if
(
meta
:
=
self
.
_recving_metadata
.
get
(
request_id
)
)
and
not
self
.
_is_hma_required
:
self
.
_invalid_block_ids
.
update
(
meta
.
local_block_ids
[
0
])
self
.
xfer_stats
.
record_failed_transfer
()
if
handle
is
not
None
:
self
.
nixl_wrapper
.
release_xfer_handle
(
handle
)
self
.
_failed_recv_reqs
.
add
(
request_id
)
def
get_mapped_blocks
(
self
,
block_ids
,
block_size_ratio
):
def
get_mapped_blocks
(
self
,
block_ids
:
np
.
ndarray
,
block_size_ratio
:
int
)
->
np
.
ndarray
:
"""
Calculates the new set of block IDs by mapping every element
in the (potentially sparse) input array.
...
...
@@ -2416,41 +2481,32 @@ class NixlConnectorWorker:
def
_get_block_descs_ids
(
self
,
engine_id
:
str
,
block_ids
:
list
[
int
],
layer_idx
:
int
|
None
=
None
,
block_ids
:
BlockIds
,
block_size_ratio
:
float
|
None
=
None
,
)
->
np
.
ndarray
:
"""
Get the descs ids for a set of block ids.
If layer_idx is provided, we use the region_ids for the given lay
er.
Otherwise, we use all regions
.
When HMA is enabled number of descriptors across kv cache groups might diff
er.
A single flattened array is returned for all groups anyway
.
"""
if
layer_idx
is
None
:
region_ids
=
np
.
arange
(
self
.
num_regions
)
else
:
assert
layer_idx
<
self
.
num_layers
if
self
.
num_layers
<
self
.
num_regions
:
# If we have more regions than layers, we assume that
# the regions are organized as [K0, V0, K1, V1, ...]
# and we select K_i and V_i
assert
2
*
self
.
num_layers
==
self
.
num_regions
region_ids
=
np
.
arange
(
2
*
layer_idx
,
2
*
layer_idx
+
2
)
else
:
# Otherwise, we assume we have MLA and select i-th layer
assert
self
.
num_layers
==
self
.
num_regions
region_ids
=
np
.
arange
(
layer_idx
,
layer_idx
+
1
)
# NOTE (NickLucche) With HMA, every kv group has the same number of layers and
# layers from different groups share the same kv tensor.
# eg block_ids=[[1, 2], [3]]->blocks [1, 2] need to be read across all regions,
# same for [3], but group0-group1 blocks will always differ (different areas).
# Therefore we can just flatten the block_ids and compute the descs ids for all
# groups at once.
num_blocks
=
self
.
dst_num_blocks
[
engine_id
]
if
block_size_ratio
is
not
None
:
num_blocks
=
int
(
num_blocks
*
block_size_ratio
)
# Compute the desc ids for each block.
region_ids
=
region_ids
[:,
None
]
block_ids
=
np
.
array
(
block_ids
)[
None
,
:]
block_ids
=
np
.
concatenate
(
block_ids
)[
None
,
:]
descs_ids
=
region_ids
*
num_blocks
+
block_ids
return
descs_ids
.
flatten
()
def
_logical_to_kernel_block_ids
(
self
,
block_ids
:
list
[
int
])
->
list
[
int
]
:
def
_logical_to_kernel_block_ids
(
self
,
block_ids
:
BlockIds
)
->
BlockIds
:
"""
Convert logical block ids to kernel physical block ids.
This is required when the logical block size (the one set by the user)
...
...
@@ -2459,13 +2515,17 @@ class NixlConnectorWorker:
if
self
.
_physical_blocks_per_logical_kv_block
==
1
:
# Noop when physical and logical block sizes are the same
return
block_ids
block_ids_np
=
np
.
array
(
block_ids
)
block_arange
=
np
.
arange
(
0
,
self
.
_physical_blocks_per_logical_kv_block
).
reshape
(
1
,
-
1
)
return
BlockTable
.
map_to_kernel_blocks
(
block_ids_np
,
self
.
_physical_blocks_per_logical_kv_block
,
block_arange
return
[
BlockTable
.
map_to_kernel_blocks
(
np
.
array
(
group
),
self
.
_physical_blocks_per_logical_kv_block
,
block_arange
,
).
tolist
()
for
group
in
block_ids
]
def
get_backend_aware_kv_block_len
(
self
,
layer_idx
:
int
)
->
int
:
"""
...
...
vllm/v1/core/kv_cache_manager.py
View file @
5b3ba94a
...
...
@@ -84,6 +84,18 @@ class KVCacheBlocks:
assert
len
(
self
.
blocks
)
==
1
,
"Only one group is supported"
return
[
block
.
block_id
for
block
in
self
.
blocks
[
0
]
if
block
.
block_hash
is
None
]
def
get_unhashed_block_ids_all_groups
(
self
)
->
list
[
list
[
int
]]:
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
# Skip padding blocks.
return
[
[
block
.
block_id
for
block
in
group
if
block
.
block_hash
is
None
and
not
block
.
is_null
]
for
group
in
self
.
blocks
]
def
new_empty
(
self
)
->
"KVCacheBlocks"
:
"""
Creates a new KVCacheBlocks instance with no blocks.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment