Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
25826835
Unverified
Commit
25826835
authored
Jun 26, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Jun 25, 2025
Browse files
[PD] Skip `tp_size` exchange with rank0 (#19413)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
754b00ed
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
72 additions
and
66 deletions
+72
-66
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+23
-6
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+49
-60
No files found.
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
25826835
...
...
@@ -7,6 +7,8 @@ from collections import defaultdict
from
typing
import
Optional
from
unittest.mock
import
patch
import
pytest
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
KVConnectorRole
,
NixlAgentMetadata
,
NixlConnector
,
NixlConnectorMetadata
,
NixlConnectorWorker
)
...
...
@@ -161,7 +163,8 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
_hand_shake_latency
=
hand_shake_latency
def
_nixl_handshake
(
self
,
host
:
str
,
port
:
int
)
->
dict
[
int
,
str
]:
def
_nixl_handshake
(
self
,
host
:
str
,
port
:
int
,
remote_tp_size
:
int
)
->
dict
[
int
,
str
]:
# Mimic slow _nixl_handshake, as well as bypass zmq communication.
time
.
sleep
(
self
.
_hand_shake_latency
)
# These should've been done in register_kv_caches(), called by
...
...
@@ -177,10 +180,10 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
agent_metadata
=
FakeNixlWrapper
.
AGENT_METADATA
,
kv_caches_base_addr
=
[
0
],
num_blocks
=
1
,
tp_size
=
1
,
block_len
=
self
.
block_len
,
attn_backend_name
=
self
.
backend_name
,
))
),
remote_tp_size
=
remote_tp_size
)
return
{
0
:
remote_agent_name
}
...
...
@@ -233,6 +236,8 @@ class TestNixlHandshake:
"localhost"
,
"remote_port"
:
1234
,
"remote_tp_size"
:
1
,
})
connector
.
bind_connector_metadata
(
metadata
)
...
...
@@ -259,13 +264,23 @@ class TestNixlHandshake:
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FakeNixlWrapper
)
@
pytest
.
mark
.
parametrize
(
"decode_tp_size, prefill_tp_size"
,
[
(
1
,
1
),
(
2
,
1
),
(
4
,
2
),
(
4
,
4
),
])
def
test_async_load_kv
(
self
,
# dist_init is a fixture that initializes the distributed environment.
dist_init
):
self
,
# Fixture that initializes the distributed environment.
dist_init
,
# Simulate consumer-producer TP sizes.
decode_tp_size
,
prefill_tp_size
):
"""Test that NixlConnector's start_load_kv should be non-blocking."""
vllm_config
=
create_vllm_config
()
vllm_config
.
parallel_config
.
tensor_parallel_size
=
decode_tp_size
# Test worker role in decode server.
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
...
...
@@ -280,6 +295,7 @@ class TestNixlHandshake:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_host"
:
"localhost"
,
"remote_port"
:
1234
,
"remote_tp_size"
:
prefill_tp_size
,
})
connector
.
bind_connector_metadata
(
metadata
)
...
...
@@ -329,6 +345,7 @@ class TestNixlHandshake:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_host"
:
"localhost"
,
"remote_port"
:
1234
,
"remote_tp_size"
:
1
,
})
connector
.
bind_connector_metadata
(
metadata
)
...
...
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
25826835
...
...
@@ -62,7 +62,6 @@ class NixlAgentMetadata(
agent_metadata
:
bytes
kv_caches_base_addr
:
list
[
int
]
num_blocks
:
int
tp_size
:
int
block_len
:
int
attn_backend_name
:
str
...
...
@@ -73,7 +72,8 @@ class ReqMeta:
remote_block_ids
:
list
[
int
]
remote_host
:
str
remote_port
:
int
remote_engine_id
:
EngineId
remote_engine_id
:
str
tp_size
:
int
class
NixlConnectorMetadata
(
KVConnectorMetadata
):
...
...
@@ -93,6 +93,8 @@ class NixlConnectorMetadata(KVConnectorMetadata):
remote_engine_id
=
kv_transfer_params
[
"remote_engine_id"
],
remote_host
=
kv_transfer_params
[
"remote_host"
],
remote_port
=
kv_transfer_params
[
"remote_port"
],
# P workers don't need to receive tp_size from proxy here.
tp_size
=
kv_transfer_params
.
get
(
"tp_size"
,
1
),
)
...
...
@@ -330,7 +332,7 @@ class NixlConnectorScheduler:
remote_engine_id
=
self
.
engine_id
,
remote_host
=
self
.
side_channel_host
,
remote_port
=
self
.
side_channel_port
,
)
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
)
class
NixlConnectorWorker
:
...
...
@@ -473,7 +475,8 @@ class NixlConnectorWorker:
"Connection listener got unexpected message %s"
,
msg
)
sock
.
send_multipart
((
identity
,
b
""
,
encoded_data
))
def
_nixl_handshake
(
self
,
host
:
str
,
port
:
int
)
->
dict
[
int
,
str
]:
def
_nixl_handshake
(
self
,
host
:
str
,
port
:
int
,
remote_tp_size
:
int
)
->
dict
[
int
,
str
]:
"""Do a NIXL handshake with a remote instance."""
start_time
=
time
.
perf_counter
()
...
...
@@ -482,7 +485,7 @@ class NixlConnectorWorker:
# a hack to keep us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
def
handshake
(
path
:
str
,
rank
:
int
)
->
tuple
[
NixlAgentMetadata
,
str
]
:
def
handshake
(
path
:
str
,
rank
:
int
)
->
str
:
# Send query for the request.
with
zmq_ctx
(
zmq
.
REQ
,
path
)
as
sock
:
sock
.
send
(
GET_META_MSG
)
...
...
@@ -492,33 +495,25 @@ class NixlConnectorWorker:
got_metadata_time
=
time
.
perf_counter
()
# Register Remote agent.
remote_agent_name
=
self
.
add_remote_agent
(
metadata
,
rank
)
remote_agent_name
=
self
.
add_remote_agent
(
metadata
,
rank
,
remote_tp_size
)
setup_agent_time
=
time
.
perf_counter
()
logger
.
debug
(
"NIXL handshake: get metadata took: %s"
,
got_metadata_time
-
start_time
)
logger
.
debug
(
"NIXL handshake: add agent took: %s"
,
setup_agent_time
-
got_metadata_time
)
return
metadata
,
remote_agent_name
return
remote_agent_name
# Handshake with remote agent-rank0 first to get the tp_size of remote
path
=
make_zmq_path
(
"tcp"
,
host
,
port
)
logger
.
debug
(
"Querying master rank metadata on path: %s"
,
path
)
rank_to_agent_name
:
dict
[
int
,
str
]
=
{}
metadata
,
rank_to_agent_name
[
0
]
=
handshake
(
path
,
0
)
# Handshake only with the other TP remote the current local rank will
# Handshake only with the remote TP rank that current local rank will
# pull from. With homogeneous TP it happens to be the same rank_i.
tp_ratio
=
self
.
_tp_size
[
self
.
engine_id
]
//
metadata
.
tp_size
tp_ratio
=
self
.
_tp_size
[
self
.
engine_id
]
//
remote_
tp_size
p_remote_rank
=
self
.
tp_rank
//
tp_ratio
if
p_remote_rank
>
0
:
path
=
make_zmq_path
(
"tcp"
,
host
,
port
+
p_remote_rank
)
logger
.
debug
(
"Querying metadata on path: %s at remote rank %s"
,
path
,
p_remote_rank
)
_
,
rank_to_agent_name
[
p_remote_rank
]
=
handshake
(
path
,
p_remote_rank
)
return
rank_to_agent_name
path
=
make_zmq_path
(
"tcp"
,
host
,
port
+
p_remote_rank
)
logger
.
debug
(
"Querying metadata on path: %s at remote rank %s"
,
path
,
p_remote_rank
)
# Remote rank -> agent name.
return
{
p_remote_rank
:
handshake
(
path
,
p_remote_rank
)}
def
register_kv_caches
(
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
]):
"""Register the KV Cache data in nixl."""
...
...
@@ -645,7 +640,6 @@ class NixlConnectorWorker:
agent_metadata
=
self
.
nixl_wrapper
.
get_agent_metadata
(),
kv_caches_base_addr
=
self
.
kv_caches_base_addr
[
self
.
engine_id
],
num_blocks
=
self
.
num_blocks
,
tp_size
=
self
.
world_size
,
block_len
=
self
.
block_len
,
attn_backend_name
=
self
.
backend_name
)
ready_event
=
threading
.
Event
()
...
...
@@ -659,7 +653,8 @@ class NixlConnectorWorker:
def
add_remote_agent
(
self
,
nixl_agent_meta
:
NixlAgentMetadata
,
remote_tp_rank
:
int
=
0
)
->
str
:
remote_tp_rank
:
int
=
0
,
remote_tp_size
:
int
=
1
)
->
str
:
"""
Add the remote NIXL agent and prepare the descriptors for reading cache
blocks from remote.
...
...
@@ -704,9 +699,9 @@ class NixlConnectorWorker:
return
self
.
_remote_agents
[
engine_id
][
remote_tp_rank
]
if
engine_id
in
self
.
_tp_size
:
assert
self
.
_tp_size
[
engine_id
]
==
nixl_agent_meta
.
tp_size
assert
self
.
_tp_size
[
engine_id
]
==
remote_
tp_size
else
:
self
.
_tp_size
[
engine_id
]
=
nixl_agent_meta
.
tp_size
self
.
_tp_size
[
engine_id
]
=
remote_
tp_size
# We may eventually enable this after asserting equality in cache
# layout and close outputs.
assert
nixl_agent_meta
.
attn_backend_name
==
self
.
backend_name
...
...
@@ -756,33 +751,31 @@ class NixlConnectorWorker:
# rank. With heterogeneous TP, prepare the descriptors by splitting the
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
p_remote_tp_rank
=
self
.
tp_rank
//
tp_ratio
# Only register the remote's descriptors if current rank pulls from it.
if
p_remote_tp_rank
==
remote_tp_rank
:
self
.
kv_caches_base_addr
[
engine_id
]
=
nixl_agent_meta
.
kv_caches_base_addr
rank_offset
=
self
.
tp_rank
%
tp_ratio
*
self
.
block_len
\
if
not
(
self
.
use_mla
or
is_kv_replicated
)
else
0
# Register all remote blocks, but only the corresponding kv heads.
for
base_addr
in
nixl_agent_meta
.
kv_caches_base_addr
:
for
block_id
in
range
(
nixl_agent_meta
.
num_blocks
):
block_offset
=
block_id
*
nixl_agent_meta
.
block_len
# For each block, grab the heads chunk belonging to rank_i
# of size remote_nheads // tp_ratio, which correspond to
# self.block_len == remote_block_len//tp_ratio bytes.
addr
=
base_addr
+
block_offset
+
rank_offset
# (addr, len, device id)
blocks_data
.
append
((
addr
,
self
.
block_len
,
remote_tp_rank
))
logger
.
debug
(
"Created %s blocks for dst engine %s with remote rank %s and "
"local rank %s"
,
len
(
blocks_data
),
engine_id
,
remote_tp_rank
,
self
.
tp_rank
)
self
.
kv_caches_base_addr
[
engine_id
]
=
nixl_agent_meta
.
kv_caches_base_addr
rank_offset
=
self
.
tp_rank
%
tp_ratio
*
self
.
block_len
\
if
not
(
self
.
use_mla
or
is_kv_replicated
)
else
0
# Register all remote blocks, but only the corresponding kv heads.
for
base_addr
in
nixl_agent_meta
.
kv_caches_base_addr
:
for
block_id
in
range
(
nixl_agent_meta
.
num_blocks
):
block_offset
=
block_id
*
nixl_agent_meta
.
block_len
# For each block, grab the heads chunk belonging to rank_i
# of size remote_nheads // tp_ratio, which correspond to
# self.block_len == remote_block_len//tp_ratio bytes.
addr
=
base_addr
+
block_offset
+
rank_offset
# (addr, len, device id)
blocks_data
.
append
((
addr
,
self
.
block_len
,
remote_tp_rank
))
logger
.
debug
(
"Created %s blocks for dst engine %s with remote rank %s and "
"local rank %s"
,
len
(
blocks_data
),
engine_id
,
remote_tp_rank
,
self
.
tp_rank
)
# Register with NIXL.
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
blocks_data
,
"VRAM"
)
self
.
dst_xfer_side_handles
[
engine_id
]
=
self
.
nixl_wrapper
.
prep_xfer_dlist
(
remote_agent_name
,
descs
)
# Register with NIXL.
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
blocks_data
,
"VRAM"
)
self
.
dst_xfer_side_handles
[
engine_id
]
=
self
.
nixl_wrapper
.
prep_xfer_dlist
(
remote_agent_name
,
descs
)
return
remote_agent_name
...
...
@@ -917,7 +910,7 @@ class NixlConnectorWorker:
if
fut
is
None
:
fut
=
self
.
_handshake_initiation_executor
.
submit
(
self
.
_nixl_handshake
,
meta
.
remote_host
,
meta
.
remote_port
)
meta
.
remote_port
,
meta
.
tp_size
)
self
.
_handshake_futures
[
remote_engine_id
]
=
fut
def
done_callback
(
f
:
Future
[
dict
[
int
,
str
]],
...
...
@@ -957,13 +950,9 @@ class NixlConnectorWorker:
remote_block_ids
=
meta
.
remote_block_ids
,
)
def
_read_blocks
(
self
,
local_block_ids
:
list
[
int
],
remote_block_ids
:
list
[
int
],
dst_engine_id
:
str
,
request_id
:
str
,
):
def
_read_blocks
(
self
,
local_block_ids
:
list
[
int
],
remote_block_ids
:
list
[
int
],
dst_engine_id
:
str
,
request_id
:
str
):
# NOTE(rob): having the staging blocks be on the READER side is
# not going to work well (since we will have to call rearrange tensors).
# after we detect the txn is complete (which means we cannot make the
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment