Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
422f22e0
Unverified
Commit
422f22e0
authored
Aug 12, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Aug 12, 2025
Browse files
[CI][Nixl] Check kv cache layout during handshake (#22745)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
6bd8ebf0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
56 additions
and
3 deletions
+56
-3
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+46
-0
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+10
-3
No files found.
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
422f22e0
...
...
@@ -419,6 +419,52 @@ class TestNixlHandshake:
return
raise
TimeoutError
(
"Took too long to complete async handshake."
)
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FakeNixlWrapper
)
def
test_handshake_fails_on_kv_cache_layout_mismatch
(
self
,
dist_init
):
"""
Verify that adding a remote agent fails if kv_cache_layout differs.
This test is only relevant for heterogeneous TP.
"""
vllm_config
=
create_vllm_config
()
# Mock TP world size to 2 to force heterogeneous TP when
# remote_tp_size=1
with
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size"
,
# noqa: E501
return_value
=
2
):
# Initialize connector and worker (with fake NIXL wrapper)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
worker
=
connector
.
connector_worker
# Minimal local registration params used by add_remote_agent
worker
.
slot_size_bytes
=
4096
worker
.
block_len
=
worker
.
slot_size_bytes
*
worker
.
block_size
worker
.
num_blocks
=
1
worker
.
dst_num_blocks
[
worker
.
engine_id
]
=
worker
.
num_blocks
# Metadata with different kv_cache_layout than local worker
mismatched_layout
=
"HND"
if
worker
.
kv_cache_layout
!=
"HND"
\
else
"NHD"
meta
=
NixlAgentMetadata
(
engine_id
=
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
agent_metadata
=
FakeNixlWrapper
.
AGENT_METADATA
,
kv_caches_base_addr
=
[
0
],
num_blocks
=
1
,
block_len
=
worker
.
block_len
,
attn_backend_name
=
worker
.
backend_name
,
kv_cache_layout
=
mismatched_layout
,
)
# We don't check layout for homogeneous TP and MLA for now, as the
# whole block is moved.
worker
.
add_remote_agent
(
meta
,
remote_tp_size
=
2
)
with
pytest
.
raises
(
AssertionError
):
worker
.
add_remote_agent
(
meta
,
remote_tp_size
=
1
)
# NOTE: resource cleanup in mp backend is a bit finicky, so the order in which
# we put here is important. First run ray, it will clean up the resources, then
...
...
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
422f22e0
...
...
@@ -30,6 +30,7 @@ from vllm.forward_context import ForwardContext
from
vllm.logger
import
init_logger
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.utils
import
make_zmq_path
,
make_zmq_socket
from
vllm.v1.attention.backends.utils
import
get_kv_cache_layout
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.request
import
RequestStatus
...
...
@@ -73,6 +74,7 @@ class NixlAgentMetadata(
num_blocks
:
int
block_len
:
int
attn_backend_name
:
str
kv_cache_layout
:
str
@
dataclass
...
...
@@ -538,7 +540,9 @@ class NixlConnectorWorker:
attn_backend
=
backend_name_to_enum
(
self
.
backend_name
)
self
.
_use_flashinfer
=
attn_backend
==
_Backend
.
FLASHINFER_VLLM_V1
self
.
_use_pallas_v1
=
attn_backend
==
_Backend
.
PALLAS_VLLM_V1
self
.
kv_cache_layout
=
get_kv_cache_layout
()
logger
.
debug
(
"Detected attention backend %s"
,
self
.
backend_name
)
logger
.
debug
(
"Detected kv cache layout %s"
,
self
.
kv_cache_layout
)
self
.
_tp_size
:
dict
[
EngineId
,
int
]
=
{
self
.
engine_id
:
self
.
world_size
}
# With heterogeneous TP, P must wait for all assigned D TP workers to
...
...
@@ -839,7 +843,8 @@ class NixlConnectorWorker:
kv_caches_base_addr
=
self
.
kv_caches_base_addr
[
self
.
engine_id
],
num_blocks
=
self
.
num_blocks
,
block_len
=
self
.
block_len
,
attn_backend_name
=
self
.
backend_name
)
attn_backend_name
=
self
.
backend_name
,
kv_cache_layout
=
self
.
kv_cache_layout
)
ready_event
=
threading
.
Event
()
self
.
_nixl_handshake_listener_t
=
threading
.
Thread
(
target
=
self
.
_nixl_handshake_listener
,
...
...
@@ -900,8 +905,7 @@ class NixlConnectorWorker:
self
.
_tp_size
[
engine_id
]
=
remote_tp_size
else
:
assert
self
.
_tp_size
[
engine_id
]
==
remote_tp_size
# We may eventually enable this after asserting equality in cache
# layout and close outputs.
# TODO We may eventually want to skip enforcing the same attn backend.
assert
nixl_agent_meta
.
attn_backend_name
==
self
.
backend_name
remote_agent_name
=
self
.
nixl_wrapper
.
add_remote_agent
(
...
...
@@ -930,6 +934,9 @@ class NixlConnectorWorker:
if
self
.
_use_flashinfer
:
# Account for joint KV in FlashInfer.
remote_block_size
//=
2
if
tp_ratio
>
1
:
# Heterogeneous TP expects same kv_cache_layout.
assert
nixl_agent_meta
.
kv_cache_layout
==
self
.
kv_cache_layout
assert
nixl_agent_meta
.
block_len
==
self
.
block_len
*
tp_ratio
,
(
"Remote P worker KV layer cache must be of shape [2, N, "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment