Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
55c65ab4
Unverified
Commit
55c65ab4
authored
Jun 25, 2025
by
Nick Hill
Committed by
GitHub
Jun 25, 2025
Browse files
[P/D] Avoid stranding blocks in P when aborted in D's waiting queue (#19223)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
2cc20699
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
1 deletion
+14
-1
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+14
-1
No files found.
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
55c65ab4
...
...
@@ -298,8 +298,21 @@ class NixlConnectorScheduler:
logger
.
debug
(
"NIXLConnector request_finished, request_status=%s, "
"kv_transfer_params=%s"
,
request
.
status
,
params
)
if
not
params
:
return
False
,
None
if
params
.
get
(
"do_remote_prefill"
):
# If do_remote_prefill is still True when the request is finished,
# update_state_after_alloc must not have been called (the request
# must have been aborted before it was scheduled).
# To avoid stranding the prefill blocks in the prefill instance,
# we must add empty block_ids to _reqs_need_recv so that our
# worker side will notify and free blocks in the prefill instance.
self
.
_reqs_need_recv
[
request
.
request_id
]
=
(
request
,
[])
params
[
"do_remote_prefill"
]
=
False
return
False
,
None
if
(
params
is
None
or
not
params
.
get
(
"do_remote_decode"
)
if
(
not
params
.
get
(
"do_remote_decode"
)
or
request
.
status
!=
RequestStatus
.
FINISHED_LENGTH_CAPPED
):
return
False
,
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment