Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5b3ba94a
Unverified
Commit
5b3ba94a
authored
Mar 06, 2026
by
Nicolò Lucchesi
Committed by
GitHub
Mar 06, 2026
Browse files
[Core][KVConnector] Support HMA+NixlConnector (#35758)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
90f3c01f
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
669 additions
and
230 deletions
+669
-230
tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
..._connector/nixl_integration/config_sweep_accuracy_test.sh
+9
-0
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
+36
-1
tests/v1/kv_connector/nixl_integration/test_accuracy.py
tests/v1/kv_connector/nixl_integration/test_accuracy.py
+1
-0
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+122
-46
tests/v1/kv_connector/unit/test_nixl_connector_hma.py
tests/v1/kv_connector/unit/test_nixl_connector_hma.py
+203
-0
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
+3
-1
tests/v1/kv_connector/unit/utils.py
tests/v1/kv_connector/unit/utils.py
+53
-15
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+3
-0
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+227
-167
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+12
-0
No files found.
tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
View file @
5b3ba94a
...
@@ -12,6 +12,7 @@ tp_configs=(
...
@@ -12,6 +12,7 @@ tp_configs=(
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
# MLA case
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
# MLA case
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192"
# SW model
)
)
dp_ep_configs
=(
dp_ep_configs
=(
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
# MLA+P-TP1, D-DPEP=2 (TP=1)
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
# MLA+P-TP1, D-DPEP=2 (TP=1)
...
@@ -26,6 +27,14 @@ else
...
@@ -26,6 +27,14 @@ else
configs
=(
"
${
tp_configs
[@]
}
"
)
configs
=(
"
${
tp_configs
[@]
}
"
)
fi
fi
if
[[
-n
"
${
ENABLE_HMA_FLAG
:-}
"
]]
;
then
# Append ENABLE_HMA_FLAG=1 to each config in the selected array
echo
"ENABLE_HMA_FLAG is set, appending ENABLE_HMA_FLAG=1 to each config"
for
i
in
"
${
!configs[@]
}
"
;
do
configs[
$i
]=
"ENABLE_HMA_FLAG=1
${
configs
[
$i
]
}
"
done
fi
run_tests
()
{
run_tests
()
{
local
label
=
$1
local
label
=
$1
local
extra_args
=
$2
local
extra_args
=
$2
...
...
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
View file @
5b3ba94a
...
@@ -5,6 +5,12 @@ set -xe
...
@@ -5,6 +5,12 @@ set -xe
KV_BUFFER_DEVICE
=
"cuda"
# Default to cuda
KV_BUFFER_DEVICE
=
"cuda"
# Default to cuda
ATTENTION_BACKEND
=
""
# Default to empty (use vllm default)
ATTENTION_BACKEND
=
""
# Default to empty (use vllm default)
CROSS_LAYERS_BLOCKS
=
"False"
CROSS_LAYERS_BLOCKS
=
"False"
ENABLE_HMA_VAR
=
""
# Default to empty (HMA disabled by default for kv connector)
# Check for ENABLE_HMA_FLAG environment variable
if
[[
-n
"
${
ENABLE_HMA_FLAG
:-}
"
]]
;
then
ENABLE_HMA_VAR
=
"--no-disable-hybrid-kv-cache-manager"
fi
while
[[
$#
-gt
0
]]
;
do
while
[[
$#
-gt
0
]]
;
do
case
$1
in
case
$1
in
--kv_buffer_device
)
--kv_buffer_device
)
...
@@ -31,6 +37,12 @@ echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE"
...
@@ -31,6 +37,12 @@ echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE"
if
[[
-n
"
$ATTENTION_BACKEND
"
]]
;
then
if
[[
-n
"
$ATTENTION_BACKEND
"
]]
;
then
echo
"Using attention backend:
$ATTENTION_BACKEND
"
echo
"Using attention backend:
$ATTENTION_BACKEND
"
fi
fi
if
[[
-n
"
$ENABLE_HMA_VAR
"
]]
;
then
echo
"HMA (Hybrid KV Cache Manager) enabled"
fi
if
[[
-n
"
$VLLM_SERVE_EXTRA_ARGS
"
]]
;
then
echo
"vLLM serve extra args:
$VLLM_SERVE_EXTRA_ARGS
"
fi
DECODER_KV_LAYOUT
=
${
DECODER_KV_LAYOUT
:-
"HND"
}
# Default to HND, optional NHD
DECODER_KV_LAYOUT
=
${
DECODER_KV_LAYOUT
:-
"HND"
}
# Default to HND, optional NHD
if
[[
"
$DECODER_KV_LAYOUT
"
==
"NHD"
]]
;
then
if
[[
"
$DECODER_KV_LAYOUT
"
==
"NHD"
]]
;
then
...
@@ -70,6 +82,8 @@ DECODER_TP_SIZE=${DECODER_TP_SIZE:-1}
...
@@ -70,6 +82,8 @@ DECODER_TP_SIZE=${DECODER_TP_SIZE:-1}
GPU_MEMORY_UTILIZATION
=
${
GPU_MEMORY_UTILIZATION
:-
0
.2
}
GPU_MEMORY_UTILIZATION
=
${
GPU_MEMORY_UTILIZATION
:-
0
.2
}
PREFILL_BLOCK_SIZE
=
${
PREFILL_BLOCK_SIZE
:-
128
}
PREFILL_BLOCK_SIZE
=
${
PREFILL_BLOCK_SIZE
:-
128
}
DECODE_BLOCK_SIZE
=
${
DECODE_BLOCK_SIZE
:-
128
}
DECODE_BLOCK_SIZE
=
${
DECODE_BLOCK_SIZE
:-
128
}
# Comma-separated extra args for vllm serve (e.g. --max-model-len,2048)
VLLM_SERVE_EXTRA_ARGS
=
${
VLLM_SERVE_EXTRA_ARGS
:-}
# Find the git repository root directory
# Find the git repository root directory
GIT_ROOT
=
$(
git rev-parse
--show-toplevel
)
GIT_ROOT
=
$(
git rev-parse
--show-toplevel
)
...
@@ -151,14 +165,24 @@ run_tests_for_model() {
...
@@ -151,14 +165,24 @@ run_tests_for_model() {
--gpu-memory-utilization
$GPU_MEMORY_UTILIZATION
\
--gpu-memory-utilization
$GPU_MEMORY_UTILIZATION
\
--tensor-parallel-size
$PREFILLER_TP_SIZE
\
--tensor-parallel-size
$PREFILLER_TP_SIZE
\
--kv-transfer-config '
$KV_CONFIG
'"
--kv-transfer-config '
$KV_CONFIG
'"
if
[[
-n
"
$VLLM_SERVE_EXTRA_ARGS
"
]]
;
then
IFS
=
','
read
-r
-a
extra_args
<<<
"
$VLLM_SERVE_EXTRA_ARGS
"
for
arg
in
"
${
extra_args
[@]
}
"
;
do
BASE_CMD
=
"
${
BASE_CMD
}
$arg
"
done
fi
# Add attention backend config if specified
# Add attention backend config if specified
if
[[
-n
"
$ATTENTION_BACKEND
"
]]
;
then
if
[[
-n
"
$ATTENTION_BACKEND
"
]]
;
then
BASE_CMD
=
"
${
BASE_CMD
}
--attention-backend=
$ATTENTION_BACKEND
"
BASE_CMD
=
"
${
BASE_CMD
}
--attention-backend=
$ATTENTION_BACKEND
"
fi
fi
# Add HMA flag if specified
if
[[
-n
"
$ENABLE_HMA_VAR
"
]]
;
then
BASE_CMD
=
"
${
BASE_CMD
}
$ENABLE_HMA_VAR
"
fi
FULL_CMD
=
"
$BASE_CMD
"
FULL_CMD
=
"
$BASE_CMD
"
eval
"
$FULL_CMD
&"
eval
"
$FULL_CMD
&"
# Store host and port for proxy configuration
# Store host and port for proxy configuration
...
@@ -193,12 +217,23 @@ run_tests_for_model() {
...
@@ -193,12 +217,23 @@ run_tests_for_model() {
--block-size
${
DECODE_BLOCK_SIZE
}
\
--block-size
${
DECODE_BLOCK_SIZE
}
\
--gpu-memory-utilization
$GPU_MEMORY_UTILIZATION
\
--gpu-memory-utilization
$GPU_MEMORY_UTILIZATION
\
--kv-transfer-config '
$KV_CONFIG
'"
--kv-transfer-config '
$KV_CONFIG
'"
if
[[
-n
"
$VLLM_SERVE_EXTRA_ARGS
"
]]
;
then
IFS
=
','
read
-r
-a
extra_args
<<<
"
$VLLM_SERVE_EXTRA_ARGS
"
for
arg
in
"
${
extra_args
[@]
}
"
;
do
BASE_CMD
=
"
${
BASE_CMD
}
$arg
"
done
fi
# Add attention backend config if specified
# Add attention backend config if specified
if
[[
-n
"
$ATTENTION_BACKEND
"
]]
;
then
if
[[
-n
"
$ATTENTION_BACKEND
"
]]
;
then
BASE_CMD
=
"
${
BASE_CMD
}
--attention-backend=
$ATTENTION_BACKEND
"
BASE_CMD
=
"
${
BASE_CMD
}
--attention-backend=
$ATTENTION_BACKEND
"
fi
fi
# Add HMA flag if specified
if
[[
-n
"
$ENABLE_HMA_VAR
"
]]
;
then
BASE_CMD
=
"
${
BASE_CMD
}
$ENABLE_HMA_VAR
"
fi
# DP-EP attention mode
# DP-EP attention mode
if
[[
-z
"
$DP_EP
"
]]
;
then
if
[[
-z
"
$DP_EP
"
]]
;
then
BASE_CMD
=
"
${
BASE_CMD
}
--tensor-parallel-size
$DECODER_TP_SIZE
"
BASE_CMD
=
"
${
BASE_CMD
}
--tensor-parallel-size
$DECODER_TP_SIZE
"
...
...
tests/v1/kv_connector/nixl_integration/test_accuracy.py
View file @
5b3ba94a
...
@@ -17,6 +17,7 @@ EXPECTED_VALUES = {
...
@@ -17,6 +17,7 @@ EXPECTED_VALUES = {
"deepseek-ai/deepseek-vl2-small"
:
0.59
,
"deepseek-ai/deepseek-vl2-small"
:
0.59
,
"deepseek-ai/deepseek-vl2-tiny"
:
0.19
,
"deepseek-ai/deepseek-vl2-tiny"
:
0.19
,
"deepseek-ai/DeepSeek-V2-Lite-Chat"
:
0.65
,
"deepseek-ai/DeepSeek-V2-Lite-Chat"
:
0.65
,
"google/gemma-3-4b-it"
:
0.74
,
}
}
SIMPLE_PROMPT
=
(
SIMPLE_PROMPT
=
(
...
...
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
5b3ba94a
...
@@ -59,7 +59,12 @@ from vllm.v1.request import RequestStatus
...
@@ -59,7 +59,12 @@ from vllm.v1.request import RequestStatus
from
vllm.v1.worker.kv_connector_model_runner_mixin
import
KVConnectorModelRunnerMixin
from
vllm.v1.worker.kv_connector_model_runner_mixin
import
KVConnectorModelRunnerMixin
from
vllm.v1.worker.utils
import
AttentionGroup
from
vllm.v1.worker.utils
import
AttentionGroup
from
.utils
import
create_request
,
create_scheduler
,
create_vllm_config
from
.utils
import
(
create_request
,
create_scheduler
,
create_vllm_config
,
make_kv_cache_config
,
)
@
pytest
.
fixture
(
scope
=
"module"
,
autouse
=
True
)
@
pytest
.
fixture
(
scope
=
"module"
,
autouse
=
True
)
...
@@ -263,7 +268,7 @@ def test_basic_interface():
...
@@ -263,7 +268,7 @@ def test_basic_interface():
req_meta
=
kv_connector_metadata
.
reqs_to_recv
[
request_id
]
req_meta
=
kv_connector_metadata
.
reqs_to_recv
[
request_id
]
for
block_id
,
block
in
zip
(
for
block_id
,
block
in
zip
(
req_meta
.
local_block_ids
,
req_meta
.
local_block_ids
[
0
]
,
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
[
0
].
req_to_blocks
[
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
[
0
].
req_to_blocks
[
request_id
request_id
],
],
...
@@ -327,7 +332,9 @@ def test_kv_transfer_handshake(dist_init):
...
@@ -327,7 +332,9 @@ def test_kv_transfer_handshake(dist_init):
# Prefill connector will register KV cache to populate proper handshake
# Prefill connector will register KV cache to populate proper handshake
# metadata.
# metadata.
prefill_connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
prefill_connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
kv_cache_shape
=
FlashAttentionBackend
.
get_kv_cache_shape
(
kv_cache_shape
=
FlashAttentionBackend
.
get_kv_cache_shape
(
num_blocks
=
2
,
block_size
=
16
,
num_kv_heads
=
4
,
head_size
=
64
num_blocks
=
2
,
block_size
=
16
,
num_kv_heads
=
4
,
head_size
=
64
)
)
...
@@ -367,13 +374,17 @@ def test_kv_transfer_handshake(dist_init):
...
@@ -367,13 +374,17 @@ def test_kv_transfer_handshake(dist_init):
do_remote_decode
=
True
,
do_remote_decode
=
True
,
)
)
request
.
status
=
RequestStatus
.
FINISHED_LENGTH_CAPPED
request
.
status
=
RequestStatus
.
FINISHED_LENGTH_CAPPED
delay
,
kv_connector_metadata
=
scheduler
.
get_kv_connector
().
request_finished
(
delay
,
kv_connector_metadata
=
(
request
,
[
0
,
1
,
2
]
scheduler
.
get_kv_connector
().
request_finished_all_groups
(
request
,
([
0
,
1
,
2
],)
)
)
)
assert
delay
assert
delay
# Decode connector will be able to create handshake with the prefill connector.
# Decode connector will be able to create handshake with the prefill connector.
decode_connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
decode_connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
decode_connector
.
register_kv_caches
(
kv_caches
)
decode_connector
.
register_kv_caches
(
kv_caches
)
# Here we are testing the retrieval of NIXLAgentMetadata.
# Here we are testing the retrieval of NIXLAgentMetadata.
...
@@ -404,9 +415,16 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
...
@@ -404,9 +415,16 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
REMOTE_ENGINE_ID
=
"remote_engine"
REMOTE_ENGINE_ID
=
"remote_engine"
def
__init__
(
def
__init__
(
self
,
*
args
,
hand_shake_latency
:
float
=
1.8
,
kv_cache_layout
=
"HND"
,
**
kwargs
self
,
*
args
,
hand_shake_latency
:
float
=
1.8
,
kv_cache_layout
=
"HND"
,
kv_cache_config
=
None
,
**
kwargs
,
):
):
super
().
__init__
(
*
args
,
**
kwargs
)
if
kv_cache_config
is
None
:
kv_cache_config
=
make_kv_cache_config
(
block_size
=
16
)
super
().
__init__
(
*
args
,
kv_cache_config
=
kv_cache_config
,
**
kwargs
)
self
.
_hand_shake_latency
=
hand_shake_latency
self
.
_hand_shake_latency
=
hand_shake_latency
self
.
kv_cache_layout
=
kv_cache_layout
self
.
kv_cache_layout
=
kv_cache_layout
# Mock register_kv_caches attribute needed for tests that do not call it.
# Mock register_kv_caches attribute needed for tests that do not call it.
...
@@ -507,7 +525,9 @@ class TestNixlHandshake:
...
@@ -507,7 +525,9 @@ class TestNixlHandshake:
request_id
=
"req_id"
request_id
=
"req_id"
# Test worker role in decode server.
# Test worker role in decode server.
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
)
...
@@ -528,13 +548,15 @@ class TestNixlHandshake:
...
@@ -528,13 +548,15 @@ class TestNixlHandshake:
num_xfers
-=
1
num_xfers
-=
1
metadata
.
add_new_req_to_recv
(
metadata
.
add_new_req_to_recv
(
request_id
=
request_id
,
request_id
=
request_id
,
local_block_ids
=
[
num_xfers
+
1
,
num_xfers
+
2
,
num_xfers
+
3
],
local_block_ids
=
(
[
num_xfers
+
1
,
num_xfers
+
2
,
num_xfers
+
3
],
),
kv_transfer_params
=
{
kv_transfer_params
=
{
"remote_block_ids"
:
[
"remote_block_ids"
:
(
num_xfers
+
4
,
[
num_xfers
+
5
,
num_xfers
+
4
,
num_xfers
+
6
,
num_xfers
+
5
,
],
num_xfers
+
6
,
],
),
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_request_id"
:
f
"prefill-
{
request_id
}
"
,
"remote_request_id"
:
f
"prefill-
{
request_id
}
"
,
"remote_host"
:
"localhost"
,
"remote_host"
:
"localhost"
,
...
@@ -594,16 +616,18 @@ class TestNixlHandshake:
...
@@ -594,16 +616,18 @@ class TestNixlHandshake:
vllm_config
.
parallel_config
.
tensor_parallel_size
=
decode_tp_size
vllm_config
.
parallel_config
.
tensor_parallel_size
=
decode_tp_size
# Test worker role in decode server.
# Test worker role in decode server.
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
vllm_config
,
connector
.
engine_id
)
)
metadata
=
NixlConnectorMetadata
()
metadata
=
NixlConnectorMetadata
()
metadata
.
add_new_req_to_recv
(
metadata
.
add_new_req_to_recv
(
request_id
=
"id"
,
request_id
=
"id"
,
local_block_ids
=
[
1
,
2
,
3
],
local_block_ids
=
(
[
1
,
2
,
3
],
),
kv_transfer_params
=
{
kv_transfer_params
=
{
"remote_block_ids"
:
[
4
,
5
,
6
],
"remote_block_ids"
:
(
[
4
,
5
,
6
],
),
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_request_id"
:
"prefill-id"
,
"remote_request_id"
:
"prefill-id"
,
"remote_host"
:
"localhost"
,
"remote_host"
:
"localhost"
,
...
@@ -652,7 +676,9 @@ class TestNixlHandshake:
...
@@ -652,7 +676,9 @@ class TestNixlHandshake:
local_tp_size
=
1
local_tp_size
=
1
vllm_config
.
parallel_config
.
tensor_parallel_size
=
local_tp_size
vllm_config
.
parallel_config
.
tensor_parallel_size
=
local_tp_size
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
)
...
@@ -717,8 +743,12 @@ class TestNixlHandshake:
...
@@ -717,8 +743,12 @@ class TestNixlHandshake:
p_tp_size
=
2
p_tp_size
=
2
# Build two separate connectors/workers to emulate P TP=2 ranks.
# Build two separate connectors/workers to emulate P TP=2 ranks.
conn_p0
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
conn_p0
=
NixlConnector
(
conn_p1
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
conn_p1
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
conn_p0
.
connector_worker
=
FakeNixlConnectorWorker
(
conn_p0
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
conn_p0
.
engine_id
,
hand_shake_latency
=
0
vllm_config
,
conn_p0
.
engine_id
,
hand_shake_latency
=
0
)
)
...
@@ -815,7 +845,9 @@ class TestNixlHandshake:
...
@@ -815,7 +845,9 @@ class TestNixlHandshake:
vllm_config
=
create_vllm_config
()
vllm_config
=
create_vllm_config
()
# Test worker role in decode server.
# Test worker role in decode server.
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
vllm_config
,
connector
.
engine_id
)
)
...
@@ -827,9 +859,9 @@ class TestNixlHandshake:
...
@@ -827,9 +859,9 @@ class TestNixlHandshake:
for
i
in
range
(
total_reqs
):
for
i
in
range
(
total_reqs
):
metadata
.
add_new_req_to_recv
(
metadata
.
add_new_req_to_recv
(
request_id
=
f
"id_
{
i
}
"
,
request_id
=
f
"id_
{
i
}
"
,
local_block_ids
=
[
1
,
2
,
3
],
local_block_ids
=
(
[
1
,
2
,
3
],
),
kv_transfer_params
=
{
kv_transfer_params
=
{
"remote_block_ids"
:
[
4
,
5
,
6
],
"remote_block_ids"
:
(
[
4
,
5
,
6
],
),
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_request_id"
:
f
"prefill-id-
{
i
}
"
,
"remote_request_id"
:
f
"prefill-id-
{
i
}
"
,
"remote_host"
:
"localhost"
,
"remote_host"
:
"localhost"
,
...
@@ -884,7 +916,9 @@ class TestNixlHandshake:
...
@@ -884,7 +916,9 @@ class TestNixlHandshake:
return_value
=
2
,
return_value
=
2
,
):
):
# Initialize connector and worker (with fake NIXL wrapper)
# Initialize connector and worker (with fake NIXL wrapper)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
)
...
@@ -934,7 +968,9 @@ class TestNixlHandshake:
...
@@ -934,7 +968,9 @@ class TestNixlHandshake:
return_value
=
2
,
return_value
=
2
,
):
):
# Initialize connector and worker (with fake NIXL wrapper)
# Initialize connector and worker (with fake NIXL wrapper)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
vllm_config
,
connector
.
engine_id
,
connector
.
engine_id
,
...
@@ -979,7 +1015,9 @@ def test_kv_connector_stats(default_vllm_config, dist_init):
...
@@ -979,7 +1015,9 @@ def test_kv_connector_stats(default_vllm_config, dist_init):
vllm_config
=
create_vllm_config
()
vllm_config
=
create_vllm_config
()
# Test worker role in decode server.
# Test worker role in decode server.
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
)
...
@@ -993,9 +1031,9 @@ def test_kv_connector_stats(default_vllm_config, dist_init):
...
@@ -993,9 +1031,9 @@ def test_kv_connector_stats(default_vllm_config, dist_init):
metadata
=
NixlConnectorMetadata
()
metadata
=
NixlConnectorMetadata
()
metadata
.
add_new_req_to_recv
(
metadata
.
add_new_req_to_recv
(
request_id
=
request_id
,
request_id
=
request_id
,
local_block_ids
=
[
1
,
2
,
3
],
local_block_ids
=
(
[
1
,
2
,
3
],
),
kv_transfer_params
=
{
kv_transfer_params
=
{
"remote_block_ids"
:
[
4
,
5
,
6
],
"remote_block_ids"
:
(
[
4
,
5
,
6
],
),
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_request_id"
:
f
"prefill-
{
request_id
}
"
,
"remote_request_id"
:
f
"prefill-
{
request_id
}
"
,
"remote_host"
:
"localhost"
,
"remote_host"
:
"localhost"
,
...
@@ -1448,7 +1486,9 @@ def test_register_kv_caches(
...
@@ -1448,7 +1486,9 @@ def test_register_kv_caches(
mock_get_attn_backend
.
return_value
=
backend_cls
mock_get_attn_backend
.
return_value
=
backend_cls
# Create connector
# Create connector
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
)
...
@@ -1676,7 +1716,9 @@ def test_kv_buffer_to_nixl_memory_types(
...
@@ -1676,7 +1716,9 @@ def test_kv_buffer_to_nixl_memory_types(
),
),
):
# noqa: E501
):
# noqa: E501
# Create connector and replace its worker with a fake one for isolation
# Create connector and replace its worker with a fake one for isolation
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
# Verify get_reg_descs was called with the correct memory_type
# Verify get_reg_descs was called with the correct memory_type
assert
connector
.
connector_worker
.
kv_buffer_device
==
kv_buffer_device
assert
connector
.
connector_worker
.
kv_buffer_device
==
kv_buffer_device
...
@@ -1692,9 +1734,15 @@ def test_shutdown_cleans_up_resources(default_vllm_config, dist_init):
...
@@ -1692,9 +1734,15 @@ def test_shutdown_cleans_up_resources(default_vllm_config, dist_init):
vllm_config
=
create_vllm_config
()
vllm_config
=
create_vllm_config
()
scheduler
=
NixlConnectorScheduler
(
scheduler
=
NixlConnectorScheduler
(
vllm_config
,
vllm_config
.
kv_transfer_config
.
engine_id
vllm_config
,
vllm_config
.
kv_transfer_config
.
engine_id
,
make_kv_cache_config
(
block_size
=
16
),
)
worker
=
NixlConnectorWorker
(
vllm_config
,
vllm_config
.
kv_transfer_config
.
engine_id
,
make_kv_cache_config
(
block_size
=
16
),
)
)
worker
=
NixlConnectorWorker
(
vllm_config
,
vllm_config
.
kv_transfer_config
.
engine_id
)
nixl_wrapper
=
worker
.
nixl_wrapper
nixl_wrapper
=
worker
.
nixl_wrapper
with
(
with
(
...
@@ -1756,7 +1804,9 @@ def test_aborted_request_removed_from_worker_in_batch(default_vllm_config, dist_
...
@@ -1756,7 +1804,9 @@ def test_aborted_request_removed_from_worker_in_batch(default_vllm_config, dist_
scheduler
=
create_scheduler
(
vllm_config
)
scheduler
=
create_scheduler
(
vllm_config
)
# KVConnector Worker in P
# KVConnector Worker in P
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
)
...
@@ -1875,12 +1925,14 @@ class FailingNixlWrapper(FakeNixlWrapper):
...
@@ -1875,12 +1925,14 @@ class FailingNixlWrapper(FakeNixlWrapper):
(
"transfer_exception"
,
{
"fail_transfer_exception"
:
True
},
True
),
(
"transfer_exception"
,
{
"fail_transfer_exception"
:
True
},
True
),
],
],
)
)
@
pytest
.
mark
.
parametrize
(
"enable_hma"
,
[
False
,
True
])
def
test_transfer_failure_logging
(
def
test_transfer_failure_logging
(
default_vllm_config
,
default_vllm_config
,
dist_init
,
dist_init
,
failure_type
,
failure_type
,
wrapper_config
,
wrapper_config
,
needs_get_finished
,
needs_get_finished
,
enable_hma
,
):
):
"""Test that transfer failures are logged with structured context.
"""Test that transfer failures are logged with structured context.
...
@@ -1897,9 +1949,16 @@ def test_transfer_failure_logging(
...
@@ -1897,9 +1949,16 @@ def test_transfer_failure_logging(
vllm_config
=
create_vllm_config
()
vllm_config
=
create_vllm_config
()
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
,
hma_enabled
=
enable_hma
),
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0.0
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0.0
,
kv_cache_config
=
connector
.
_kv_cache_config
,
)
)
# Configure FailingNixlWrapper to fail in the specified way
# Configure FailingNixlWrapper to fail in the specified way
...
@@ -1910,8 +1969,17 @@ def test_transfer_failure_logging(
...
@@ -1910,8 +1969,17 @@ def test_transfer_failure_logging(
# For notification_failed, we need empty local blocks
# For notification_failed, we need empty local blocks
# (full cache hit path to trigger send_notif)
# (full cache hit path to trigger send_notif)
local_blocks
=
[]
if
failure_type
==
"notification_failed"
else
[
10
,
11
,
12
]
local_blocks
:
tuple
[()]
|
tuple
[
list
[
int
],
...]
remote_blocks
=
[
20
,
21
,
22
]
if
enable_hma
:
# HMA enabled: multiple groups (FA + SW)
local_blocks
=
(
()
if
failure_type
==
"notification_failed"
else
([
10
,
11
,
12
],
[
13
,
14
])
)
remote_blocks
=
[[
20
,
21
,
22
],
[
23
,
24
]]
else
:
# HMA disabled: single group
local_blocks
=
()
if
failure_type
==
"notification_failed"
else
([
10
,
11
,
12
],)
remote_blocks
=
[[
20
,
21
,
22
]]
metadata
=
NixlConnectorMetadata
()
metadata
=
NixlConnectorMetadata
()
metadata
.
add_new_req_to_recv
(
metadata
.
add_new_req_to_recv
(
...
@@ -2007,7 +2075,9 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
...
@@ -2007,7 +2075,9 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
"""Test that handshake failures mark blocks invalid and return via get_finished."""
"""Test that handshake failures mark blocks invalid and return via get_finished."""
vllm_config
=
create_vllm_config
()
vllm_config
=
create_vllm_config
()
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0.1
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0.1
)
)
...
@@ -2017,9 +2087,9 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
...
@@ -2017,9 +2087,9 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
metadata
=
NixlConnectorMetadata
()
metadata
=
NixlConnectorMetadata
()
metadata
.
add_new_req_to_recv
(
metadata
.
add_new_req_to_recv
(
request_id
=
request_id
,
request_id
=
request_id
,
local_block_ids
=
[
1
,
2
,
3
],
local_block_ids
=
(
[
1
,
2
,
3
],
),
kv_transfer_params
=
{
kv_transfer_params
=
{
"remote_block_ids"
:
[
4
,
5
,
6
],
"remote_block_ids"
:
(
[
4
,
5
,
6
],
),
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_request_id"
:
f
"prefill-
{
request_id
}
"
,
"remote_request_id"
:
f
"prefill-
{
request_id
}
"
,
"remote_host"
:
"localhost"
,
"remote_host"
:
"localhost"
,
...
@@ -2058,7 +2128,9 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init)
...
@@ -2058,7 +2128,9 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init)
and return via get_finished."""
and return via get_finished."""
vllm_config
=
create_vllm_config
()
vllm_config
=
create_vllm_config
()
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
)
...
@@ -2068,9 +2140,9 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init)
...
@@ -2068,9 +2140,9 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init)
metadata
=
NixlConnectorMetadata
()
metadata
=
NixlConnectorMetadata
()
metadata
.
add_new_req_to_recv
(
metadata
.
add_new_req_to_recv
(
request_id
=
request_id
,
request_id
=
request_id
,
local_block_ids
=
[
7
,
8
,
9
],
local_block_ids
=
(
[
7
,
8
,
9
],
),
kv_transfer_params
=
{
kv_transfer_params
=
{
"remote_block_ids"
:
[
10
,
11
,
12
],
"remote_block_ids"
:
(
[
10
,
11
,
12
],
),
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_request_id"
:
f
"prefill-
{
request_id
}
"
,
"remote_request_id"
:
f
"prefill-
{
request_id
}
"
,
"remote_host"
:
"localhost"
,
"remote_host"
:
"localhost"
,
...
@@ -2154,7 +2226,9 @@ def test_compatibility_hash_validation(
...
@@ -2154,7 +2226,9 @@ def test_compatibility_hash_validation(
"enforce_handshake_compat"
:
enforce_handshake_compat
"enforce_handshake_compat"
:
enforce_handshake_compat
},
},
)
)
decode_connector
=
NixlConnector
(
local_vllm_config
,
KVConnectorRole
.
WORKER
)
decode_connector
=
NixlConnector
(
local_vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
decode_worker
=
decode_connector
.
connector_worker
decode_worker
=
decode_connector
.
connector_worker
kv_cache_shape
=
decode_worker
.
attn_backend
.
get_kv_cache_shape
(
kv_cache_shape
=
decode_worker
.
attn_backend
.
get_kv_cache_shape
(
num_blocks
=
2
,
block_size
=
16
,
num_kv_heads
=
4
,
head_size
=
64
num_blocks
=
2
,
block_size
=
16
,
num_kv_heads
=
4
,
head_size
=
64
...
@@ -2267,7 +2341,9 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
...
@@ -2267,7 +2341,9 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
model
=
"facebook/opt-125m"
,
model
=
"facebook/opt-125m"
,
block_size
=
16
,
block_size
=
16
,
)
)
decode_connector
=
NixlConnector
(
local_vllm_config
,
KVConnectorRole
.
WORKER
)
decode_connector
=
NixlConnector
(
local_vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
decode_worker
=
decode_connector
.
connector_worker
decode_worker
=
decode_connector
.
connector_worker
backend
=
get_current_attn_backend
(
local_vllm_config
)
backend
=
get_current_attn_backend
(
local_vllm_config
)
...
...
tests/v1/kv_connector/unit/test_nixl_connector_hma.py
0 → 100644
View file @
5b3ba94a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for NixlConnectorScheduler sw_sizes calculation with HMA."""
from
unittest.mock
import
patch
import
pytest
from
vllm
import
LLM
,
SamplingParams
from
vllm.config
import
KVTransferConfig
from
vllm.v1.core.single_type_kv_cache_manager
import
(
FullAttentionManager
,
SlidingWindowManager
,
)
from
.utils
import
(
create_vllm_config
,
make_kv_cache_config
,
)
@
pytest
.
mark
.
cpu_test
@
pytest
.
mark
.
parametrize
(
"hma_enabled,expected_sw_sizes"
,
[
# HMA enabled: FullAttentionSpec (0) + SlidingWindowSpec (2048/16=128)
(
True
,
[
0
,
128
+
1
]),
# HMA disabled: only FullAttentionSpec (0)
(
False
,
[
0
]),
],
)
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform"
)
def
test_sw_sizes
(
mock_platform
,
hma_enabled
,
expected_sw_sizes
):
"""Test sw_sizes is correctly computed based on HMA enabled/disabled."""
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
NixlConnectorScheduler
,
)
mock_platform
.
device_type
=
"cpu"
block_size
=
16
vllm_config
=
create_vllm_config
(
block_size
=
block_size
)
# SW 2048 tokens=>128 blocks
kv_cache_config
=
make_kv_cache_config
(
block_size
=
block_size
,
hma_enabled
=
hma_enabled
,
sw_size
=
2048
)
scheduler
=
NixlConnectorScheduler
(
vllm_config
=
vllm_config
,
engine_id
=
"test-engine"
,
kv_cache_config
=
kv_cache_config
,
)
# in number of blocks
assert
scheduler
.
blocks_per_sw
==
expected_sw_sizes
,
(
f
"Expected sw_sizes=
{
expected_sw_sizes
}
, got
{
scheduler
.
blocks_per_sw
}
"
)
@
pytest
.
mark
.
cpu_test
def
test_logical_to_kernel_block_ids_with_hma
():
"""Test _logical_to_kernel_block_ids expands blocks when HMA is enabled.
When HMA is enabled, the logical block size may differ from the kernel
block size. Each logical block maps to multiple kernel blocks.
"""
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
NixlConnectorWorker
,
)
# Create a mock worker with just the required attributes
# (use __new__ to skip __init__)
worker
=
object
.
__new__
(
NixlConnectorWorker
)
# Simulate HMA scenario: logical block size = 32, kernel block size = 16
# So each logical block maps to 2 kernel blocks eg [0]->[0,1]
worker
.
_physical_blocks_per_logical_kv_block
=
2
# Test conversion: FA + SW group
logical_block_ids
=
[[
0
,
1
,
2
],
[
3
,
4
]]
kernel_block_ids
=
worker
.
_logical_to_kernel_block_ids
(
logical_block_ids
)
expected_kernel_block_ids
=
[[
0
,
1
,
2
,
3
,
4
,
5
],
[
6
,
7
,
8
,
9
]]
assert
kernel_block_ids
==
expected_kernel_block_ids
,
(
f
"Expected
{
expected_kernel_block_ids
}
, got
{
kernel_block_ids
}
"
)
@
pytest
.
mark
.
parametrize
(
"model_name, sw_size"
,
[(
"google/gemma-3-1b-it"
,
512
)])
def
test_fewer_blocks_with_hma
(
monkeypatch
,
model_name
,
sw_size
):
"""Test that a prefill instance returns fewer "remote blocks" for the SWA groups
when sequence exceeds the sliding window.
"""
kv_transfer_config
=
KVTransferConfig
(
kv_connector
=
"NixlConnector"
,
kv_role
=
"kv_both"
,
)
block_size
=
16
llm_kwargs
=
{
"model"
:
model_name
,
"enforce_eager"
:
True
,
"gpu_memory_utilization"
:
0.5
,
"kv_transfer_config"
:
kv_transfer_config
,
"max_model_len"
:
2048
,
# NOTE: Make sure HMA is enabled
"disable_hybrid_kv_cache_manager"
:
False
,
"max_num_batched_tokens"
:
1024
,
"enable_prefix_caching"
:
False
,
"block_size"
:
block_size
,
}
monkeypatch
.
setenv
(
"VLLM_ENABLE_V1_MULTIPROCESSING"
,
"0"
)
def
run_hma_test
(
llm
:
LLM
):
remote_prefill_opts
=
{
"do_remote_decode"
:
True
,
"do_remote_prefill"
:
False
,
"remote_engine_id"
:
None
,
"remote_block_ids"
:
None
,
"remote_host"
:
None
,
"remote_port"
:
None
,
}
# Simulate sidecar request
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
1
,
extra_args
=
{
"kv_transfer_params"
:
remote_prefill_opts
},
)
scheduler
=
llm
.
llm_engine
.
engine_core
.
engine_core
.
scheduler
kv_managers
=
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
# HMA enabled with FA + SWA groups
assert
len
(
kv_managers
)
>
2
for
kv_manager
in
kv_managers
:
assert
isinstance
(
kv_manager
,
(
SlidingWindowManager
,
FullAttentionManager
))
req_to_blocks
=
kv_managers
[
0
].
req_to_blocks
assert
len
(
req_to_blocks
)
==
0
# Process some request with length exceeding the sliding window
outputs
=
llm
.
generate
([
"hi"
*
1401
],
sampling_params
)
kv_params
=
outputs
[
0
].
kv_transfer_params
# +1 to account for overlapping window across blocks.
expected_num_remote_blocks
=
sw_size
//
block_size
+
1
remote_block_ids
=
kv_params
[
"remote_block_ids"
]
assert
(
len
(
remote_block_ids
[
0
])
==
expected_num_remote_blocks
<
len
(
remote_block_ids
[
-
1
])
)
for
group_block_ids
in
remote_block_ids
[:
-
1
]:
assert
len
(
group_block_ids
)
==
expected_num_remote_blocks
def
run_test_and_cleanup
():
llm
=
LLM
(
**
llm_kwargs
)
try
:
run_hma_test
(
llm
)
finally
:
llm
.
llm_engine
.
engine_core
.
shutdown
()
run_test_and_cleanup
()
@
pytest
.
mark
.
cpu_test
def
test_nixl_metadata_hma_block_ids_structure
():
"""
Test that NixlConnectorMetadata correctly stores block IDs for multiple
KV cache groups when HMA is enabled.
"""
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
NixlConnectorMetadata
,
)
metadata
=
NixlConnectorMetadata
()
# Add request with block IDs for 2 groups (FA + SW)
fa_blocks
=
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
]
# 8 blocks for FA
sw_blocks
=
[
8
,
9
,
10
,
11
]
# 4 blocks for SW (clipped)
metadata
.
add_new_req_to_recv
(
request_id
=
"test-req-hma"
,
local_block_ids
=
(
fa_blocks
,
sw_blocks
),
kv_transfer_params
=
{
"remote_block_ids"
:
([
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
],
[
18
,
19
,
20
,
21
]),
"remote_engine_id"
:
"remote-engine"
,
"remote_request_id"
:
"prefill-test-req-hma"
,
"remote_host"
:
"localhost"
,
"remote_port"
:
1234
,
"tp_size"
:
1
,
},
)
assert
"test-req-hma"
in
metadata
.
reqs_to_recv
req_meta
=
metadata
.
reqs_to_recv
[
"test-req-hma"
]
# Verify local block IDs structure
assert
len
(
req_meta
.
local_block_ids
)
==
2
assert
list
(
req_meta
.
local_block_ids
[
0
])
==
fa_blocks
assert
list
(
req_meta
.
local_block_ids
[
1
])
==
sw_blocks
# Verify remote block IDs structure
assert
req_meta
.
remote
is
not
None
assert
len
(
req_meta
.
remote
.
block_ids
)
==
2
assert
list
(
req_meta
.
remote
.
block_ids
[
0
])
==
[
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
]
assert
list
(
req_meta
.
remote
.
block_ids
[
1
])
==
[
18
,
19
,
20
,
21
]
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
View file @
5b3ba94a
...
@@ -208,7 +208,9 @@ def test_prefix_cache_lifecycle():
...
@@ -208,7 +208,9 @@ def test_prefix_cache_lifecycle():
# Ensure we send all block ids, including the partial blocks,
# Ensure we send all block ids, including the partial blocks,
# even if there is a cache hit.
# even if there is a cache hit.
assert
len
(
kv_transfer_params
[
"remote_block_ids"
])
==
(
NUM_EXTERNAL_FULL_BLOCKS
+
1
)
# remote_block_ids is BlockIds (tuple of lists); sum block counts across groups.
num_remote_blocks
=
sum
(
len
(
g
)
for
g
in
kv_transfer_params
[
"remote_block_ids"
])
assert
num_remote_blocks
==
(
NUM_EXTERNAL_FULL_BLOCKS
+
1
)
# STEP (2): Ensure it is freed.
# STEP (2): Ensure it is freed.
scheduler_output
=
scheduler
.
schedule
()
scheduler_output
=
scheduler
.
schedule
()
...
...
tests/v1/kv_connector/unit/utils.py
View file @
5b3ba94a
...
@@ -36,6 +36,7 @@ from vllm.v1.kv_cache_interface import (
...
@@ -36,6 +36,7 @@ from vllm.v1.kv_cache_interface import (
FullAttentionSpec
,
FullAttentionSpec
,
KVCacheConfig
,
KVCacheConfig
,
KVCacheGroupSpec
,
KVCacheGroupSpec
,
SlidingWindowSpec
,
)
)
from
vllm.v1.outputs
import
KVConnectorOutput
,
ModelRunnerOutput
from
vllm.v1.outputs
import
KVConnectorOutput
,
ModelRunnerOutput
from
vllm.v1.request
import
Request
from
vllm.v1.request
import
Request
...
@@ -142,24 +143,26 @@ def create_vllm_config(
...
@@ -142,24 +143,26 @@ def create_vllm_config(
def
create_scheduler
(
def
create_scheduler
(
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
num_blocks
:
int
=
10000
,
num_blocks
:
int
=
10000
,
kv_cache_config
:
KVCacheConfig
|
None
=
None
,
)
->
Scheduler
:
)
->
Scheduler
:
"""Initialize Scheduler For Testing."""
"""Initialize Scheduler For Testing."""
block_size
=
vllm_config
.
cache_config
.
block_size
block_size
=
vllm_config
.
cache_config
.
block_size
kv_cache_config
=
KVCacheConfig
(
if
kv_cache_config
is
None
:
num_blocks
=
num_blocks
,
# A large number of blocks to hold all requests
kv_cache_config
=
KVCacheConfig
(
kv_cache_tensors
=
[],
num_blocks
=
num_blocks
,
# A large number of blocks to hold all requests
kv_cache_groups
=
[
kv_cache_tensors
=
[],
KVCacheGroupSpec
(
kv_cache_groups
=
[
[
"layer"
],
KVCacheGroupSpec
(
FullAttentionSpec
(
[
"layer"
],
block_size
=
block_size
,
FullAttentionSpec
(
num_kv_heads
=
1
,
block_size
=
block_size
,
head_size
=
1
,
num_kv_heads
=
1
,
dtype
=
torch
.
float32
,
head_size
=
1
,
),
dtype
=
torch
.
float32
,
)
),
],
)
)
],
)
vllm_config
.
cache_config
.
num_gpu_blocks
=
num_blocks
vllm_config
.
cache_config
.
num_gpu_blocks
=
num_blocks
return
Scheduler
(
return
Scheduler
(
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
...
@@ -412,3 +415,38 @@ KVConnectorFactory.register_connector(
...
@@ -412,3 +415,38 @@ KVConnectorFactory.register_connector(
KVConnectorFactory
.
register_connector
(
KVConnectorFactory
.
register_connector
(
"MockKVConnector"
,
__name__
,
MockKVConnector
.
__name__
"MockKVConnector"
,
__name__
,
MockKVConnector
.
__name__
)
)
def
make_kv_cache_config
(
block_size
:
int
,
hma_enabled
:
bool
=
False
,
sw_size
:
int
=
128
,
num_blocks
:
int
=
100
,
)
->
KVCacheConfig
:
kv_cache_groups
=
[
KVCacheGroupSpec
(
[
"layer0"
,
"layer2"
],
FullAttentionSpec
(
block_size
=
block_size
,
num_kv_heads
=
4
,
head_size
=
16
,
dtype
=
torch
.
float16
,
),
)
]
if
hma_enabled
:
kv_cache_groups
.
append
(
KVCacheGroupSpec
(
[
"layer1"
,
"layer3"
],
SlidingWindowSpec
(
block_size
=
block_size
,
num_kv_heads
=
4
,
head_size
=
16
,
dtype
=
torch
.
float16
,
sliding_window
=
sw_size
,
),
)
)
return
KVCacheConfig
(
num_blocks
=
num_blocks
,
kv_cache_tensors
=
[],
kv_cache_groups
=
kv_cache_groups
)
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
5b3ba94a
...
@@ -24,6 +24,9 @@ if TYPE_CHECKING:
...
@@ -24,6 +24,9 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
EngineId
=
str
EngineId
=
str
# block ids as returned by the hybrid KV cache manager. list[list[int]] are allow
# mutability and are for connector internal use only.
BlockIds
=
tuple
[
list
[
int
],
...]
|
list
[
list
[
int
]]
def
get_kv_connector_cache_layout
():
def
get_kv_connector_cache_layout
():
...
...
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
5b3ba94a
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
import
contextlib
import
contextlib
import
copy
import
copy
import
logging
import
logging
import
math
import
os
import
os
import
queue
import
queue
import
sys
import
sys
...
@@ -24,6 +23,7 @@ import zmq
...
@@ -24,6 +23,7 @@ import zmq
from
vllm
import
envs
from
vllm
import
envs
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.utils
import
(
from
vllm.distributed.kv_transfer.kv_connector.utils
import
(
BlockIds
,
EngineId
,
EngineId
,
TpKVTopology
,
TpKVTopology
,
get_current_attn_backend
,
get_current_attn_backend
,
...
@@ -38,6 +38,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
...
@@ -38,6 +38,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorHandshakeMetadata
,
KVConnectorHandshakeMetadata
,
KVConnectorMetadata
,
KVConnectorMetadata
,
KVConnectorRole
,
KVConnectorRole
,
SupportsHMA
,
)
)
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
(
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
(
KVConnectorPromMetrics
,
KVConnectorPromMetrics
,
...
@@ -53,10 +54,12 @@ from vllm.distributed.parallel_state import (
...
@@ -53,10 +54,12 @@ from vllm.distributed.parallel_state import (
from
vllm.forward_context
import
ForwardContext
from
vllm.forward_context
import
ForwardContext
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.network_utils
import
make_zmq_path
,
make_zmq_socket
from
vllm.utils.network_utils
import
make_zmq_path
,
make_zmq_socket
from
vllm.v1.attention.backend
import
AttentionBackend
,
AttentionMetadata
from
vllm.v1.attention.backend
import
AttentionBackend
,
AttentionMetadata
from
vllm.v1.attention.backends.utils
import
get_kv_cache_layout
from
vllm.v1.attention.backends.utils
import
get_kv_cache_layout
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.kv_cache_interface
import
FullAttentionSpec
,
MambaSpec
,
SlidingWindowSpec
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.block_table
import
BlockTable
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -205,6 +208,7 @@ def compute_nixl_compatibility_hash(
...
@@ -205,6 +208,7 @@ def compute_nixl_compatibility_hash(
model_config
=
vllm_config
.
model_config
model_config
=
vllm_config
.
model_config
cache_config
=
vllm_config
.
cache_config
cache_config
=
vllm_config
.
cache_config
is_hma_enabled
=
not
vllm_config
.
scheduler_config
.
disable_hybrid_kv_cache_manager
factors
=
{
factors
=
{
# Version compatibility
# Version compatibility
...
@@ -220,6 +224,7 @@ def compute_nixl_compatibility_hash(
...
@@ -220,6 +224,7 @@ def compute_nixl_compatibility_hash(
"attn_backend_name"
:
attn_backend_name
,
"attn_backend_name"
:
attn_backend_name
,
"cache_dtype"
:
str
(
cache_config
.
cache_dtype
),
"cache_dtype"
:
str
(
cache_config
.
cache_dtype
),
"cross_layers_blocks"
:
cross_layers_blocks
,
"cross_layers_blocks"
:
cross_layers_blocks
,
"is_hma_enabled"
:
is_hma_enabled
,
}
}
compat_hash
=
hash_factors
(
factors
)
compat_hash
=
hash_factors
(
factors
)
...
@@ -238,7 +243,7 @@ def compute_nixl_compatibility_hash(
...
@@ -238,7 +243,7 @@ def compute_nixl_compatibility_hash(
@
dataclass
@
dataclass
class
RemoteMeta
:
class
RemoteMeta
:
block_ids
:
list
[
int
]
block_ids
:
BlockIds
host
:
str
host
:
str
port
:
int
port
:
int
engine_id
:
str
engine_id
:
str
...
@@ -247,9 +252,9 @@ class RemoteMeta:
...
@@ -247,9 +252,9 @@ class RemoteMeta:
@
dataclass
@
dataclass
class
ReqMeta
:
class
ReqMeta
:
local_block_ids
:
list
[
int
]
local_block_ids
:
BlockIds
# To be used when logical block size does not match the kernel block size
# To be used when logical block size does not match the kernel block size
local_physical_block_ids
:
list
[
int
]
local_physical_block_ids
:
BlockIds
tp_size
:
int
tp_size
:
int
remote
:
RemoteMeta
|
None
=
None
remote
:
RemoteMeta
|
None
=
None
...
@@ -264,7 +269,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
...
@@ -264,7 +269,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
def
_add_new_req
(
def
_add_new_req
(
self
,
self
,
local_block_ids
:
list
[
int
]
,
local_block_ids
:
BlockIds
,
kv_transfer_params
:
dict
[
str
,
Any
],
kv_transfer_params
:
dict
[
str
,
Any
],
)
->
ReqMeta
:
)
->
ReqMeta
:
return
ReqMeta
(
return
ReqMeta
(
...
@@ -277,7 +282,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
...
@@ -277,7 +282,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
def
add_new_req_to_save
(
def
add_new_req_to_save
(
self
,
self
,
request_id
:
ReqId
,
request_id
:
ReqId
,
local_block_ids
:
list
[
int
]
,
local_block_ids
:
BlockIds
,
kv_transfer_params
:
dict
[
str
,
Any
],
kv_transfer_params
:
dict
[
str
,
Any
],
):
):
self
.
reqs_to_save
[
request_id
]
=
self
.
_add_new_req
(
self
.
reqs_to_save
[
request_id
]
=
self
.
_add_new_req
(
...
@@ -287,7 +292,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
...
@@ -287,7 +292,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
def
add_new_req_to_recv
(
def
add_new_req_to_recv
(
self
,
self
,
request_id
:
ReqId
,
request_id
:
ReqId
,
local_block_ids
:
list
[
int
]
,
local_block_ids
:
BlockIds
,
kv_transfer_params
:
dict
[
str
,
Any
],
kv_transfer_params
:
dict
[
str
,
Any
],
):
):
req
=
self
.
_add_new_req
(
local_block_ids
,
kv_transfer_params
)
req
=
self
.
_add_new_req
(
local_block_ids
,
kv_transfer_params
)
...
@@ -301,7 +306,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
...
@@ -301,7 +306,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
self
.
reqs_to_recv
[
request_id
]
=
req
self
.
reqs_to_recv
[
request_id
]
=
req
class
NixlConnector
(
KVConnectorBase_V1
):
class
NixlConnector
(
KVConnectorBase_V1
,
SupportsHMA
):
@
property
@
property
def
prefer_cross_layer_blocks
(
self
)
->
bool
:
def
prefer_cross_layer_blocks
(
self
)
->
bool
:
backend
=
get_current_attn_backend
(
self
.
_vllm_config
)
backend
=
get_current_attn_backend
(
self
.
_vllm_config
)
...
@@ -326,22 +331,27 @@ class NixlConnector(KVConnectorBase_V1):
...
@@ -326,22 +331,27 @@ class NixlConnector(KVConnectorBase_V1):
self
,
self
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
role
:
KVConnectorRole
,
role
:
KVConnectorRole
,
kv_cache_config
:
"KVCacheConfig
| None"
=
None
,
kv_cache_config
:
"KVCacheConfig
"
,
):
):
super
().
__init__
(
vllm_config
,
role
,
kv_cache_config
)
super
().
__init__
(
vllm_config
,
role
,
kv_cache_config
)
assert
vllm_config
.
kv_transfer_config
is
not
None
assert
vllm_config
.
kv_transfer_config
is
not
None
assert
vllm_config
.
kv_transfer_config
.
engine_id
is
not
None
assert
vllm_config
.
kv_transfer_config
.
engine_id
is
not
None
for
group
in
kv_cache_config
.
kv_cache_groups
:
if
isinstance
(
group
.
kv_cache_spec
,
MambaSpec
):
raise
ValueError
(
"NixlConnector does not support Mamba models."
)
self
.
engine_id
:
EngineId
=
vllm_config
.
kv_transfer_config
.
engine_id
self
.
engine_id
:
EngineId
=
vllm_config
.
kv_transfer_config
.
engine_id
self
.
kv_transfer_config
=
vllm_config
.
kv_transfer_config
self
.
kv_transfer_config
=
vllm_config
.
kv_transfer_config
if
role
==
KVConnectorRole
.
SCHEDULER
:
if
role
==
KVConnectorRole
.
SCHEDULER
:
self
.
connector_scheduler
:
NixlConnectorScheduler
|
None
=
(
self
.
connector_scheduler
:
NixlConnectorScheduler
|
None
=
(
NixlConnectorScheduler
(
vllm_config
,
self
.
engine_id
)
NixlConnectorScheduler
(
vllm_config
,
self
.
engine_id
,
kv_cache_config
)
)
)
self
.
connector_worker
:
NixlConnectorWorker
|
None
=
None
self
.
connector_worker
:
NixlConnectorWorker
|
None
=
None
elif
role
==
KVConnectorRole
.
WORKER
:
elif
role
==
KVConnectorRole
.
WORKER
:
self
.
connector_scheduler
=
None
self
.
connector_scheduler
=
None
self
.
connector_worker
=
NixlConnectorWorker
(
vllm_config
,
self
.
engine_id
)
self
.
connector_worker
=
NixlConnectorWorker
(
vllm_config
,
self
.
engine_id
,
kv_cache_config
)
############################################################
############################################################
# Class Methods
# Class Methods
...
@@ -392,10 +402,10 @@ class NixlConnector(KVConnectorBase_V1):
...
@@ -392,10 +402,10 @@ class NixlConnector(KVConnectorBase_V1):
assert
self
.
connector_scheduler
is
not
None
assert
self
.
connector_scheduler
is
not
None
return
self
.
connector_scheduler
.
build_connector_meta
(
scheduler_output
)
return
self
.
connector_scheduler
.
build_connector_meta
(
scheduler_output
)
def
request_finished
(
def
request_finished
_all_groups
(
self
,
self
,
request
:
"Request"
,
request
:
"Request"
,
block_ids
:
list
[
int
],
block_ids
:
tuple
[
list
[
int
],
...],
)
->
tuple
[
bool
,
dict
[
str
,
Any
]
|
None
]:
)
->
tuple
[
bool
,
dict
[
str
,
Any
]
|
None
]:
assert
self
.
connector_scheduler
is
not
None
assert
self
.
connector_scheduler
is
not
None
return
self
.
connector_scheduler
.
request_finished
(
request
,
block_ids
)
return
self
.
connector_scheduler
.
request_finished
(
request
,
block_ids
)
...
@@ -518,10 +528,13 @@ class NixlConnector(KVConnectorBase_V1):
...
@@ -518,10 +528,13 @@ class NixlConnector(KVConnectorBase_V1):
class
NixlConnectorScheduler
:
class
NixlConnectorScheduler
:
"""Implementation of Scheduler side methods"""
"""Implementation of Scheduler side methods"""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
engine_id
:
str
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
engine_id
:
str
,
kv_cache_config
:
"KVCacheConfig"
):
self
.
vllm_config
=
vllm_config
self
.
vllm_config
=
vllm_config
self
.
block_size
=
vllm_config
.
cache_config
.
block_size
self
.
block_size
=
vllm_config
.
cache_config
.
block_size
self
.
engine_id
:
EngineId
=
engine_id
self
.
engine_id
:
EngineId
=
engine_id
self
.
kv_cache_config
=
kv_cache_config
self
.
side_channel_host
=
envs
.
VLLM_NIXL_SIDE_CHANNEL_HOST
self
.
side_channel_host
=
envs
.
VLLM_NIXL_SIDE_CHANNEL_HOST
self
.
side_channel_port
=
(
self
.
side_channel_port
=
(
envs
.
VLLM_NIXL_SIDE_CHANNEL_PORT
envs
.
VLLM_NIXL_SIDE_CHANNEL_PORT
...
@@ -534,8 +547,18 @@ class NixlConnectorScheduler:
...
@@ -534,8 +547,18 @@ class NixlConnectorScheduler:
self
.
use_host_buffer
=
(
self
.
use_host_buffer
=
(
vllm_config
.
kv_transfer_config
.
kv_buffer_device
==
"cpu"
vllm_config
.
kv_transfer_config
.
kv_buffer_device
==
"cpu"
)
)
self
.
_is_hma_required
=
(
not
vllm_config
.
scheduler_config
.
disable_hybrid_kv_cache_manager
# Also handle unlikely SW-only model case instead of checking num_groups>1.
and
any
(
not
isinstance
(
g
.
kv_cache_spec
,
FullAttentionSpec
)
for
g
in
kv_cache_config
.
kv_cache_groups
)
)
logger
.
info
(
"Initializing NIXL Scheduler %s"
,
engine_id
)
logger
.
info
(
"Initializing NIXL Scheduler %s"
,
engine_id
)
if
vllm_config
.
scheduler_config
.
disable_hybrid_kv_cache_manager
:
logger
.
info
(
"Hybrid Memory Allocator is enabled with NIXL"
)
# Background thread for handling new handshake requests.
# Background thread for handling new handshake requests.
self
.
_nixl_handshake_listener_t
:
threading
.
Thread
|
None
=
None
self
.
_nixl_handshake_listener_t
:
threading
.
Thread
|
None
=
None
...
@@ -545,7 +568,7 @@ class NixlConnectorScheduler:
...
@@ -545,7 +568,7 @@ class NixlConnectorScheduler:
# Requests that need to start recv/send.
# Requests that need to start recv/send.
# New requests are added by update_state_after_alloc in
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
# the scheduler. Used to make metadata passed to Worker.
self
.
_reqs_need_recv
:
dict
[
ReqId
,
tuple
[
Request
,
list
[
int
]
]]
=
{}
self
.
_reqs_need_recv
:
dict
[
ReqId
,
tuple
[
Request
,
BlockIds
]]
=
{}
self
.
_reqs_need_save
:
dict
[
ReqId
,
Request
]
=
{}
self
.
_reqs_need_save
:
dict
[
ReqId
,
Request
]
=
{}
# Reqs to send and their expiration time
# Reqs to send and their expiration time
self
.
_reqs_need_send
:
dict
[
ReqId
,
float
]
=
{}
self
.
_reqs_need_send
:
dict
[
ReqId
,
float
]
=
{}
...
@@ -554,12 +577,54 @@ class NixlConnectorScheduler:
...
@@ -554,12 +577,54 @@ class NixlConnectorScheduler:
# remote prefill or aborted.
# remote prefill or aborted.
self
.
_reqs_not_processed
:
set
[
ReqId
]
=
set
()
self
.
_reqs_not_processed
:
set
[
ReqId
]
=
set
()
# Gather Sliding Window sizes for each kv cache group (if any) in number of
# blocks per KV cache group. This is used to clip the local attention window.
sw_sizes_tokens
:
list
[
tuple
[
int
,
int
]]
=
[
(
g
.
kv_cache_spec
.
sliding_window
,
g
.
kv_cache_spec
.
block_size
)
if
isinstance
(
g
.
kv_cache_spec
,
SlidingWindowSpec
)
else
(
0
,
self
.
block_size
)
for
g
in
kv_cache_config
.
kv_cache_groups
]
# cdiv(n_tokens, block_size) gives blocks/window; add 1 to conservatively
# account for boundary overlap eg window isn't fully aligned with blocks.
self
.
blocks_per_sw
=
[
cdiv
(
n_tokens
,
block_size
)
+
1
if
n_tokens
else
0
for
n_tokens
,
block_size
in
sw_sizes_tokens
]
def
shutdown
(
self
):
def
shutdown
(
self
):
self
.
_stop_event
.
set
()
self
.
_stop_event
.
set
()
if
self
.
_nixl_handshake_listener_t
is
not
None
:
if
self
.
_nixl_handshake_listener_t
is
not
None
:
self
.
_nixl_handshake_listener_t
.
join
()
self
.
_nixl_handshake_listener_t
.
join
()
self
.
_nixl_handshake_listener_t
=
None
self
.
_nixl_handshake_listener_t
=
None
def
get_sw_clipped_blocks
(
self
,
block_ids
:
BlockIds
)
->
BlockIds
:
"""
Clip the number of blocks to the sliding window size for each kv cache group
that employs SWA.
This is necessary because the KV Cache manager initially allocates blocks for
the entire sequence length, and successively cleans up blocks that are outside
the window prior to the `request_finished_all_groups` hook.
"""
if
len
(
block_ids
)
==
0
or
not
self
.
_is_hma_required
:
# No blocks to clip eg Full prefix cache hit or not a hybrid model.
return
block_ids
# NOTE (NickLucche) This logic is currently handled at the connector level
# because offloading connectors might want to receive the whole sequence even
# for SWA groups. We will abstract this logic once the interface is more stable
assert
len
(
block_ids
)
==
len
(
self
.
blocks_per_sw
),
(
"Number of KV cache groups must match"
)
# For non-SWA groups, blocks_per_sw is 0 so we return all block_ids unchanged
return
tuple
(
[
blocks
[
-
self
.
blocks_per_sw
[
i
]
:]
if
self
.
blocks_per_sw
[
i
]
>
0
else
blocks
for
i
,
blocks
in
enumerate
(
block_ids
)
]
)
def
set_xfer_handshake_metadata
(
def
set_xfer_handshake_metadata
(
self
,
metadata
:
dict
[
int
,
KVConnectorHandshakeMetadata
]
self
,
metadata
:
dict
[
int
,
KVConnectorHandshakeMetadata
]
)
->
None
:
)
->
None
:
...
@@ -707,12 +772,18 @@ class NixlConnectorScheduler:
...
@@ -707,12 +772,18 @@ class NixlConnectorScheduler:
# If remote_blocks and num_external_tokens = 0, we have
# If remote_blocks and num_external_tokens = 0, we have
# a full prefix cache hit on the D worker. We need to call
# a full prefix cache hit on the D worker. We need to call
# send_notif in _read_blocks to free the memory on the P.
# send_notif in _read_blocks to free the memory on the P.
local_block_ids
=
(
blocks
.
get_unhashed_block_ids
()
unhashed_local_block_ids
:
BlockIds
=
(
blocks
.
get_unhashed_block_ids_all_groups
()
if
num_external_tokens
>
0
if
num_external_tokens
>
0
else
[]
else
()
)
)
# Get unhashed blocks to pull from remote.
local_block_ids
=
self
.
get_sw_clipped_blocks
(
unhashed_local_block_ids
)
# Get unhashed blocks to pull from remote. Mind that a full prefix
# cache hit is indicated with an empty list.
self
.
_reqs_need_recv
[
request
.
request_id
]
=
(
self
.
_reqs_need_recv
[
request
.
request_id
]
=
(
request
,
request
,
local_block_ids
,
local_block_ids
,
...
@@ -753,9 +824,10 @@ class NixlConnectorScheduler:
...
@@ -753,9 +824,10 @@ class NixlConnectorScheduler:
req
=
req_to_save
req
=
req_to_save
assert
req
.
kv_transfer_params
is
not
None
assert
req
.
kv_transfer_params
is
not
None
clipped_block_id_groups
=
self
.
get_sw_clipped_blocks
(
new_block_id_groups
)
meta
.
add_new_req_to_save
(
meta
.
add_new_req_to_save
(
request_id
=
req_id
,
request_id
=
req_id
,
local_block_ids
=
new
_block_id_groups
[
0
]
,
local_block_ids
=
clipped
_block_id_groups
,
kv_transfer_params
=
req
.
kv_transfer_params
,
kv_transfer_params
=
req
.
kv_transfer_params
,
)
)
assert
scheduler_output
.
num_scheduled_tokens
is
not
None
assert
scheduler_output
.
num_scheduled_tokens
is
not
None
...
@@ -786,7 +858,7 @@ class NixlConnectorScheduler:
...
@@ -786,7 +858,7 @@ class NixlConnectorScheduler:
def
request_finished
(
def
request_finished
(
self
,
self
,
request
:
"Request"
,
request
:
"Request"
,
block_ids
:
list
[
int
]
,
block_ids
:
BlockIds
,
)
->
tuple
[
bool
,
dict
[
str
,
Any
]
|
None
]:
)
->
tuple
[
bool
,
dict
[
str
,
Any
]
|
None
]:
"""
"""
Once a request is finished, determine whether request blocks
Once a request is finished, determine whether request blocks
...
@@ -828,7 +900,7 @@ class NixlConnectorScheduler:
...
@@ -828,7 +900,7 @@ class NixlConnectorScheduler:
# TODO: check whether block_ids actually ever be 0. If not we could
# TODO: check whether block_ids actually ever be 0. If not we could
# remove the conditional below
# remove the conditional below
delay_free_blocks
=
len
(
block_ids
)
>
0
delay_free_blocks
=
any
(
len
(
group
)
>
0
for
group
in
block_ids
)
if
delay_free_blocks
:
if
delay_free_blocks
:
# Prefill request on remote. It will be read from D upon completion
# Prefill request on remote. It will be read from D upon completion
...
@@ -841,6 +913,11 @@ class NixlConnectorScheduler:
...
@@ -841,6 +913,11 @@ class NixlConnectorScheduler:
self
.
_reqs_need_send
[
request
.
request_id
]
=
(
self
.
_reqs_need_send
[
request
.
request_id
]
=
(
time
.
perf_counter
()
+
envs
.
VLLM_NIXL_ABORT_REQUEST_TIMEOUT
time
.
perf_counter
()
+
envs
.
VLLM_NIXL_ABORT_REQUEST_TIMEOUT
)
)
# NOTE HMA will "mark" empty/null blocks in groups with 0s (eg SWA ones),
# trimming down after allocating for the whole sequence length. Empty
# blocks are always at the start of the list.
# Here we "unpad" blocks to send the actual remote blocks to be read.
block_ids
=
self
.
get_sw_clipped_blocks
(
block_ids
)
return
delay_free_blocks
,
dict
(
return
delay_free_blocks
,
dict
(
do_remote_prefill
=
True
,
do_remote_prefill
=
True
,
...
@@ -857,7 +934,9 @@ class NixlConnectorScheduler:
...
@@ -857,7 +934,9 @@ class NixlConnectorScheduler:
class
NixlConnectorWorker
:
class
NixlConnectorWorker
:
"""Implementation of Worker side methods"""
"""Implementation of Worker side methods"""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
engine_id
:
str
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
engine_id
:
str
,
kv_cache_config
:
"KVCacheConfig"
):
if
NixlWrapper
is
None
:
if
NixlWrapper
is
None
:
logger
.
error
(
"NIXL is not available"
)
logger
.
error
(
"NIXL is not available"
)
raise
RuntimeError
(
"NIXL is not available"
)
raise
RuntimeError
(
"NIXL is not available"
)
...
@@ -875,6 +954,14 @@ class NixlConnectorWorker:
...
@@ -875,6 +954,14 @@ class NixlConnectorWorker:
self
.
nixl_backends
=
vllm_config
.
kv_transfer_config
.
get_from_extra_config
(
self
.
nixl_backends
=
vllm_config
.
kv_transfer_config
.
get_from_extra_config
(
"backends"
,
[
"UCX"
]
"backends"
,
[
"UCX"
]
)
)
self
.
_is_hma_required
=
(
not
vllm_config
.
scheduler_config
.
disable_hybrid_kv_cache_manager
and
any
(
not
isinstance
(
g
.
kv_cache_spec
,
FullAttentionSpec
)
for
g
in
kv_cache_config
.
kv_cache_groups
)
)
self
.
kv_cache_config
=
kv_cache_config
# Agent.
# Agent.
non_ucx_backends
=
[
b
for
b
in
self
.
nixl_backends
if
b
!=
"UCX"
]
non_ucx_backends
=
[
b
for
b
in
self
.
nixl_backends
if
b
!=
"UCX"
]
...
@@ -1017,10 +1104,6 @@ class NixlConnectorWorker:
...
@@ -1017,10 +1104,6 @@ class NixlConnectorWorker:
self
.
model_config
=
vllm_config
.
model_config
self
.
model_config
=
vllm_config
.
model_config
self
.
cache_config
=
vllm_config
.
cache_config
self
.
cache_config
=
vllm_config
.
cache_config
# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
# List of block window sizes for each layer for local attention
self
.
block_window_per_layer
:
list
[
int
|
None
]
=
[]
self
.
use_mla
=
self
.
model_config
.
use_mla
self
.
use_mla
=
self
.
model_config
.
use_mla
# Get the attention backend from the first layer
# Get the attention backend from the first layer
...
@@ -1030,8 +1113,8 @@ class NixlConnectorWorker:
...
@@ -1030,8 +1113,8 @@ class NixlConnectorWorker:
self
.
backend_name
=
self
.
attn_backend
.
get_name
()
self
.
backend_name
=
self
.
attn_backend
.
get_name
()
self
.
kv_cache_layout
=
get_kv_cache_layout
()
self
.
kv_cache_layout
=
get_kv_cache_layout
()
self
.
host_buffer_kv_cache_layout
=
self
.
kv_cache_layout
self
.
host_buffer_kv_cache_layout
=
self
.
kv_cache_layout
logger
.
debug
(
"Detected attention backend %s"
,
self
.
backend_name
)
logger
.
info
(
"Detected attention backend %s"
,
self
.
backend_name
)
logger
.
debug
(
"Detected kv cache layout %s"
,
self
.
kv_cache_layout
)
logger
.
info
(
"Detected kv cache layout %s"
,
self
.
kv_cache_layout
)
# lazy initialized in register_kv_caches
# lazy initialized in register_kv_caches
self
.
compat_hash
:
str
|
None
=
None
self
.
compat_hash
:
str
|
None
=
None
...
@@ -1238,9 +1321,15 @@ class NixlConnectorWorker:
...
@@ -1238,9 +1321,15 @@ class NixlConnectorWorker:
"remote_request_id"
:
meta
.
remote
.
request_id
,
"remote_request_id"
:
meta
.
remote
.
request_id
,
"remote_host"
:
meta
.
remote
.
host
,
"remote_host"
:
meta
.
remote
.
host
,
"remote_port"
:
meta
.
remote
.
port
,
"remote_port"
:
meta
.
remote
.
port
,
"num_local_blocks"
:
len
(
meta
.
local_block_ids
),
"num_local_blocks"
:
sum
(
"num_remote_blocks"
:
len
(
meta
.
remote
.
block_ids
),
len
(
group
)
for
group
in
meta
.
local_block_ids
"local_block_ids_sample"
:
meta
.
local_block_ids
[:
10
],
),
"num_remote_blocks"
:
sum
(
len
(
group
)
for
group
in
meta
.
remote
.
block_ids
),
"local_block_ids_sample"
:
meta
.
local_block_ids
[
0
][:
10
]
if
meta
.
local_block_ids
else
[],
}
}
)
)
...
@@ -1301,8 +1390,10 @@ class NixlConnectorWorker:
...
@@ -1301,8 +1390,10 @@ class NixlConnectorWorker:
error
=
e
,
error
=
e
,
meta
=
meta
,
meta
=
meta
,
)
)
if
req_meta
:
=
self
.
_recving_metadata
.
get
(
req_id
):
if
(
self
.
_invalid_block_ids
.
update
(
req_meta
.
local_block_ids
)
req_meta
:
=
self
.
_recving_metadata
.
get
(
req_id
)
)
and
not
self
.
_is_hma_required
:
self
.
_invalid_block_ids
.
update
(
req_meta
.
local_block_ids
[
0
])
self
.
_failed_recv_reqs
.
add
(
req_id
)
self
.
_failed_recv_reqs
.
add
(
req_id
)
fut
.
add_done_callback
(
request_ready
)
fut
.
add_done_callback
(
request_ready
)
...
@@ -1370,6 +1461,10 @@ class NixlConnectorWorker:
...
@@ -1370,6 +1461,10 @@ class NixlConnectorWorker:
for
cache
in
cache_list
:
for
cache
in
cache_list
:
base_addr
=
cache
.
data_ptr
()
base_addr
=
cache
.
data_ptr
()
if
base_addr
in
seen_base_addresses
:
if
base_addr
in
seen_base_addresses
:
# NOTE (NickLucche) HMA employs memory pooling to share tensors
# across groups. This results in skipping all tensors but the ones
# pointed to by group0. Also, generally we will have more blocks
# per tensor but fewer regions.
continue
continue
logger
.
debug
(
logger
.
debug
(
...
@@ -1457,28 +1552,6 @@ class NixlConnectorWorker:
...
@@ -1457,28 +1552,6 @@ class NixlConnectorWorker:
self
.
register_local_xfer_handler
(
self
.
block_size
)
self
.
register_local_xfer_handler
(
self
.
block_size
)
)
)
# TODO(mgoin): Hybrid memory allocator is currently disabled for
# models with local attention (Llama 4). Can remove this once enabled.
if
self
.
model_config
.
hf_config
.
model_type
==
"llama4"
:
from
transformers
import
Llama4TextConfig
assert
isinstance
(
self
.
model_config
.
hf_text_config
,
Llama4TextConfig
)
llama4_config
=
self
.
model_config
.
hf_text_config
no_rope_layers
=
llama4_config
.
no_rope_layers
chunk_size
=
llama4_config
.
attention_chunk_size
chunk_block_size
=
math
.
ceil
(
chunk_size
/
self
.
block_size
)
for
layer_idx
in
range
(
self
.
num_layers
):
# no_rope_layers[layer_idx] == 0 means NoPE (global)
# Any other value means RoPE (local chunked)
is_local_attention
=
no_rope_layers
[
layer_idx
]
!=
0
block_window
=
chunk_block_size
if
is_local_attention
else
None
self
.
block_window_per_layer
.
append
(
block_window
)
logger
.
debug
(
"Llama 4 block window per layer mapping: %s"
,
self
.
block_window_per_layer
,
)
assert
len
(
self
.
block_window_per_layer
)
==
self
.
num_layers
# After KV Caches registered, listen for new connections.
# After KV Caches registered, listen for new connections.
agent_metadata
=
NixlAgentMetadata
(
agent_metadata
=
NixlAgentMetadata
(
engine_id
=
self
.
engine_id
,
engine_id
=
self
.
engine_id
,
...
@@ -1767,6 +1840,11 @@ class NixlConnectorWorker:
...
@@ -1767,6 +1840,11 @@ class NixlConnectorWorker:
# Num kv_heads > tp_size and P TP > D TP case, not supported
# Num kv_heads > tp_size and P TP > D TP case, not supported
assert
not
(
tp_ratio
<
0
and
self
.
kv_topo
.
is_kv_replicated
(
remote_engine_id
))
assert
not
(
tp_ratio
<
0
and
self
.
kv_topo
.
is_kv_replicated
(
remote_engine_id
))
if
self
.
_is_hma_required
:
assert
block_size_ratio
==
1
,
(
"HMA does not support different remote block size yet"
)
kv_cache_layout
=
(
kv_cache_layout
=
(
self
.
kv_cache_layout
self
.
kv_cache_layout
if
not
self
.
use_host_buffer
if
not
self
.
use_host_buffer
...
@@ -1781,6 +1859,9 @@ class NixlConnectorWorker:
...
@@ -1781,6 +1859,9 @@ class NixlConnectorWorker:
"Remote is HND and local is NHD, enabled additional permute "
"Remote is HND and local is NHD, enabled additional permute "
"on local device KV."
"on local device KV."
)
)
assert
not
self
.
_is_hma_required
,
(
"HMA does not support block size post processing"
)
self
.
enable_permute_local_kv
=
True
self
.
enable_permute_local_kv
=
True
else
:
else
:
raise
RuntimeError
(
raise
RuntimeError
(
...
@@ -1836,13 +1917,15 @@ class NixlConnectorWorker:
...
@@ -1836,13 +1917,15 @@ class NixlConnectorWorker:
assert
self
.
copy_blocks
is
not
None
assert
self
.
copy_blocks
is
not
None
local_block_ids
=
meta
.
local_physical_block_ids
local_block_ids
=
meta
.
local_physical_block_ids
self
.
copy_blocks
(
# TODO (NickLucche) D2H<>H2D ops could benefit from coalescing io across groups
self
.
host_xfer_buffers
,
for
group_block_ids
in
local_block_ids
:
self
.
device_kv_caches
,
self
.
copy_blocks
(
local_block_ids
,
self
.
host_xfer_buffers
,
local_block_ids
,
self
.
device_kv_caches
,
"h2d"
,
group_block_ids
,
)
group_block_ids
,
"h2d"
,
)
if
logger
.
isEnabledFor
(
logging
.
DEBUG
):
if
logger
.
isEnabledFor
(
logging
.
DEBUG
):
logger
.
debug
(
logger
.
debug
(
"synced recved kv of request[%s] to device kv buffer,"
"synced recved kv of request[%s] to device kv buffer,"
...
@@ -1868,13 +1951,14 @@ class NixlConnectorWorker:
...
@@ -1868,13 +1951,14 @@ class NixlConnectorWorker:
","
.
join
(
map
(
str
,
meta
.
local_physical_block_ids
)),
","
.
join
(
map
(
str
,
meta
.
local_physical_block_ids
)),
)
)
# blocking
# blocking
self
.
copy_blocks
(
for
group_block_ids
in
meta
.
local_physical_block_ids
:
self
.
device_kv_caches
,
self
.
copy_blocks
(
self
.
host_xfer_buffers
,
self
.
device_kv_caches
,
meta
.
local_physical_block_ids
,
self
.
host_xfer_buffers
,
meta
.
local_physical_block_ids
,
group_block_ids
,
"d2h"
,
group_block_ids
,
)
"d2h"
,
)
def
post_process_device_kv_on_receive
(
def
post_process_device_kv_on_receive
(
self
,
self
,
...
@@ -1973,8 +2057,9 @@ class NixlConnectorWorker:
...
@@ -1973,8 +2057,9 @@ class NixlConnectorWorker:
if
not
self
.
use_mla
and
(
if
not
self
.
use_mla
and
(
block_size_ratio
>
1
or
self
.
enable_permute_local_kv
block_size_ratio
>
1
or
self
.
enable_permute_local_kv
):
):
assert
not
self
.
_is_hma_required
block_ids_for_blocksize_post_process
[
block_size_ratio
].
append
(
block_ids_for_blocksize_post_process
[
block_size_ratio
].
append
(
meta
.
local_physical_block_ids
meta
.
local_physical_block_ids
[
0
]
)
)
for
(
for
(
block_size_ratio
,
block_size_ratio
,
...
@@ -2106,8 +2191,9 @@ class NixlConnectorWorker:
...
@@ -2106,8 +2191,9 @@ class NixlConnectorWorker:
handle: The transfer handle.
handle: The transfer handle.
"""
"""
# Use .get() here as the metadata cleanup is handled by get_finished()
# Use .get() here as the metadata cleanup is handled by get_finished()
if
meta
:
=
self
.
_recving_metadata
.
get
(
req_id
):
# TODO (NickLucche) handle failed transfer for HMA.
self
.
_invalid_block_ids
.
update
(
meta
.
local_block_ids
)
if
(
meta
:
=
self
.
_recving_metadata
.
get
(
req_id
))
and
not
self
.
_is_hma_required
:
self
.
_invalid_block_ids
.
update
(
meta
.
local_block_ids
[
0
])
self
.
nixl_wrapper
.
release_xfer_handle
(
handle
)
self
.
nixl_wrapper
.
release_xfer_handle
(
handle
)
self
.
xfer_stats
.
record_failed_transfer
()
self
.
xfer_stats
.
record_failed_transfer
()
...
@@ -2230,8 +2316,8 @@ class NixlConnectorWorker:
...
@@ -2230,8 +2316,8 @@ class NixlConnectorWorker:
def
_read_blocks
(
def
_read_blocks
(
self
,
self
,
local_block_ids
:
list
[
int
]
,
local_block_ids
:
BlockIds
,
remote_block_ids
:
list
[
int
]
,
remote_block_ids
:
BlockIds
,
dst_engine_id
:
str
,
dst_engine_id
:
str
,
request_id
:
str
,
request_id
:
str
,
remote_request_id
:
str
,
remote_request_id
:
str
,
...
@@ -2246,22 +2332,30 @@ class NixlConnectorWorker:
...
@@ -2246,22 +2332,30 @@ class NixlConnectorWorker:
assert
self
.
kv_topo
is
not
None
assert
self
.
kv_topo
is
not
None
block_size_ratio
=
self
.
kv_topo
.
block_size_ratio_from_engine_id
(
dst_engine_id
)
block_size_ratio
=
self
.
kv_topo
.
block_size_ratio_from_engine_id
(
dst_engine_id
)
if
block_size_ratio
>
1
:
if
block_size_ratio
>
1
:
local_block_ids
=
self
.
get_mapped_blocks
(
# TODO (NickLucche) assume HMA is off. Change to handle multiple KV groups.
np
.
asarray
(
local_block_ids
),
block_size_ratio
assert
not
self
.
_is_hma_required
)
local_block_ids0
=
local_block_ids
[
0
]
if
local_block_ids
else
[]
if
len
(
local_block_ids
)
>
len
(
remote_block_ids
):
remote_block_ids0
=
remote_block_ids
[
0
]
local_block_ids_mapped
=
self
.
get_mapped_blocks
(
np
.
asarray
(
local_block_ids0
),
block_size_ratio
).
tolist
()
if
len
(
local_block_ids_mapped
)
>
len
(
remote_block_ids0
):
# NOTE:
# NOTE:
# get_mapped_blocks will always expand block_ids for n times.
# get_mapped_blocks will always expand block_ids for n times.
# ex:
# ex:
# prefill block_ids with block_size as 4:
# prefill block_ids with block_size as 4:
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# Local decode block_ids with block_size as 16: [1, 2, 3]
# Local decode block_ids with block_size as 16: [1, 2, 3]
# exp
l
and
ecode block_ids with get_mapped_blocks from [1, 2, 3] to
# expand
ed d
ecode block_ids with get_mapped_blocks from [1, 2, 3] to
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
# Then we clip local to align with prefill
# Then we clip local to align with prefill
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] to
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] to
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
local_block_ids
=
local_block_ids
[:
len
(
remote_block_ids
)]
local_block_ids_mapped
=
local_block_ids_mapped
[
:
len
(
remote_block_ids0
)
]
local_block_ids
=
[
local_block_ids_mapped
]
if
local_block_ids_mapped
else
[]
remote_block_ids
=
[
remote_block_ids0
]
# NOTE(rob): having the staging blocks be on the READER side is
# NOTE(rob): having the staging blocks be on the READER side is
# not going to work well (since we will have to call rearrange tensors).
# not going to work well (since we will have to call rearrange tensors).
# after we detect the txn is complete (which means we cannot make the
# after we detect the txn is complete (which means we cannot make the
...
@@ -2269,8 +2363,7 @@ class NixlConnectorWorker:
...
@@ -2269,8 +2363,7 @@ class NixlConnectorWorker:
# then we will need to have the staging blocks on the remote side.
# then we will need to have the staging blocks on the remote side.
# NOTE(rob): according to nvidia the staging blocks are used to
# NOTE(rob): according to nvidia the staging blocks are used to
# saturate IB with heterogeneous TP sizes. We should remove the staging
# saturate IB with heterogeneous TP sizes.
# blocks until we are ready.
# Number of D TP workers that will read from dst P. Propagate info
# Number of D TP workers that will read from dst P. Propagate info
# on notification so that dst worker can wait before freeing blocks.
# on notification so that dst worker can wait before freeing blocks.
...
@@ -2278,8 +2371,8 @@ class NixlConnectorWorker:
...
@@ -2278,8 +2371,8 @@ class NixlConnectorWorker:
# Full prefix cache hit: do not need to read remote blocks,
# Full prefix cache hit: do not need to read remote blocks,
# just notify P worker that we have the blocks we need.
# just notify P worker that we have the blocks we need.
num_local_blocks
=
len
(
local_block_ids
)
if
len
(
local_block_ids
)
==
0
:
if
num_local_blocks
==
0
:
# A full prefix cache hit is indicated with an empty list.
agent_name
=
self
.
_remote_agents
[
dst_engine_id
][
remote_rank
]
agent_name
=
self
.
_remote_agents
[
dst_engine_id
][
remote_rank
]
try
:
try
:
self
.
nixl_wrapper
.
send_notif
(
agent_name
,
notif_msg
=
notif_id
)
self
.
nixl_wrapper
.
send_notif
(
agent_name
,
notif_msg
=
notif_id
)
...
@@ -2297,66 +2390,34 @@ class NixlConnectorWorker:
...
@@ -2297,66 +2390,34 @@ class NixlConnectorWorker:
self
.
xfer_stats
.
record_failed_notification
()
self
.
xfer_stats
.
record_failed_notification
()
return
return
# Partial prefix cache hit: just read uncomputed blocks.
assert
(
num_remote_blocks
=
len
(
remote_block_ids
)
len
(
remote_block_ids
)
assert
num_local_blocks
<=
num_remote_blocks
==
len
(
local_block_ids
)
if
num_local_blocks
<
num_remote_blocks
:
==
len
(
self
.
kv_cache_config
.
kv_cache_groups
)
remote_block_ids
=
remote_block_ids
[
-
num_local_blocks
:]
)
remote_block_ids
=
list
(
remote_block_ids
)
for
i
,
remote_group
in
enumerate
(
remote_block_ids
):
num_remote_blocks
=
len
(
remote_group
)
num_local_blocks
=
len
(
local_block_ids
[
i
])
assert
num_local_blocks
<=
num_remote_blocks
# Partial prefix cache hit: just read uncomputed blocks.
if
num_local_blocks
<
num_remote_blocks
:
remote_block_ids
[
i
]
=
remote_group
[
-
num_local_blocks
:]
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp
# workers will issue xfers to parts of the P worker remote kv caches.
# workers will issue xfers to parts of the P worker remote kv caches.
# Get descs ids.
# Get descs ids.
local_block_descs_ids
:
np
.
ndarray
remote_block_descs_ids
=
self
.
_get_block_descs_ids
(
remote_block_descs_ids
:
np
.
ndarray
dst_engine_id
,
remote_block_ids
,
if
not
self
.
block_window_per_layer
:
)
# Default case: assume global attention
local_block_descs_ids
=
self
.
_get_block_descs_ids
(
remote_block_descs_ids
=
self
.
_get_block_descs_ids
(
self
.
engine_id
,
dst_engine_id
,
local_block_ids
,
remote_block_ids
,
block_size_ratio
=
block_size_ratio
,
)
)
local_block_descs_ids
=
self
.
_get_block_descs_ids
(
self
.
engine_id
,
local_block_ids
,
block_size_ratio
=
block_size_ratio
,
)
else
:
# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
local_descs_list
=
[]
remote_descs_list
=
[]
for
layer_idx
,
block_window
in
enumerate
(
self
.
block_window_per_layer
):
# For each layer:
if
block_window
is
None
:
# If not chunked, we just use the
# full block lists (global attention)
layer_local_block_ids
=
local_block_ids
layer_remote_block_ids
=
remote_block_ids
else
:
# If chunked, get the last block_window blocks
layer_local_block_ids
=
local_block_ids
[
-
block_window
:]
layer_remote_block_ids
=
remote_block_ids
[
-
block_window
:]
# Get descs ids for the layer.
layer_local_desc_ids
=
self
.
_get_block_descs_ids
(
self
.
engine_id
,
layer_local_block_ids
,
layer_idx
,
block_size_ratio
=
block_size_ratio
,
)
layer_remote_desc_ids
=
self
.
_get_block_descs_ids
(
dst_engine_id
,
layer_remote_block_ids
,
layer_idx
,
)
local_descs_list
.
append
(
layer_local_desc_ids
)
remote_descs_list
.
append
(
layer_remote_desc_ids
)
local_block_descs_ids
=
np
.
concatenate
(
local_descs_list
)
remote_block_descs_ids
=
np
.
concatenate
(
remote_descs_list
)
assert
len
(
local_block_descs_ids
)
==
len
(
remote_block_descs_ids
)
assert
len
(
local_block_descs_ids
)
==
len
(
remote_block_descs_ids
)
...
@@ -2387,14 +2448,18 @@ class NixlConnectorWorker:
...
@@ -2387,14 +2448,18 @@ class NixlConnectorWorker:
dst_engine_id
=
dst_engine_id
,
dst_engine_id
=
dst_engine_id
,
remote_rank
=
remote_rank
,
remote_rank
=
remote_rank
,
)
)
if
meta
:
=
self
.
_recving_metadata
.
get
(
request_id
):
if
(
self
.
_invalid_block_ids
.
update
(
meta
.
local_block_ids
)
meta
:
=
self
.
_recving_metadata
.
get
(
request_id
)
)
and
not
self
.
_is_hma_required
:
self
.
_invalid_block_ids
.
update
(
meta
.
local_block_ids
[
0
])
self
.
xfer_stats
.
record_failed_transfer
()
self
.
xfer_stats
.
record_failed_transfer
()
if
handle
is
not
None
:
if
handle
is
not
None
:
self
.
nixl_wrapper
.
release_xfer_handle
(
handle
)
self
.
nixl_wrapper
.
release_xfer_handle
(
handle
)
self
.
_failed_recv_reqs
.
add
(
request_id
)
self
.
_failed_recv_reqs
.
add
(
request_id
)
def
get_mapped_blocks
(
self
,
block_ids
,
block_size_ratio
):
def
get_mapped_blocks
(
self
,
block_ids
:
np
.
ndarray
,
block_size_ratio
:
int
)
->
np
.
ndarray
:
"""
"""
Calculates the new set of block IDs by mapping every element
Calculates the new set of block IDs by mapping every element
in the (potentially sparse) input array.
in the (potentially sparse) input array.
...
@@ -2416,41 +2481,32 @@ class NixlConnectorWorker:
...
@@ -2416,41 +2481,32 @@ class NixlConnectorWorker:
def
_get_block_descs_ids
(
def
_get_block_descs_ids
(
self
,
self
,
engine_id
:
str
,
engine_id
:
str
,
block_ids
:
list
[
int
],
block_ids
:
BlockIds
,
layer_idx
:
int
|
None
=
None
,
block_size_ratio
:
float
|
None
=
None
,
block_size_ratio
:
float
|
None
=
None
,
)
->
np
.
ndarray
:
)
->
np
.
ndarray
:
"""
"""
Get the descs ids for a set of block ids.
Get the descs ids for a set of block ids.
If layer_idx is provided, we use the region_ids for the given lay
er.
When HMA is enabled number of descriptors across kv cache groups might diff
er.
Otherwise, we use all regions
.
A single flattened array is returned for all groups anyway
.
"""
"""
if
layer_idx
is
None
:
region_ids
=
np
.
arange
(
self
.
num_regions
)
region_ids
=
np
.
arange
(
self
.
num_regions
)
# NOTE (NickLucche) With HMA, every kv group has the same number of layers and
else
:
# layers from different groups share the same kv tensor.
assert
layer_idx
<
self
.
num_layers
# eg block_ids=[[1, 2], [3]]->blocks [1, 2] need to be read across all regions,
if
self
.
num_layers
<
self
.
num_regions
:
# same for [3], but group0-group1 blocks will always differ (different areas).
# If we have more regions than layers, we assume that
# Therefore we can just flatten the block_ids and compute the descs ids for all
# the regions are organized as [K0, V0, K1, V1, ...]
# groups at once.
# and we select K_i and V_i
assert
2
*
self
.
num_layers
==
self
.
num_regions
region_ids
=
np
.
arange
(
2
*
layer_idx
,
2
*
layer_idx
+
2
)
else
:
# Otherwise, we assume we have MLA and select i-th layer
assert
self
.
num_layers
==
self
.
num_regions
region_ids
=
np
.
arange
(
layer_idx
,
layer_idx
+
1
)
num_blocks
=
self
.
dst_num_blocks
[
engine_id
]
num_blocks
=
self
.
dst_num_blocks
[
engine_id
]
if
block_size_ratio
is
not
None
:
if
block_size_ratio
is
not
None
:
num_blocks
=
int
(
num_blocks
*
block_size_ratio
)
num_blocks
=
int
(
num_blocks
*
block_size_ratio
)
# Compute the desc ids for each block.
# Compute the desc ids for each block.
region_ids
=
region_ids
[:,
None
]
region_ids
=
region_ids
[:,
None
]
block_ids
=
np
.
array
(
block_ids
)[
None
,
:]
block_ids
=
np
.
concatenate
(
block_ids
)[
None
,
:]
descs_ids
=
region_ids
*
num_blocks
+
block_ids
descs_ids
=
region_ids
*
num_blocks
+
block_ids
return
descs_ids
.
flatten
()
return
descs_ids
.
flatten
()
def
_logical_to_kernel_block_ids
(
self
,
block_ids
:
list
[
int
])
->
list
[
int
]
:
def
_logical_to_kernel_block_ids
(
self
,
block_ids
:
BlockIds
)
->
BlockIds
:
"""
"""
Convert logical block ids to kernel physical block ids.
Convert logical block ids to kernel physical block ids.
This is required when the logical block size (the one set by the user)
This is required when the logical block size (the one set by the user)
...
@@ -2459,13 +2515,17 @@ class NixlConnectorWorker:
...
@@ -2459,13 +2515,17 @@ class NixlConnectorWorker:
if
self
.
_physical_blocks_per_logical_kv_block
==
1
:
if
self
.
_physical_blocks_per_logical_kv_block
==
1
:
# Noop when physical and logical block sizes are the same
# Noop when physical and logical block sizes are the same
return
block_ids
return
block_ids
block_ids_np
=
np
.
array
(
block_ids
)
block_arange
=
np
.
arange
(
0
,
self
.
_physical_blocks_per_logical_kv_block
).
reshape
(
block_arange
=
np
.
arange
(
0
,
self
.
_physical_blocks_per_logical_kv_block
).
reshape
(
1
,
-
1
1
,
-
1
)
)
return
BlockTable
.
map_to_kernel_blocks
(
return
[
block_ids_np
,
self
.
_physical_blocks_per_logical_kv_block
,
block_arange
BlockTable
.
map_to_kernel_blocks
(
).
tolist
()
np
.
array
(
group
),
self
.
_physical_blocks_per_logical_kv_block
,
block_arange
,
).
tolist
()
for
group
in
block_ids
]
def
get_backend_aware_kv_block_len
(
self
,
layer_idx
:
int
)
->
int
:
def
get_backend_aware_kv_block_len
(
self
,
layer_idx
:
int
)
->
int
:
"""
"""
...
...
vllm/v1/core/kv_cache_manager.py
View file @
5b3ba94a
...
@@ -84,6 +84,18 @@ class KVCacheBlocks:
...
@@ -84,6 +84,18 @@ class KVCacheBlocks:
assert
len
(
self
.
blocks
)
==
1
,
"Only one group is supported"
assert
len
(
self
.
blocks
)
==
1
,
"Only one group is supported"
return
[
block
.
block_id
for
block
in
self
.
blocks
[
0
]
if
block
.
block_hash
is
None
]
return
[
block
.
block_id
for
block
in
self
.
blocks
[
0
]
if
block
.
block_hash
is
None
]
def
get_unhashed_block_ids_all_groups
(
self
)
->
list
[
list
[
int
]]:
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
# Skip padding blocks.
return
[
[
block
.
block_id
for
block
in
group
if
block
.
block_hash
is
None
and
not
block
.
is_null
]
for
group
in
self
.
blocks
]
def
new_empty
(
self
)
->
"KVCacheBlocks"
:
def
new_empty
(
self
)
->
"KVCacheBlocks"
:
"""
"""
Creates a new KVCacheBlocks instance with no blocks.
Creates a new KVCacheBlocks instance with no blocks.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment