Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
53c9a7ce
Unverified
Commit
53c9a7ce
authored
Oct 13, 2025
by
Will Eaton
Committed by
GitHub
Oct 13, 2025
Browse files
[P/D] [NixlConnector] kv load recovery integration (#26171)
Signed-off-by:
Will Eaton
<
weaton@redhat.com
>
parent
0d21b9b5
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
252 additions
and
26 deletions
+252
-26
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+142
-1
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+109
-24
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+1
-1
No files found.
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
53c9a7ce
...
...
@@ -190,7 +190,6 @@ def _make_fake_nixl_pkg():
# Copy of FakeNixlWrapper implementation for Ray workers
import uuid
from collections import defaultdict
from typing import Optional
{
fake_nixl_source
}
...
...
@@ -1143,3 +1142,145 @@ def test_aborted_request_removed_from_worker_in_batch(dist_init):
# After abort, the worker should not keep tracking it as "in-batch"
assert
req
.
request_id
not
in
connector
.
connector_worker
.
_reqs_to_process
#### Model Runner end ####
class
FailingNixlWrapper
(
FakeNixlWrapper
):
"""Mock NixlWrapper that fails on specific operations."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
fail_handshake
=
False
self
.
fail_transfer_setup
=
False
self
.
fail_send_notif
=
False
def
add_remote_agent
(
self
,
agent_metadata
:
bytes
)
->
str
:
if
self
.
fail_handshake
:
from
zmq.error
import
Again
raise
Again
(
"Simulated timeout failure"
)
return
super
().
add_remote_agent
(
agent_metadata
)
def
make_prepped_xfer
(
self
,
xfer_type
:
str
,
local_xfer_side_handle
:
int
,
local_block_descs_ids
:
list
[
int
],
remote_xfer_side_handle
:
int
,
remote_block_descs_ids
:
list
[
int
],
notif_msg
:
bytes
|
None
=
None
,
)
->
int
:
if
self
.
fail_transfer_setup
:
# classic RuntimeError to simulate failure
raise
RuntimeError
(
"BAD STATUS"
)
return
super
().
make_prepped_xfer
(
xfer_type
,
local_xfer_side_handle
,
local_block_descs_ids
,
remote_xfer_side_handle
,
remote_block_descs_ids
,
notif_msg
,
)
def
send_notif
(
self
,
agent_name
:
str
,
notif_msg
:
bytes
)
->
None
:
if
self
.
fail_send_notif
:
raise
RuntimeError
(
"Simulated send_notif failure"
)
return
super
().
send_notif
(
agent_name
,
notif_msg
)
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FailingNixlWrapper
,
)
def
test_handshake_failure_returns_finished
(
dist_init
):
"""Test that handshake failures mark blocks invalid and return via get_finished."""
vllm_config
=
create_vllm_config
()
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0.1
)
connector
.
connector_worker
.
nixl_wrapper
.
fail_handshake
=
True
request_id
=
"test_handshake_fail"
metadata
=
NixlConnectorMetadata
()
metadata
.
add_new_req
(
request_id
=
request_id
,
local_block_ids
=
[
1
,
2
,
3
],
kv_transfer_params
=
{
"remote_block_ids"
:
[
4
,
5
,
6
],
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_host"
:
"localhost"
,
"remote_port"
:
1234
,
"remote_tp_size"
:
1
,
},
)
connector
.
bind_connector_metadata
(
metadata
)
dummy_ctx
=
ForwardContext
(
no_compile_layers
=
{},
attn_metadata
=
{},
virtual_engine
=
0
,
)
connector
.
start_load_kv
(
dummy_ctx
)
# Wait for handshake to fail
time
.
sleep
(
0.3
)
# Check that blocks were marked invalid
invalid_blocks
=
connector
.
get_block_ids_with_load_errors
()
assert
invalid_blocks
==
{
1
,
2
,
3
}
# Check that request appears in get_finished
_
,
done_recving
=
connector
.
get_finished
(
finished_req_ids
=
set
())
assert
request_id
in
done_recving
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FailingNixlWrapper
,
)
def
test_transfer_setup_failure_returns_finished
(
dist_init
):
"""Test that transfer setup failures mark blocks invalid
and return via get_finished."""
vllm_config
=
create_vllm_config
()
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
connector
.
connector_worker
.
nixl_wrapper
.
fail_transfer_setup
=
True
request_id
=
"test_transfer_fail"
metadata
=
NixlConnectorMetadata
()
metadata
.
add_new_req
(
request_id
=
request_id
,
local_block_ids
=
[
7
,
8
,
9
],
kv_transfer_params
=
{
"remote_block_ids"
:
[
10
,
11
,
12
],
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_host"
:
"localhost"
,
"remote_port"
:
1234
,
"remote_tp_size"
:
1
,
},
)
connector
.
bind_connector_metadata
(
metadata
)
dummy_ctx
=
ForwardContext
(
no_compile_layers
=
{},
attn_metadata
=
{},
virtual_engine
=
0
,
)
connector
.
start_load_kv
(
dummy_ctx
)
# Wait for handshake to complete and process ready_requests
connector
.
bind_connector_metadata
(
NixlConnectorMetadata
())
time
.
sleep
(
0.1
)
connector
.
start_load_kv
(
dummy_ctx
)
# check that blocks were marked invalid
invalid_blocks
=
connector
.
get_block_ids_with_load_errors
()
assert
invalid_blocks
==
{
7
,
8
,
9
}
# ensure request appears in get_finished
_
,
done_recving
=
connector
.
get_finished
(
finished_req_ids
=
set
())
assert
request_id
in
done_recving
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
53c9a7ce
...
...
@@ -68,6 +68,7 @@ except ImportError:
NixlWrapper
=
None
nixlXferTelemetry
=
None
try
:
from
nixl._api
import
nixl_agent_config
except
ImportError
:
...
...
@@ -234,6 +235,11 @@ class NixlConnector(KVConnectorBase_V1):
assert
self
.
connector_worker
is
not
None
return
self
.
connector_worker
.
get_finished
()
def
get_block_ids_with_load_errors
(
self
)
->
set
[
int
]:
"""Get block IDs that failed to load via NIXL."""
assert
self
.
connector_worker
is
not
None
return
self
.
connector_worker
.
get_block_ids_with_load_errors
()
def
get_kv_connector_stats
(
self
)
->
KVConnectorStats
|
None
:
assert
self
.
connector_worker
is
not
None
return
self
.
connector_worker
.
get_kv_connector_stats
()
...
...
@@ -614,6 +620,11 @@ class NixlConnectorWorker:
# Set of requests that have been part of a batch, regardless of status.
self
.
_reqs_to_process
:
set
[
ReqId
]
=
set
()
# invalid blocks from failed NIXL operations
self
.
_invalid_block_ids
:
set
[
int
]
=
set
()
# requests that skipped transfer (handshake or transfer failures)
self
.
_failed_recv_reqs
:
set
[
ReqId
]
=
set
()
# Background thread for handling new handshake requests.
self
.
_nixl_handshake_listener_t
:
threading
.
Thread
|
None
=
None
# Background thread for initializing new NIXL handshakes.
...
...
@@ -713,6 +724,8 @@ class NixlConnectorWorker:
# Send query for the request.
with
zmq_ctx
(
zmq
.
REQ
,
path
)
as
sock
:
# Set receive timeout to 5 seconds to avoid hanging on dead server
sock
.
setsockopt
(
zmq
.
RCVTIMEO
,
5000
)
# milliseconds
sock
.
send
(
GET_META_MSG
)
metadata_bytes
=
sock
.
recv
()
decoder
=
msgspec
.
msgpack
.
Decoder
(
NixlAgentMetadata
)
...
...
@@ -795,10 +808,20 @@ class NixlConnectorWorker:
fut
.
add_done_callback
(
done_callback
)
# TODO: handle failure state of future in the
# callback, we want to fail the request in this case.
def
request_ready
(
_f
:
Future
[
Any
],
entry
=
(
req_id
,
meta
)):
# check handshake success before proceeding with request
def
request_ready
(
f
:
Future
[
Any
],
entry
=
(
req_id
,
meta
)):
try
:
# check if handshake succeeded
f
.
result
()
self
.
_ready_requests
.
put
(
entry
)
except
Exception
:
# handshake failed - mark blocks as invalid
logger
.
exception
(
"Handshake failed for request %s, marking blocks as invalid"
,
req_id
)
if
req_meta
:
=
self
.
_recving_metadata
.
get
(
req_id
):
self
.
_invalid_block_ids
.
update
(
req_meta
.
local_block_ids
)
self
.
_failed_recv_reqs
.
add
(
req_id
)
fut
.
add_done_callback
(
request_ready
)
...
...
@@ -1205,6 +1228,11 @@ class NixlConnectorWorker:
"""
done_sending
=
self
.
_get_new_notifs
()
done_recving
=
self
.
_pop_done_transfers
(
self
.
_recving_transfers
)
# add requests that skipped transfer to done_recving
done_recving
.
update
(
self
.
_failed_recv_reqs
)
self
.
_failed_recv_reqs
.
clear
()
if
len
(
done_sending
)
>
0
or
len
(
done_recving
)
>
0
:
logger
.
debug
(
"Rank %s, get_finished: %s requests done sending "
...
...
@@ -1214,10 +1242,10 @@ class NixlConnectorWorker:
len
(
done_recving
),
)
if
self
.
use_host_buffer
:
# clean up metadata for completed requests
for
req_id
in
done_recving
:
meta
=
self
.
_recving_metadata
.
pop
(
req_id
)
assert
meta
,
f
"
{
req_id
}
not found in recving_metadata list"
meta
=
self
.
_recving_metadata
.
pop
(
req_id
,
None
)
if
self
.
use_host_buffer
and
meta
:
self
.
sync_recved_kv_to_device
(
req_id
,
meta
)
# Handle timeout to avoid stranding blocks on remote.
...
...
@@ -1296,7 +1324,19 @@ class NixlConnectorWorker:
in_progress
=
True
continue
else
:
raise
RuntimeError
(
"Transfer failed with state %s"
,
xfer_state
)
# transfer failed - mark blocks as invalid
logger
.
error
(
"NIXL transfer failed for request %s with state %s. "
"Marking blocks as invalid."
,
req_id
,
xfer_state
,
)
# mark all blocks for this request as invalid
if
meta
:
=
self
.
_recving_metadata
.
pop
(
req_id
,
None
):
self
.
_invalid_block_ids
.
update
(
meta
.
local_block_ids
)
self
.
_recving_metadata
.
pop
(
req_id
,
None
)
self
.
nixl_wrapper
.
release_xfer_handle
(
handle
)
self
.
xfer_stats
.
record_failed_transfer
()
if
not
in_progress
:
done_req_ids
.
add
(
req_id
)
del
transfers
[
req_id
]
...
...
@@ -1317,7 +1357,7 @@ class NixlConnectorWorker:
len
(
meta
.
local_block_ids
),
len
(
meta
.
remote_block_ids
),
)
if
self
.
use_host_buffer
:
# always store metadata for failure recovery
self
.
_recving_metadata
[
req_id
]
=
meta
if
remote_engine_id
not
in
self
.
_remote_agents
:
# Initiate handshake with remote engine to exchange metadata.
...
...
@@ -1394,7 +1434,16 @@ class NixlConnectorWorker:
if
num_local_blocks
==
0
:
remote_rank
=
self
.
tp_rank
//
tp_ratio
agent_name
=
self
.
_remote_agents
[
dst_engine_id
][
remote_rank
]
try
:
self
.
nixl_wrapper
.
send_notif
(
agent_name
,
notif_msg
=
notif_id
)
except
Exception
:
logger
.
exception
(
"NIXL send_notif failed for request %s: "
"P worker blocks will be freed after timeout. "
"This may indicate network issues."
,
request_id
,
)
self
.
xfer_stats
.
record_failed_notification
()
return
# Partial prefix cache hit: just read uncomputed blocks.
...
...
@@ -1456,6 +1505,8 @@ class NixlConnectorWorker:
assert
len
(
local_block_descs_ids
)
==
len
(
remote_block_descs_ids
)
# Prepare transfer with Nixl.
handle
=
None
try
:
handle
=
self
.
nixl_wrapper
.
make_prepped_xfer
(
"READ"
,
local_xfer_side_handle
,
...
...
@@ -1470,6 +1521,19 @@ class NixlConnectorWorker:
# Use handle to check completion in future step().
self
.
_recving_transfers
[
request_id
].
append
((
handle
,
time
.
perf_counter
()))
except
Exception
:
logger
.
exception
(
"NIXL transfer setup/initiation failed for request %s. "
"Marking blocks as invalid."
,
request_id
,
)
# mark all blocks for this request as invalid
if
meta
:
=
self
.
_recving_metadata
.
get
(
request_id
):
self
.
_invalid_block_ids
.
update
(
meta
.
local_block_ids
)
self
.
xfer_stats
.
record_failed_transfer
()
if
handle
is
not
None
:
self
.
nixl_wrapper
.
release_xfer_handle
(
handle
)
self
.
_failed_recv_reqs
.
add
(
request_id
)
def
_get_block_descs_ids
(
self
,
engine_id
:
str
,
block_ids
:
list
[
int
],
layer_idx
:
int
|
None
=
None
...
...
@@ -1527,6 +1591,17 @@ class NixlConnectorWorker:
return
self
.
xfer_stats
.
clone_and_reset
()
return
None
def
get_block_ids_with_load_errors
(
self
)
->
set
[
int
]:
"""
Return and clear the set of block IDs that failed to load.
This is called by the scheduler to identify blocks that need
to be retried after a NIXL transfer failure.
"""
result
=
self
.
_invalid_block_ids
self
.
_invalid_block_ids
=
set
()
return
result
def
shutdown
(
self
):
"""Shutdown the connector worker."""
self
.
_handshake_initiation_executor
.
shutdown
(
wait
=
False
)
...
...
@@ -1586,6 +1661,8 @@ class NixlKVConnectorStats(KVConnectorStats):
"post_duration"
:
[],
"bytes_transferred"
:
[],
"num_descriptors"
:
[],
"num_failed_transfers"
:
[],
"num_failed_notifications"
:
[],
}
def
record_transfer
(
self
,
res
:
nixlXferTelemetry
):
...
...
@@ -1595,6 +1672,14 @@ class NixlKVConnectorStats(KVConnectorStats):
self
.
data
[
"bytes_transferred"
].
append
(
res
.
totalBytes
)
self
.
data
[
"num_descriptors"
].
append
(
res
.
descCount
)
def
record_failed_transfer
(
self
):
"""Record a failed NIXL transfer operation."""
self
.
data
[
"num_failed_transfers"
].
append
(
1.0
)
def
record_failed_notification
(
self
):
"""Record a failed NIXL notification (send_notif)."""
self
.
data
[
"num_failed_notifications"
].
append
(
1.0
)
def
clone_and_reset
(
self
)
->
"NixlKVConnectorStats"
:
old
=
copy
.
copy
(
self
)
self
.
reset
()
...
...
vllm/v1/core/sched/scheduler.py
View file @
53c9a7ce
...
...
@@ -1487,7 +1487,7 @@ class Scheduler(SchedulerInterface):
total_tokens_to_reschedule
+=
num_tokens_to_reschedule
# Mark requests with async KV load failures; they will be rescheduled
# once loading completes
# once loading completes
.
self
.
failed_recving_kv_req_ids
|=
async_affected_req_ids
# --- Handle sync KV loads (running requests) ---
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment