Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b9f61e13
Unverified
Commit
b9f61e13
authored
Jun 01, 2025
by
Robert Shaw
Committed by
GitHub
Jun 02, 2025
Browse files
[Bugfix][Nixl] Fix DP Metadata Handshake (#19008)
Signed-off-by:
rshaw@neuralmagic.com
<
robertgshaw2@gmail.com
>
parent
d6fd3a33
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
32 deletions
+36
-32
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+36
-32
No files found.
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
b9f61e13
...
...
@@ -19,7 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1
,
KVConnectorMetadata
,
KVConnectorRole
)
from
vllm.distributed.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tp_group
,
get_world_group
)
get_tp_group
)
from
vllm.logger
import
init_logger
from
vllm.utils
import
make_zmq_path
,
make_zmq_socket
,
round_down
from
vllm.v1.core.sched.output
import
SchedulerOutput
...
...
@@ -172,6 +172,11 @@ class NixlConnectorScheduler:
self
.
vllm_config
=
vllm_config
self
.
block_size
=
vllm_config
.
cache_config
.
block_size
self
.
engine_id
=
engine_id
self
.
side_channel_host
=
envs
.
VLLM_NIXL_SIDE_CHANNEL_HOST
self
.
side_channel_port
=
(
envs
.
VLLM_NIXL_SIDE_CHANNEL_PORT
+
vllm_config
.
parallel_config
.
data_parallel_rank_local
*
vllm_config
.
parallel_config
.
tensor_parallel_size
)
logger
.
info
(
"Initializing NIXL Scheduler %s"
,
engine_id
)
# Requests that need to start recv.
...
...
@@ -310,8 +315,8 @@ class NixlConnectorScheduler:
do_remote_decode
=
False
,
remote_block_ids
=
computed_block_ids
,
remote_engine_id
=
self
.
engine_id
,
remote_host
=
envs
.
VLLM_NIXL_SIDE_CHANNEL_HOST
,
remote_port
=
envs
.
VLLM_NIXL_SIDE_CHANNEL_PORT
,
remote_host
=
self
.
side_channel_host
,
remote_port
=
self
.
side_channel_port
,
)
...
...
@@ -330,11 +335,19 @@ class NixlConnectorWorker:
# Map of engine_id -> agent_name.
self
.
_remote_agents
:
dict
[
str
,
str
]
=
{}
# NIXL handshake port.
# NOTE(rob): Within a DP group, each DP rank gets its own
# base port (which is sent in the KVTransferParams).
# Each TP rank listens/queries on the base_port + tp_rank.
self
.
side_channel_port
=
(
envs
.
VLLM_NIXL_SIDE_CHANNEL_PORT
+
vllm_config
.
parallel_config
.
data_parallel_rank_local
*
vllm_config
.
parallel_config
.
tensor_parallel_size
)
# Metadata.
self
.
engine_id
=
engine_id
self
.
rank
=
get_tensor_model_parallel_rank
()
self
.
tp_
rank
=
get_tensor_model_parallel_rank
()
self
.
world_size
=
get_tensor_model_parallel_world_size
()
self
.
world_rank
=
get_world_group
().
rank_in_group
self
.
tp_group
=
get_tp_group
()
# KV Caches and nixl tracking data.
...
...
@@ -383,16 +396,11 @@ class NixlConnectorWorker:
@
staticmethod
def
_nixl_handshake_listener
(
metadata
:
NixlAgentMetadata
,
ready_event
:
threading
.
Event
,
world
_rank
:
int
):
ready_event
:
threading
.
Event
,
base_port
:
int
,
tp
_rank
:
int
):
"""Background thread for getting new NIXL handshakes."""
# NOTE(rob): this is a simple implementation. We will move
# to a better approach like an ETCD server in the future.
# NOTE(rob): to support heterogeneous TP, we will have to
# move this into the scheduler rather than worker, since
# each rank needs the metadata of all other ranks (whereas
# in this setup, each rank only gets one other rank's meta.
# to a better approach via HTTP endpoint soon.
encoder
=
msgspec
.
msgpack
.
Encoder
()
encoded_data
=
encoder
.
encode
(
metadata
)
...
...
@@ -402,11 +410,7 @@ class NixlConnectorWorker:
# Listen for new requests for metadata.
host
=
envs
.
VLLM_NIXL_SIDE_CHANNEL_HOST
# NOTE(rob): we need each rank to have a unique port. This
# hack to keeps us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
port
=
envs
.
VLLM_NIXL_SIDE_CHANNEL_PORT
+
world_rank
path
=
make_zmq_path
(
"tcp"
,
host
,
port
)
path
=
make_zmq_path
(
"tcp"
,
host
,
base_port
+
tp_rank
)
logger
.
debug
(
"Starting listening on path: %s"
,
path
)
with
zmq_ctx
(
zmq
.
ROUTER
,
path
)
as
sock
:
ready_event
.
set
()
...
...
@@ -421,10 +425,10 @@ class NixlConnectorWorker:
"""Do a NIXL handshake with a remote instance."""
start_time
=
time
.
perf_counter
()
# NOTE(rob): we need each rank to have a unique port.
This is
# a hack to keep us moving. We will switch when
moving to etcd
#
or where we have a single ZMQ socket in the scheduler
.
path
=
make_zmq_path
(
"tcp"
,
host
,
port
+
self
.
world
_rank
)
# NOTE(rob): we need each
tp_
rank to have a unique port.
#
This is
a hack to keep us moving. We will switch when
#
we switch to HTTP-based NIXL metadata exchange
.
path
=
make_zmq_path
(
"tcp"
,
host
,
port
+
self
.
tp
_rank
)
logger
.
debug
(
"Querying metadata on path: %s"
,
path
)
with
zmq_ctx
(
zmq
.
REQ
,
path
)
as
sock
:
# Send query for the request.
...
...
@@ -532,7 +536,7 @@ class NixlConnectorWorker:
ready_event
=
threading
.
Event
()
self
.
_nixl_handshake_listener_t
=
threading
.
Thread
(
target
=
self
.
_nixl_handshake_listener
,
args
=
(
metadata
,
ready_event
,
self
.
world
_rank
),
args
=
(
metadata
,
ready_event
,
self
.
side_channel_port
,
self
.
tp
_rank
),
daemon
=
True
,
name
=
"nixl_handshake_listener"
)
self
.
_nixl_handshake_listener_t
.
start
()
...
...
@@ -556,9 +560,9 @@ class NixlConnectorWorker:
block_offset
=
block_id
*
self
.
block_len
# (addr, len, device id)
blocks_data
.
append
(
(
base_addr
+
block_offset
,
self
.
block_len
,
self
.
rank
))
logger
.
debug
(
"Created %s blocks for src engine %s and rank %s"
,
len
(
blocks_data
),
self
.
engine_id
,
self
.
rank
)
(
base_addr
+
block_offset
,
self
.
block_len
,
self
.
tp_
rank
))
logger
.
debug
(
"Created %s blocks for src engine %s and
tp_
rank %s"
,
len
(
blocks_data
),
self
.
engine_id
,
self
.
tp_
rank
)
# Register with NIXL.
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
blocks_data
,
"VRAM"
)
...
...
@@ -573,9 +577,9 @@ class NixlConnectorWorker:
block_offset
=
block_id
*
self
.
block_len
# (addr, len, device id)
blocks_data
.
append
(
(
base_addr
+
block_offset
,
self
.
block_len
,
self
.
rank
))
logger
.
debug
(
"Created %s blocks for dst engine %s and rank %s"
,
len
(
blocks_data
),
engine_id
,
self
.
rank
)
(
base_addr
+
block_offset
,
self
.
block_len
,
self
.
tp_
rank
))
logger
.
debug
(
"Created %s blocks for dst engine %s and
tp_
rank %s"
,
len
(
blocks_data
),
engine_id
,
self
.
tp_
rank
)
# Register with NIXL.
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
blocks_data
,
"VRAM"
)
...
...
@@ -600,14 +604,14 @@ class NixlConnectorWorker:
if
len
(
done_sending
)
>
0
or
len
(
done_recving
)
>
0
:
logger
.
debug
(
"Rank %s, get_finished: %s requests done sending "
"and %s requests done recving"
,
self
.
rank
,
len
(
done_sending
),
len
(
done_recving
))
"and %s requests done recving"
,
self
.
tp_
rank
,
len
(
done_sending
),
len
(
done_recving
))
if
self
.
world_size
==
1
:
return
done_sending
,
done_recving
# Rank 0: get finished from all other ranks.
if
self
.
rank
==
0
:
if
self
.
tp_
rank
==
0
:
for
req_id
in
done_sending
:
self
.
_done_sending_count
[
req_id
]
+=
1
for
req_id
in
done_recving
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment