Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2a6dc67e
Unverified
Commit
2a6dc67e
authored
Oct 04, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Oct 04, 2025
Browse files
[Bugfix] Fix `_reqs_to_process` leak on abort (#26012)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
f05fea1f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
82 additions
and
2 deletions
+82
-2
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+66
-0
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+16
-2
No files found.
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
2a6dc67e
...
...
@@ -33,6 +33,7 @@ from vllm.platforms.interface import Platform
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionBackend
from
vllm.v1.outputs
import
KVConnectorOutput
,
ModelRunnerOutput
from
vllm.v1.request
import
RequestStatus
from
.utils
import
create_request
,
create_scheduler
,
create_vllm_config
...
...
@@ -1023,3 +1024,68 @@ def test_shutdown_cleans_up_resources(dist_init):
assert
mock_dereg
.
call_count
==
2
mock_dereg
.
assert_any_call
(
"desc1"
)
mock_dereg
.
assert_any_call
(
"desc2"
)
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FakeNixlWrapper
)
def
test_aborted_request_removed_from_worker_in_batch
(
dist_init
):
"""
Create and schedule a request so that P adds it to in-batch tracking via
the real scheduler, then simulate an abort (request not in next scheduler
iteration) and verify the worker no longer tracks it as in-batch.
"""
vllm_config
=
create_vllm_config
()
scheduler
=
create_scheduler
(
vllm_config
)
# KVConnector Worker in P
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
# Create a request that triggers do_remote_decode so that
# the scheduler adds it to reqs_in_batch
req
=
create_request
(
request_id
=
1
,
do_remote_decode
=
True
,
max_tokens
=
1
)
scheduler
.
add_request
(
req
)
# First scheduling pass - examinate build_connector_meta output
sched_out
=
scheduler
.
schedule
()
kv_meta
=
sched_out
.
kv_connector_metadata
assert
kv_meta
is
not
None
assert
isinstance
(
kv_meta
,
NixlConnectorMetadata
)
assert
req
.
request_id
in
kv_meta
.
reqs_in_batch
#### Model Runner start ####
# Bind scheduler-produced metadata and start worker processing.
connector
.
bind_connector_metadata
(
kv_meta
)
dummy_ctx
=
ForwardContext
(
no_compile_layers
=
{},
attn_metadata
=
{},
virtual_engine
=
0
,
)
connector
.
start_load_kv
(
dummy_ctx
)
# Ensure it was tracked by the worker
assert
req
.
request_id
in
connector
.
connector_worker
.
_reqs_to_process
#### Model Runner end ####
# Abort request - request_finished call in connector scheduler
scheduler
.
finish_requests
(
req
.
request_id
,
RequestStatus
.
FINISHED_ABORTED
)
# Second scheduling pass - build metadata with aborted request
sched_out2
=
scheduler
.
schedule
()
kv_meta2
=
sched_out2
.
kv_connector_metadata
assert
kv_meta2
is
not
None
assert
isinstance
(
kv_meta2
,
NixlConnectorMetadata
)
assert
req
.
request_id
not
in
kv_meta2
.
reqs_in_batch
# Bind empty/abort metadata and run worker step
#### Model Runner start ####
connector
.
bind_connector_metadata
(
kv_meta2
)
connector
.
start_load_kv
(
dummy_ctx
)
# After abort, the worker should not keep tracking it as "in-batch"
assert
req
.
request_id
not
in
connector
.
connector_worker
.
_reqs_to_process
#### Model Runner end ####
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
2a6dc67e
...
...
@@ -113,6 +113,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
self
.
reqs_to_save
:
dict
[
ReqId
,
ReqMeta
]
=
{}
self
.
reqs_to_send
:
dict
[
ReqId
,
float
]
=
{}
self
.
reqs_in_batch
:
set
[
ReqId
]
=
set
()
self
.
reqs_not_processed
:
set
[
ReqId
]
=
set
()
def
add_new_req
(
self
,
...
...
@@ -287,6 +288,9 @@ class NixlConnectorScheduler:
# Reqs to send and their expiration time
self
.
_reqs_need_send
:
dict
[
ReqId
,
float
]
=
{}
self
.
_reqs_in_batch
:
set
[
ReqId
]
=
set
()
# Reqs to remove from processed set because they're not to send after
# remote prefill or aborted.
self
.
_reqs_not_processed
:
set
[
ReqId
]
=
set
()
def
get_num_new_matched_tokens
(
self
,
request
:
"Request"
,
...
...
@@ -401,11 +405,13 @@ class NixlConnectorScheduler:
meta
.
reqs_to_send
=
self
.
_reqs_need_send
meta
.
reqs_in_batch
=
self
.
_reqs_in_batch
meta
.
reqs_not_processed
=
self
.
_reqs_not_processed
# Clear the list once workers start the transfers
self
.
_reqs_need_recv
.
clear
()
self
.
_reqs_need_save
.
clear
()
self
.
_reqs_in_batch
=
set
()
self
.
_reqs_not_processed
=
set
()
self
.
_reqs_need_send
=
{}
return
meta
...
...
@@ -439,8 +445,12 @@ class NixlConnectorScheduler:
params
[
"do_remote_prefill"
]
=
False
return
False
,
None
if
(
not
params
.
get
(
"do_remote_decode"
)
or
request
.
status
!=
RequestStatus
.
FINISHED_LENGTH_CAPPED
):
if
not
params
.
get
(
"do_remote_decode"
):
return
False
,
None
if
request
.
status
!=
RequestStatus
.
FINISHED_LENGTH_CAPPED
:
# Also include the case of a P/D Prefill request with immediate
# block free (eg abort). Stop tracking this request.
self
.
_reqs_not_processed
.
add
(
request
.
request_id
)
return
False
,
None
# TODO: check whether block_ids actually ever be 0. If not we could
...
...
@@ -1234,6 +1244,10 @@ class NixlConnectorWorker:
for
req_id
in
metadata
.
reqs_in_batch
:
self
.
_reqs_to_process
.
add
(
req_id
)
# Remove all requests that are not to be processed (eg aborted).
for
req_id
in
metadata
.
reqs_not_processed
:
self
.
_reqs_to_process
.
discard
(
req_id
)
# Add to requests that are waiting to be read and track expiration.
for
req_id
,
expiration_time
in
metadata
.
reqs_to_send
.
items
():
if
req_id
in
self
.
_reqs_to_process
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment