Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
25826835
Unverified
Commit
25826835
authored
Jun 26, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Jun 25, 2025
Browse files
[PD] Skip `tp_size` exchange with rank0 (#19413)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
754b00ed
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
72 additions
and
66 deletions
+72
-66
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+23
-6
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+49
-60
No files found.
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
25826835
...
@@ -7,6 +7,8 @@ from collections import defaultdict
...
@@ -7,6 +7,8 @@ from collections import defaultdict
from
typing
import
Optional
from
typing
import
Optional
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
pytest
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
KVConnectorRole
,
NixlAgentMetadata
,
NixlConnector
,
NixlConnectorMetadata
,
KVConnectorRole
,
NixlAgentMetadata
,
NixlConnector
,
NixlConnectorMetadata
,
NixlConnectorWorker
)
NixlConnectorWorker
)
...
@@ -161,7 +163,8 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
...
@@ -161,7 +163,8 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
_hand_shake_latency
=
hand_shake_latency
self
.
_hand_shake_latency
=
hand_shake_latency
def
_nixl_handshake
(
self
,
host
:
str
,
port
:
int
)
->
dict
[
int
,
str
]:
def
_nixl_handshake
(
self
,
host
:
str
,
port
:
int
,
remote_tp_size
:
int
)
->
dict
[
int
,
str
]:
# Mimic slow _nixl_handshake, as well as bypass zmq communication.
# Mimic slow _nixl_handshake, as well as bypass zmq communication.
time
.
sleep
(
self
.
_hand_shake_latency
)
time
.
sleep
(
self
.
_hand_shake_latency
)
# These should've been done in register_kv_caches(), called by
# These should've been done in register_kv_caches(), called by
...
@@ -177,10 +180,10 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
...
@@ -177,10 +180,10 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
agent_metadata
=
FakeNixlWrapper
.
AGENT_METADATA
,
agent_metadata
=
FakeNixlWrapper
.
AGENT_METADATA
,
kv_caches_base_addr
=
[
0
],
kv_caches_base_addr
=
[
0
],
num_blocks
=
1
,
num_blocks
=
1
,
tp_size
=
1
,
block_len
=
self
.
block_len
,
block_len
=
self
.
block_len
,
attn_backend_name
=
self
.
backend_name
,
attn_backend_name
=
self
.
backend_name
,
))
),
remote_tp_size
=
remote_tp_size
)
return
{
0
:
remote_agent_name
}
return
{
0
:
remote_agent_name
}
...
@@ -233,6 +236,8 @@ class TestNixlHandshake:
...
@@ -233,6 +236,8 @@ class TestNixlHandshake:
"localhost"
,
"localhost"
,
"remote_port"
:
"remote_port"
:
1234
,
1234
,
"remote_tp_size"
:
1
,
})
})
connector
.
bind_connector_metadata
(
metadata
)
connector
.
bind_connector_metadata
(
metadata
)
...
@@ -259,13 +264,23 @@ class TestNixlHandshake:
...
@@ -259,13 +264,23 @@ class TestNixlHandshake:
@
patch
(
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FakeNixlWrapper
)
FakeNixlWrapper
)
@
pytest
.
mark
.
parametrize
(
"decode_tp_size, prefill_tp_size"
,
[
(
1
,
1
),
(
2
,
1
),
(
4
,
2
),
(
4
,
4
),
])
def
test_async_load_kv
(
def
test_async_load_kv
(
self
,
self
,
# dist_init is a fixture that initializes the distributed environment.
# Fixture that initializes the distributed environment.
dist_init
):
dist_init
,
# Simulate consumer-producer TP sizes.
decode_tp_size
,
prefill_tp_size
):
"""Test that NixlConnector's start_load_kv should be non-blocking."""
"""Test that NixlConnector's start_load_kv should be non-blocking."""
vllm_config
=
create_vllm_config
()
vllm_config
=
create_vllm_config
()
vllm_config
.
parallel_config
.
tensor_parallel_size
=
decode_tp_size
# Test worker role in decode server.
# Test worker role in decode server.
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
...
@@ -280,6 +295,7 @@ class TestNixlHandshake:
...
@@ -280,6 +295,7 @@ class TestNixlHandshake:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_host"
:
"localhost"
,
"remote_host"
:
"localhost"
,
"remote_port"
:
1234
,
"remote_port"
:
1234
,
"remote_tp_size"
:
prefill_tp_size
,
})
})
connector
.
bind_connector_metadata
(
metadata
)
connector
.
bind_connector_metadata
(
metadata
)
...
@@ -329,6 +345,7 @@ class TestNixlHandshake:
...
@@ -329,6 +345,7 @@ class TestNixlHandshake:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_host"
:
"localhost"
,
"remote_host"
:
"localhost"
,
"remote_port"
:
1234
,
"remote_port"
:
1234
,
"remote_tp_size"
:
1
,
})
})
connector
.
bind_connector_metadata
(
metadata
)
connector
.
bind_connector_metadata
(
metadata
)
...
...
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
25826835
...
@@ -62,7 +62,6 @@ class NixlAgentMetadata(
...
@@ -62,7 +62,6 @@ class NixlAgentMetadata(
agent_metadata
:
bytes
agent_metadata
:
bytes
kv_caches_base_addr
:
list
[
int
]
kv_caches_base_addr
:
list
[
int
]
num_blocks
:
int
num_blocks
:
int
tp_size
:
int
block_len
:
int
block_len
:
int
attn_backend_name
:
str
attn_backend_name
:
str
...
@@ -73,7 +72,8 @@ class ReqMeta:
...
@@ -73,7 +72,8 @@ class ReqMeta:
remote_block_ids
:
list
[
int
]
remote_block_ids
:
list
[
int
]
remote_host
:
str
remote_host
:
str
remote_port
:
int
remote_port
:
int
remote_engine_id
:
EngineId
remote_engine_id
:
str
tp_size
:
int
class
NixlConnectorMetadata
(
KVConnectorMetadata
):
class
NixlConnectorMetadata
(
KVConnectorMetadata
):
...
@@ -93,6 +93,8 @@ class NixlConnectorMetadata(KVConnectorMetadata):
...
@@ -93,6 +93,8 @@ class NixlConnectorMetadata(KVConnectorMetadata):
remote_engine_id
=
kv_transfer_params
[
"remote_engine_id"
],
remote_engine_id
=
kv_transfer_params
[
"remote_engine_id"
],
remote_host
=
kv_transfer_params
[
"remote_host"
],
remote_host
=
kv_transfer_params
[
"remote_host"
],
remote_port
=
kv_transfer_params
[
"remote_port"
],
remote_port
=
kv_transfer_params
[
"remote_port"
],
# P workers don't need to receive tp_size from proxy here.
tp_size
=
kv_transfer_params
.
get
(
"tp_size"
,
1
),
)
)
...
@@ -330,7 +332,7 @@ class NixlConnectorScheduler:
...
@@ -330,7 +332,7 @@ class NixlConnectorScheduler:
remote_engine_id
=
self
.
engine_id
,
remote_engine_id
=
self
.
engine_id
,
remote_host
=
self
.
side_channel_host
,
remote_host
=
self
.
side_channel_host
,
remote_port
=
self
.
side_channel_port
,
remote_port
=
self
.
side_channel_port
,
)
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
)
class
NixlConnectorWorker
:
class
NixlConnectorWorker
:
...
@@ -473,7 +475,8 @@ class NixlConnectorWorker:
...
@@ -473,7 +475,8 @@ class NixlConnectorWorker:
"Connection listener got unexpected message %s"
,
msg
)
"Connection listener got unexpected message %s"
,
msg
)
sock
.
send_multipart
((
identity
,
b
""
,
encoded_data
))
sock
.
send_multipart
((
identity
,
b
""
,
encoded_data
))
def
_nixl_handshake
(
self
,
host
:
str
,
port
:
int
)
->
dict
[
int
,
str
]:
def
_nixl_handshake
(
self
,
host
:
str
,
port
:
int
,
remote_tp_size
:
int
)
->
dict
[
int
,
str
]:
"""Do a NIXL handshake with a remote instance."""
"""Do a NIXL handshake with a remote instance."""
start_time
=
time
.
perf_counter
()
start_time
=
time
.
perf_counter
()
...
@@ -482,7 +485,7 @@ class NixlConnectorWorker:
...
@@ -482,7 +485,7 @@ class NixlConnectorWorker:
# a hack to keep us moving. We will switch when moving to etcd
# a hack to keep us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
# or where we have a single ZMQ socket in the scheduler.
def
handshake
(
path
:
str
,
rank
:
int
)
->
tuple
[
NixlAgentMetadata
,
str
]
:
def
handshake
(
path
:
str
,
rank
:
int
)
->
str
:
# Send query for the request.
# Send query for the request.
with
zmq_ctx
(
zmq
.
REQ
,
path
)
as
sock
:
with
zmq_ctx
(
zmq
.
REQ
,
path
)
as
sock
:
sock
.
send
(
GET_META_MSG
)
sock
.
send
(
GET_META_MSG
)
...
@@ -492,33 +495,25 @@ class NixlConnectorWorker:
...
@@ -492,33 +495,25 @@ class NixlConnectorWorker:
got_metadata_time
=
time
.
perf_counter
()
got_metadata_time
=
time
.
perf_counter
()
# Register Remote agent.
# Register Remote agent.
remote_agent_name
=
self
.
add_remote_agent
(
metadata
,
rank
)
remote_agent_name
=
self
.
add_remote_agent
(
metadata
,
rank
,
remote_tp_size
)
setup_agent_time
=
time
.
perf_counter
()
setup_agent_time
=
time
.
perf_counter
()
logger
.
debug
(
"NIXL handshake: get metadata took: %s"
,
logger
.
debug
(
"NIXL handshake: get metadata took: %s"
,
got_metadata_time
-
start_time
)
got_metadata_time
-
start_time
)
logger
.
debug
(
"NIXL handshake: add agent took: %s"
,
logger
.
debug
(
"NIXL handshake: add agent took: %s"
,
setup_agent_time
-
got_metadata_time
)
setup_agent_time
-
got_metadata_time
)
return
metadata
,
remote_agent_name
return
remote_agent_name
# Handshake with remote agent-rank0 first to get the tp_size of remote
path
=
make_zmq_path
(
"tcp"
,
host
,
port
)
logger
.
debug
(
"Querying master rank metadata on path: %s"
,
path
)
rank_to_agent_name
:
dict
[
int
,
str
]
=
{}
metadata
,
rank_to_agent_name
[
0
]
=
handshake
(
path
,
0
)
# Handshake only with the
other TP remote
th
e
current local rank will
# Handshake only with the
remote TP rank
th
at
current local rank will
# pull from. With homogeneous TP it happens to be the same rank_i.
# pull from. With homogeneous TP it happens to be the same rank_i.
tp_ratio
=
self
.
_tp_size
[
self
.
engine_id
]
//
metadata
.
tp_size
tp_ratio
=
self
.
_tp_size
[
self
.
engine_id
]
//
remote_
tp_size
p_remote_rank
=
self
.
tp_rank
//
tp_ratio
p_remote_rank
=
self
.
tp_rank
//
tp_ratio
if
p_remote_rank
>
0
:
path
=
make_zmq_path
(
"tcp"
,
host
,
port
+
p_remote_rank
)
path
=
make_zmq_path
(
"tcp"
,
host
,
port
+
p_remote_rank
)
logger
.
debug
(
"Querying metadata on path: %s at remote rank %s"
,
logger
.
debug
(
"Querying metadata on path: %s at remote rank %s"
,
path
,
path
,
p_remote_rank
)
p_remote_rank
)
_
,
rank_to_agent_name
[
p_remote_rank
]
=
handshake
(
# Remote rank -> agent name.
path
,
p_remote_rank
)
return
{
p_remote_rank
:
handshake
(
path
,
p_remote_rank
)}
return
rank_to_agent_name
def
register_kv_caches
(
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
]):
def
register_kv_caches
(
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
]):
"""Register the KV Cache data in nixl."""
"""Register the KV Cache data in nixl."""
...
@@ -645,7 +640,6 @@ class NixlConnectorWorker:
...
@@ -645,7 +640,6 @@ class NixlConnectorWorker:
agent_metadata
=
self
.
nixl_wrapper
.
get_agent_metadata
(),
agent_metadata
=
self
.
nixl_wrapper
.
get_agent_metadata
(),
kv_caches_base_addr
=
self
.
kv_caches_base_addr
[
self
.
engine_id
],
kv_caches_base_addr
=
self
.
kv_caches_base_addr
[
self
.
engine_id
],
num_blocks
=
self
.
num_blocks
,
num_blocks
=
self
.
num_blocks
,
tp_size
=
self
.
world_size
,
block_len
=
self
.
block_len
,
block_len
=
self
.
block_len
,
attn_backend_name
=
self
.
backend_name
)
attn_backend_name
=
self
.
backend_name
)
ready_event
=
threading
.
Event
()
ready_event
=
threading
.
Event
()
...
@@ -659,7 +653,8 @@ class NixlConnectorWorker:
...
@@ -659,7 +653,8 @@ class NixlConnectorWorker:
def
add_remote_agent
(
self
,
def
add_remote_agent
(
self
,
nixl_agent_meta
:
NixlAgentMetadata
,
nixl_agent_meta
:
NixlAgentMetadata
,
remote_tp_rank
:
int
=
0
)
->
str
:
remote_tp_rank
:
int
=
0
,
remote_tp_size
:
int
=
1
)
->
str
:
"""
"""
Add the remote NIXL agent and prepare the descriptors for reading cache
Add the remote NIXL agent and prepare the descriptors for reading cache
blocks from remote.
blocks from remote.
...
@@ -704,9 +699,9 @@ class NixlConnectorWorker:
...
@@ -704,9 +699,9 @@ class NixlConnectorWorker:
return
self
.
_remote_agents
[
engine_id
][
remote_tp_rank
]
return
self
.
_remote_agents
[
engine_id
][
remote_tp_rank
]
if
engine_id
in
self
.
_tp_size
:
if
engine_id
in
self
.
_tp_size
:
assert
self
.
_tp_size
[
engine_id
]
==
nixl_agent_meta
.
tp_size
assert
self
.
_tp_size
[
engine_id
]
==
remote_
tp_size
else
:
else
:
self
.
_tp_size
[
engine_id
]
=
nixl_agent_meta
.
tp_size
self
.
_tp_size
[
engine_id
]
=
remote_
tp_size
# We may eventually enable this after asserting equality in cache
# We may eventually enable this after asserting equality in cache
# layout and close outputs.
# layout and close outputs.
assert
nixl_agent_meta
.
attn_backend_name
==
self
.
backend_name
assert
nixl_agent_meta
.
attn_backend_name
==
self
.
backend_name
...
@@ -756,9 +751,7 @@ class NixlConnectorWorker:
...
@@ -756,9 +751,7 @@ class NixlConnectorWorker:
# rank. With heterogeneous TP, prepare the descriptors by splitting the
# rank. With heterogeneous TP, prepare the descriptors by splitting the
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
p_remote_tp_rank
=
self
.
tp_rank
//
tp_ratio
# Only register the remote's descriptors if current rank pulls from it.
# Only register the remote's descriptors if current rank pulls from it.
if
p_remote_tp_rank
==
remote_tp_rank
:
self
.
kv_caches_base_addr
[
self
.
kv_caches_base_addr
[
engine_id
]
=
nixl_agent_meta
.
kv_caches_base_addr
engine_id
]
=
nixl_agent_meta
.
kv_caches_base_addr
rank_offset
=
self
.
tp_rank
%
tp_ratio
*
self
.
block_len
\
rank_offset
=
self
.
tp_rank
%
tp_ratio
*
self
.
block_len
\
...
@@ -917,7 +910,7 @@ class NixlConnectorWorker:
...
@@ -917,7 +910,7 @@ class NixlConnectorWorker:
if
fut
is
None
:
if
fut
is
None
:
fut
=
self
.
_handshake_initiation_executor
.
submit
(
fut
=
self
.
_handshake_initiation_executor
.
submit
(
self
.
_nixl_handshake
,
meta
.
remote_host
,
self
.
_nixl_handshake
,
meta
.
remote_host
,
meta
.
remote_port
)
meta
.
remote_port
,
meta
.
tp_size
)
self
.
_handshake_futures
[
remote_engine_id
]
=
fut
self
.
_handshake_futures
[
remote_engine_id
]
=
fut
def
done_callback
(
f
:
Future
[
dict
[
int
,
str
]],
def
done_callback
(
f
:
Future
[
dict
[
int
,
str
]],
...
@@ -957,13 +950,9 @@ class NixlConnectorWorker:
...
@@ -957,13 +950,9 @@ class NixlConnectorWorker:
remote_block_ids
=
meta
.
remote_block_ids
,
remote_block_ids
=
meta
.
remote_block_ids
,
)
)
def
_read_blocks
(
def
_read_blocks
(
self
,
local_block_ids
:
list
[
int
],
self
,
remote_block_ids
:
list
[
int
],
dst_engine_id
:
str
,
local_block_ids
:
list
[
int
],
request_id
:
str
):
remote_block_ids
:
list
[
int
],
dst_engine_id
:
str
,
request_id
:
str
,
):
# NOTE(rob): having the staging blocks be on the READER side is
# NOTE(rob): having the staging blocks be on the READER side is
# not going to work well (since we will have to call rearrange tensors).
# not going to work well (since we will have to call rearrange tensors).
# after we detect the txn is complete (which means we cannot make the
# after we detect the txn is complete (which means we cannot make the
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment