Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
72f431e7
Unverified
Commit
72f431e7
authored
Oct 21, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Oct 21, 2025
Browse files
[Nixl] Minor refactor to handshake related metadata (#26410)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
be444507
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
176 additions
and
88 deletions
+176
-88
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+0
-2
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+176
-86
No files found.
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
72f431e7
...
...
@@ -565,8 +565,6 @@ class TestNixlHandshake:
kv_cache_layout
=
mismatched_layout
,
)
# We don't check layout for homogeneous TP and MLA for now, as the
# whole block is moved.
with
pytest
.
raises
(
RuntimeError
):
# mismatched layout is expected to fail
worker
.
add_remote_agent
(
meta
,
remote_tp_size
=
2
)
...
...
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
72f431e7
...
...
@@ -36,7 +36,6 @@ from vllm.distributed.parallel_state import (
get_tensor_model_parallel_world_size
,
get_tp_group
,
)
from
vllm.distributed.utils
import
divide
from
vllm.forward_context
import
ForwardContext
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
...
...
@@ -521,6 +520,72 @@ class NixlConnectorScheduler:
class
NixlConnectorWorker
:
"""Implementation of Worker side methods"""
@
dataclass
class
TpKVTopology
:
"""
Helper class for tensor parallel and KV topology information for
mapping between local and remote TP workers.
"""
tp_size
:
int
tp_rank
:
int
remote_tp_size
:
dict
[
EngineId
,
int
]
is_mla
:
bool
total_num_kv_heads
:
int
def
tp_ratio
(
self
,
remote_tp_size
:
int
,
)
->
int
:
"""
Calculate the tensor parallel ratio between local and remote TP.
We can think of it as the number of local TP workers-per-remote TP
workers. Local workers will read from the same remote TP worker in
groups of size `tp_ratio`.
"""
assert
self
.
tp_size
%
remote_tp_size
==
0
,
(
f
"Local tensor parallel size
{
self
.
tp_size
}
is not divisible "
f
"by remote tensor parallel size
{
remote_tp_size
}
."
)
return
self
.
tp_size
//
remote_tp_size
def
tp_ratio_from_engine_id
(
self
,
remote_engine_id
:
EngineId
,
)
->
int
:
remote_tp_size
=
self
.
remote_tp_size
[
remote_engine_id
]
return
self
.
tp_ratio
(
remote_tp_size
)
def
is_kv_replicated
(
self
,
engine_id
:
EngineId
)
->
bool
:
"""
Whether the KV cache is replicated across TP workers due to the
number of TP workers being greater than the number of KV heads.
"""
tp_size
=
self
.
remote_tp_size
[
engine_id
]
return
tp_size
//
self
.
total_num_kv_heads
>=
1
def
replicates_kv_cache
(
self
,
remote_engine_id
:
EngineId
)
->
bool
:
# MLA is always replicated as the hidden dim can't be split.
return
self
.
is_mla
or
self
.
is_kv_replicated
(
remote_engine_id
)
def
get_target_remote_rank
(
self
,
remote_tp_size
:
int
,
)
->
int
:
"""
Get the remote TP rank (on P) that the current local TP rank
(on D) will read from.
"""
tp_ratio
=
self
.
tp_ratio
(
remote_tp_size
)
return
self
.
tp_rank
//
tp_ratio
def
get_target_remote_rank_from_engine_id
(
self
,
remote_engine_id
:
EngineId
,
)
->
int
:
remote_tp_size
=
self
.
remote_tp_size
[
remote_engine_id
]
return
self
.
get_target_remote_rank
(
remote_tp_size
)
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
engine_id
:
str
):
if
NixlWrapper
is
None
:
logger
.
error
(
"NIXL is not available"
)
...
...
@@ -534,6 +599,7 @@ class NixlConnectorWorker:
if
vllm_config
.
kv_transfer_config
is
None
:
raise
ValueError
(
"kv_transfer_config must be set for NixlConnector"
)
self
.
kv_transfer_config
=
vllm_config
.
kv_transfer_config
self
.
nixl_backends
=
vllm_config
.
kv_transfer_config
.
get_from_extra_config
(
"backends"
,
[
"UCX"
]
...
...
@@ -654,7 +720,6 @@ class NixlConnectorWorker:
# Protects _handshake_futures and _remote_agents.
self
.
_handshake_lock
=
threading
.
RLock
()
self
.
vllm_config
=
vllm_config
self
.
block_size
=
vllm_config
.
cache_config
.
block_size
self
.
model_config
=
vllm_config
.
model_config
self
.
cache_config
=
vllm_config
.
cache_config
...
...
@@ -686,6 +751,14 @@ class NixlConnectorWorker:
self
.
consumer_notification_counts_by_req
=
defaultdict
[
ReqId
,
int
](
int
)
self
.
xfer_stats
=
NixlKVConnectorStats
()
self
.
kv_topo
=
self
.
TpKVTopology
(
tp_size
=
self
.
world_size
,
tp_rank
=
self
.
tp_rank
,
remote_tp_size
=
self
.
_tp_size
,
# shared state
is_mla
=
self
.
use_mla
,
total_num_kv_heads
=
self
.
model_config
.
get_total_num_kv_heads
(),
)
@
staticmethod
def
_nixl_handshake_listener
(
metadata
:
NixlAgentMetadata
,
...
...
@@ -731,8 +804,7 @@ class NixlConnectorWorker:
# Handshake only with the remote TP rank that current local rank will
# pull from. With homogeneous TP it happens to be the same rank_i.
tp_ratio
=
self
.
_tp_size
[
self
.
engine_id
]
//
remote_tp_size
p_remote_rank
=
self
.
tp_rank
//
tp_ratio
p_remote_rank
=
self
.
kv_topo
.
get_target_remote_rank
(
remote_tp_size
)
path
=
make_zmq_path
(
"tcp"
,
host
,
port
+
p_remote_rank
)
logger
.
debug
(
"Querying metadata on path: %s at remote rank %s"
,
path
,
p_remote_rank
...
...
@@ -989,13 +1061,11 @@ class NixlConnectorWorker:
# TODO(mgoin): Hybrid memory allocator is currently disabled for
# models with local attention (Llama 4). Can remove this once enabled.
if
self
.
vllm_config
.
model_config
.
hf_config
.
model_type
==
"llama4"
:
if
self
.
model_config
.
hf_config
.
model_type
==
"llama4"
:
from
transformers
import
Llama4TextConfig
assert
isinstance
(
self
.
vllm_config
.
model_config
.
hf_text_config
,
Llama4TextConfig
)
llama4_config
=
self
.
vllm_config
.
model_config
.
hf_text_config
assert
isinstance
(
self
.
model_config
.
hf_text_config
,
Llama4TextConfig
)
llama4_config
=
self
.
model_config
.
hf_text_config
no_rope_layers
=
llama4_config
.
no_rope_layers
chunk_size
=
llama4_config
.
attention_chunk_size
chunk_block_size
=
math
.
ceil
(
chunk_size
/
self
.
block_size
)
...
...
@@ -1078,36 +1148,106 @@ class NixlConnectorWorker:
engine_id
=
nixl_agent_meta
.
engine_id
# TODO re-evaluate refreshing for scaling/recovery
if
remote_tp_rank
in
self
.
_remote_agents
.
get
(
engine_id
,
{}):
logger
.
debug
(
"Remote agent with engine_id %s and rank"
"%s already exchanged metadata, skip handshake."
,
engine_id
,
remote_tp_rank
,
)
return
self
.
_remote_agents
[
engine_id
][
remote_tp_rank
]
### Register remote agent metadata
if
engine_id
not
in
self
.
_tp_size
:
self
.
_tp_size
[
engine_id
]
=
remote_tp_size
else
:
assert
self
.
_tp_size
[
engine_id
]
==
remote_tp_size
# TODO We may eventually want to skip enforcing the same attn backend.
assert
nixl_agent_meta
.
attn_backend_name
==
self
.
backend_name
remote_agent_name
=
self
.
nixl_wrapper
.
add_remote_agent
(
nixl_agent_meta
.
agent_metadata
)
# Handle tp_size>num_kv_heads: replicate KV cache.
replicates_kv_cache
=
self
.
kv_topo
.
replicates_kv_cache
(
engine_id
)
# Create dst descs and xfer side handles. TP workers have same #blocks
# so we only register once per engine_id.
if
engine_id
not
in
self
.
dst_num_blocks
:
self
.
dst_num_blocks
[
engine_id
]
=
nixl_agent_meta
.
num_blocks
# Keep track of remote agent kv caches base addresses.
self
.
kv_caches_base_addr
[
engine_id
]
=
nixl_agent_meta
.
kv_caches_base_addr
self
.
_validate_remote_agent_handshake
(
nixl_agent_meta
,
remote_tp_size
)
# Number of D TP workers reading from a single P TP worker. This is
# 1 when P and D `--tensor-parallel-size` match.
tp_ratio
=
divide
(
self
.
_tp_size
[
self
.
engine_id
],
self
.
_tp_size
[
engine_id
])
tp_ratio
=
self
.
kv_topo
.
tp_ratio_from_engine_id
(
engine_id
)
### Register remote agent memory regions
blocks_data
=
[]
# With homogeneous TP, D pulls the whole kv cache from corresponding
# rank. With heterogeneous TP, prepare the descriptors by splitting the
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
# Register all remote blocks, but only the corresponding kv heads.
for
i
,
base_addr
in
enumerate
(
nixl_agent_meta
.
kv_caches_base_addr
):
kv_block_len
=
self
.
get_backend_aware_kv_block_len
(
layer_idx
=
i
)
rank_offset
=
(
self
.
tp_rank
%
tp_ratio
*
kv_block_len
if
not
replicates_kv_cache
else
0
)
for
block_id
in
range
(
nixl_agent_meta
.
num_blocks
):
block_offset
=
block_id
*
nixl_agent_meta
.
block_lens
[
i
]
# For each block, grab the heads chunk belonging to rank_i
# of size remote_nheads // tp_ratio, which correspond to
# self.block_len == remote_block_len//tp_ratio bytes.
addr
=
base_addr
+
block_offset
+
rank_offset
# (addr, len, device id)
blocks_data
.
append
((
addr
,
kv_block_len
,
remote_tp_rank
))
if
self
.
_use_flashinfer
:
# With FlashInfer index V separately to allow head splitting.
for
block_id
in
range
(
nixl_agent_meta
.
num_blocks
):
block_offset
=
block_id
*
nixl_agent_meta
.
block_lens
[
i
]
addr
=
base_addr
+
block_offset
+
rank_offset
v_addr
=
addr
+
nixl_agent_meta
.
block_lens
[
i
]
//
2
blocks_data
.
append
((
v_addr
,
kv_block_len
,
remote_tp_rank
))
logger
.
debug
(
"Created %s blocks for dst engine %s with remote rank %s and local rank %s"
,
len
(
blocks_data
),
engine_id
,
remote_tp_rank
,
self
.
tp_rank
,
)
# Register with NIXL.
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
blocks_data
,
self
.
nixl_memory_type
)
self
.
dst_xfer_side_handles
[
engine_id
]
=
self
.
nixl_wrapper
.
prep_xfer_dlist
(
remote_agent_name
,
descs
)
return
remote_agent_name
def
_validate_remote_agent_handshake
(
self
,
nixl_agent_meta
:
NixlAgentMetadata
,
remote_tp_size
:
int
):
"""
Validate the remote agent handshake metadata ensuring the
invariants hold true.
"""
remote_engine_id
=
nixl_agent_meta
.
engine_id
assert
self
.
_tp_size
[
remote_engine_id
]
==
remote_tp_size
# TODO We may eventually want to skip enforcing the same attn backend.
assert
nixl_agent_meta
.
attn_backend_name
==
self
.
backend_name
tp_ratio
=
self
.
kv_topo
.
tp_ratio_from_engine_id
(
remote_engine_id
)
assert
tp_ratio
>
0
,
"Decode TP cannot be smaller than prefill TP"
assert
not
self
.
_use_pallas
or
tp_ratio
==
1
,
(
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
)
# Handle tp_size>num_kv_heads: replicate KV cache.
total_num_kv_heads
=
self
.
model_config
.
get_total_num_kv_heads
()
is_kv_replicated
=
self
.
_tp_size
[
engine_id
]
//
total_num_kv_heads
>=
1
remote_block_len
=
nixl_agent_meta
.
block_lens
[
0
]
if
nixl_agent_meta
.
kv_cache_layout
!=
self
.
kv_cache_layout
:
if
not
self
.
use_mla
and
nixl_agent_meta
.
kv_cache_layout
!=
self
.
kv_cache_layout
:
if
(
self
.
vllm_config
.
kv_transfer_config
is
not
None
and
self
.
vllm_config
.
kv_transfer_config
.
enable_permute_local_kv
self
.
kv_transfer_config
.
enable_permute_local_kv
and
nixl_agent_meta
.
kv_cache_layout
==
"HND"
):
logger
.
info
(
...
...
@@ -1121,13 +1261,19 @@ class NixlConnectorWorker:
"Or enable experimental feature to use HND to NHD support by "
"setting 'enable_permute_local_kv'=True in --kv-transfer-config."
)
if
self
.
use_mla
or
is_kv_replicated
:
# Block len can only vary across layers when using MLA.
remote_block_len
=
nixl_agent_meta
.
block_lens
[
0
]
if
self
.
use_mla
or
self
.
kv_topo
.
is_kv_replicated
(
remote_engine_id
):
# With replicated KV cache, only the number of blocks can differ.
assert
self
.
block_len_per_layer
==
nixl_agent_meta
.
block_lens
,
(
"KV cache sizes must match between P and D when replicated"
)
remote_block_size
=
remote_block_len
//
(
self
.
slot_size_per_layer
[
0
])
else
:
if
tp_ratio
>
1
and
self
.
device_type
==
"xpu"
:
# XPU uses NHD, hence it does not support splitting on H
raise
ValueError
(
"Heterogeneous TP is not supported on XPU"
)
# When MLA is not used, this is a list of the same block length
for
block_len
in
nixl_agent_meta
.
block_lens
:
assert
block_len
==
remote_block_len
,
(
...
...
@@ -1139,14 +1285,6 @@ class NixlConnectorWorker:
if
self
.
_use_flashinfer
:
# With flashinfer, KV are sent in the same message.
remote_block_size
//=
2
if
tp_ratio
>
1
:
# Heterogeneous TP expects same kv_cache_layout.
if
nixl_agent_meta
.
kv_cache_layout
==
"NHD"
:
raise
ValueError
(
"Heterogeneous TP is not supported for remote with NHD."
)
if
self
.
device_type
==
"xpu"
:
raise
ValueError
(
"Heterogeneous TP is not supported on XPU"
)
assert
remote_block_len
==
self
.
block_len_per_layer
[
0
]
*
tp_ratio
,
(
"Remote P worker KV layer cache must be of shape [2, N, "
...
...
@@ -1158,60 +1296,10 @@ class NixlConnectorWorker:
f
"
{
self
.
block_size
=
}
,
{
remote_block_size
=
}
"
)
# Create dst descs and xfer side handles. TP workers have same #blocks.
if
engine_id
in
self
.
dst_num_blocks
:
assert
self
.
dst_num_blocks
[
engine_id
]
==
nixl_agent_meta
.
num_blocks
else
:
self
.
dst_num_blocks
[
engine_id
]
=
nixl_agent_meta
.
num_blocks
blocks_data
=
[]
# With homogeneous TP, D pulls the whole kv cache from corresponding
# rank. With heterogeneous TP, prepare the descriptors by splitting the
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
self
.
kv_caches_base_addr
[
engine_id
]
=
nixl_agent_meta
.
kv_caches_base_addr
# TP workers have same #blocks.
assert
self
.
dst_num_blocks
[
remote_engine_id
]
==
nixl_agent_meta
.
num_blocks
assert
len
(
nixl_agent_meta
.
kv_caches_base_addr
)
==
len
(
self
.
block_len_per_layer
)
# Register all remote blocks, but only the corresponding kv heads.
for
i
,
base_addr
in
enumerate
(
nixl_agent_meta
.
kv_caches_base_addr
):
kv_block_len
=
self
.
get_backend_aware_kv_block_len
(
layer_idx
=
i
)
rank_offset
=
(
self
.
tp_rank
%
tp_ratio
*
kv_block_len
if
not
(
self
.
use_mla
or
is_kv_replicated
)
else
0
)
for
block_id
in
range
(
nixl_agent_meta
.
num_blocks
):
block_offset
=
block_id
*
nixl_agent_meta
.
block_lens
[
i
]
# For each block, grab the heads chunk belonging to rank_i
# of size remote_nheads // tp_ratio, which correspond to
# self.block_len == remote_block_len//tp_ratio bytes.
addr
=
base_addr
+
block_offset
+
rank_offset
# (addr, len, device id)
blocks_data
.
append
((
addr
,
kv_block_len
,
remote_tp_rank
))
if
self
.
_use_flashinfer
:
# With FlashInfer index V separately to allow head splitting.
for
block_id
in
range
(
nixl_agent_meta
.
num_blocks
):
block_offset
=
block_id
*
nixl_agent_meta
.
block_lens
[
i
]
addr
=
base_addr
+
block_offset
+
rank_offset
v_addr
=
addr
+
nixl_agent_meta
.
block_lens
[
i
]
//
2
blocks_data
.
append
((
v_addr
,
kv_block_len
,
remote_tp_rank
))
logger
.
debug
(
"Created %s blocks for dst engine %s with remote rank %s and local rank %s"
,
len
(
blocks_data
),
engine_id
,
remote_tp_rank
,
self
.
tp_rank
,
)
# Register with NIXL.
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
blocks_data
,
self
.
nixl_memory_type
)
self
.
dst_xfer_side_handles
[
engine_id
]
=
self
.
nixl_wrapper
.
prep_xfer_dlist
(
remote_agent_name
,
descs
)
return
remote_agent_name
def
sync_recved_kv_to_device
(
self
,
req_id
:
str
,
meta
:
ReqMeta
):
"""copy recved kv from host buffer to device."""
...
...
@@ -1505,14 +1593,16 @@ class NixlConnectorWorker:
# Number of D TP workers that will read from dst P. Propagate tp_ratio
# on notification so that dst worker can wait before freeing blocks.
tp_ratio
=
self
.
_tp_size
[
self
.
engine_id
]
//
self
.
_tp_size
[
dst_engine_id
]
tp_ratio
=
self
.
kv_topo
.
tp_ratio_from_engine_id
(
dst_engine_id
)
notif_id
=
f
"
{
request_id
}
:
{
tp_ratio
}
"
.
encode
()
# Full prefix cache hit: do not need to read remote blocks,
# just notify P worker that we have the blocks we need.
num_local_blocks
=
len
(
local_block_ids
)
if
num_local_blocks
==
0
:
remote_rank
=
self
.
tp_rank
//
tp_ratio
remote_rank
=
self
.
kv_topo
.
get_target_remote_rank_from_engine_id
(
dst_engine_id
)
agent_name
=
self
.
_remote_agents
[
dst_engine_id
][
remote_rank
]
try
:
self
.
nixl_wrapper
.
send_notif
(
agent_name
,
notif_msg
=
notif_id
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment