Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2ce5c5d3
Unverified
Commit
2ce5c5d3
authored
Oct 29, 2025
by
Nick Hill
Committed by
GitHub
Oct 29, 2025
Browse files
[BugFix] Handle unscheduled requests properly when async scheduling (#27756)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
b5bae42f
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
63 additions
and
43 deletions
+63
-43
tests/v1/tpu/worker/test_tpu_model_runner.py
tests/v1/tpu/worker/test_tpu_model_runner.py
+3
-1
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+3
-3
vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
...buted/kv_transfer/kv_connector/v1/offloading_connector.py
+1
-1
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
...ted/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
+2
-2
vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
...d/kv_transfer/kv_connector/v1/shared_storage_connector.py
+1
-1
vllm/v1/core/sched/output.py
vllm/v1/core/sched/output.py
+24
-8
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+20
-19
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+8
-7
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+1
-1
No files found.
tests/v1/tpu/worker/test_tpu_model_runner.py
View file @
2ce5c5d3
...
@@ -212,10 +212,12 @@ def test_update_states_request_resumed(model_runner):
...
@@ -212,10 +212,12 @@ def test_update_states_request_resumed(model_runner):
# resume req
# resume req
cached_req_data
=
CachedRequestData
(
cached_req_data
=
CachedRequestData
(
req_ids
=
[
req_id
],
req_ids
=
[
req_id
],
resumed_
from_preemption
=
[
False
]
,
resumed_
req_ids
=
{
req_id
}
,
new_token_ids
=
[[]],
new_token_ids
=
[[]],
all_token_ids
=
{
req_id
:
scheduler_output
.
scheduled_new_reqs
[
0
].
prompt_token_ids
},
new_block_ids
=
[([],)],
new_block_ids
=
[([],)],
num_computed_tokens
=
[
0
],
num_computed_tokens
=
[
0
],
num_output_tokens
=
[
0
],
)
)
scheduler_output
=
SchedulerOutput
(
scheduler_output
=
SchedulerOutput
(
...
...
tests/v1/worker/test_gpu_model_runner.py
View file @
2ce5c5d3
...
@@ -259,10 +259,10 @@ def test_update_states_request_resumed(model_runner, dist_init):
...
@@ -259,10 +259,10 @@ def test_update_states_request_resumed(model_runner, dist_init):
# resume req
# resume req
cached_req_data
=
CachedRequestData
(
cached_req_data
=
CachedRequestData
(
req_ids
=
[
req_id
],
req_ids
=
[
req_id
],
resumed_
from_preemption
=
[
False
]
,
resumed_
req_ids
=
set
()
,
new_token_ids
=
[[]],
new_token_ids
=
[[]],
resumed_req
_token_ids
=
[
None
]
,
all
_token_ids
=
{}
,
new_block_ids
=
(
[[
0
]
]
,),
new_block_ids
=
[
(
[
0
],)
]
,
num_computed_tokens
=
[
0
],
num_computed_tokens
=
[
0
],
num_output_tokens
=
[
0
],
num_output_tokens
=
[
0
],
)
)
...
...
vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
View file @
2ce5c5d3
...
@@ -494,5 +494,5 @@ def yield_req_data(
...
@@ -494,5 +494,5 @@ def yield_req_data(
yield
from
zip
(
yield
from
zip
(
cached_reqs
.
req_ids
,
cached_reqs
.
req_ids
,
cached_reqs
.
new_block_ids
,
cached_reqs
.
new_block_ids
,
cached_reqs
.
resumed_
from_preemption
,
(
req_id
in
cached_reqs
.
resumed_
req_ids
for
req_id
in
cached_reqs
.
req_ids
)
,
)
)
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
View file @
2ce5c5d3
...
@@ -415,10 +415,10 @@ class P2pNcclConnector(KVConnectorBase_V1):
...
@@ -415,10 +415,10 @@ class P2pNcclConnector(KVConnectorBase_V1):
for
i
,
req_id
in
enumerate
(
cached_reqs
.
req_ids
):
for
i
,
req_id
in
enumerate
(
cached_reqs
.
req_ids
):
num_computed_tokens
=
cached_reqs
.
num_computed_tokens
[
i
]
num_computed_tokens
=
cached_reqs
.
num_computed_tokens
[
i
]
new_block_ids
=
cached_reqs
.
new_block_ids
[
i
]
new_block_ids
=
cached_reqs
.
new_block_ids
[
i
]
resumed_from_preemption
=
cached_reqs
.
resumed_
from_preemption
[
i
]
resumed_from_preemption
=
req_id
in
cached_reqs
.
resumed_
req_ids
if
self
.
is_producer
:
if
self
.
is_producer
:
num_scheduled_tokens
=
(
scheduler_output
.
num_scheduled_tokens
)
[
req_id
]
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
num_tokens
=
num_scheduled_tokens
+
num_computed_tokens
num_tokens
=
num_scheduled_tokens
+
num_computed_tokens
assert
req_id
in
self
.
chunked_prefill
assert
req_id
in
self
.
chunked_prefill
assert
new_block_ids
is
not
None
assert
new_block_ids
is
not
None
...
...
vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
View file @
2ce5c5d3
...
@@ -336,7 +336,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
...
@@ -336,7 +336,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
cached_reqs
=
scheduler_output
.
scheduled_cached_reqs
cached_reqs
=
scheduler_output
.
scheduled_cached_reqs
for
i
,
req_id
in
enumerate
(
cached_reqs
.
req_ids
):
for
i
,
req_id
in
enumerate
(
cached_reqs
.
req_ids
):
resumed_from_preemption
=
cached_reqs
.
resumed_
from_preemption
[
i
]
resumed_from_preemption
=
req_id
in
cached_reqs
.
resumed_
req_ids
if
not
resumed_from_preemption
or
req_id
not
in
self
.
_requests_need_load
:
if
not
resumed_from_preemption
or
req_id
not
in
self
.
_requests_need_load
:
continue
continue
...
...
vllm/v1/core/sched/output.py
View file @
2ce5c5d3
...
@@ -2,8 +2,11 @@
...
@@ -2,8 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
cached_property
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
from
typing_extensions
import
deprecated
from
vllm._bc_linter
import
bc_linter_include
from
vllm._bc_linter
import
bc_linter_include
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -96,16 +99,16 @@ class NewRequestData:
...
@@ -96,16 +99,16 @@ class NewRequestData:
@
dataclass
@
dataclass
class
CachedRequestData
:
class
CachedRequestData
:
req_ids
:
list
[
str
]
req_ids
:
list
[
str
]
#
If resumed_from_preemption is False
, new_block_ids will be appended to
#
For request ids not in resumed_req_ids
, new_block_ids will be appended to
# the request's block IDs.
If True
, new_block_ids will be used as the
# the request's block IDs.
For those in the set
, new_block_ids will be used as the
# request's block IDs instead of appending to the existing block IDs.
# request's block IDs instead of appending to the existing block IDs.
resumed_
from_preemption
:
list
[
bool
]
resumed_
req_ids
:
set
[
str
]
# NOTE(woosuk): new_token_ids is only used for pipeline parallelism.
# NOTE(woosuk): new_token_ids is only used for pipeline parallelism.
# When PP is not used, new_token_ids will be empty.
# When PP is not used, new_token_ids will be empty.
new_token_ids
:
list
[
list
[
int
]]
new_token_ids
:
list
[
list
[
int
]]
#
If resumed_from_preemption is True
, prop
o
gate the token ids to the
#
For requests not scheduled in the last step
, prop
a
gate the token ids to the
# connector
, otherwise will be empty
.
# connector
. Won't contain requests that were scheduled in the prior step
.
resumed_req
_token_ids
:
list
[
list
[
int
]
|
None
]
all
_token_ids
:
dict
[
str
,
list
[
int
]]
new_block_ids
:
list
[
tuple
[
list
[
int
],
...]
|
None
]
new_block_ids
:
list
[
tuple
[
list
[
int
],
...]
|
None
]
num_computed_tokens
:
list
[
int
]
num_computed_tokens
:
list
[
int
]
num_output_tokens
:
list
[
int
]
num_output_tokens
:
list
[
int
]
...
@@ -114,13 +117,26 @@ class CachedRequestData:
...
@@ -114,13 +117,26 @@ class CachedRequestData:
def
num_reqs
(
self
)
->
int
:
def
num_reqs
(
self
)
->
int
:
return
len
(
self
.
req_ids
)
return
len
(
self
.
req_ids
)
@
cached_property
@
deprecated
(
"use resumed_req_ids field"
)
def
resumed_from_preemption
(
self
)
->
list
[
bool
]:
return
[
req_id
in
self
.
resumed_req_ids
for
req_id
in
self
.
req_ids
]
@
cached_property
@
deprecated
(
"use all_token_ids field"
)
def
resumed_req_token_ids
(
self
)
->
list
[
list
[
int
]
|
None
]:
return
[
self
.
all_token_ids
[
req_id
]
if
req_id
in
self
.
resumed_req_ids
else
None
for
req_id
in
self
.
req_ids
]
@
classmethod
@
classmethod
def
make_empty
(
cls
)
->
"CachedRequestData"
:
def
make_empty
(
cls
)
->
"CachedRequestData"
:
return
cls
(
return
cls
(
req_ids
=
[],
req_ids
=
[],
resumed_
from_preemption
=
[]
,
resumed_
req_ids
=
set
()
,
new_token_ids
=
[],
new_token_ids
=
[],
resumed_req
_token_ids
=
[]
,
all
_token_ids
=
{}
,
new_block_ids
=
[],
new_block_ids
=
[],
num_computed_tokens
=
[],
num_computed_tokens
=
[],
num_output_tokens
=
[],
num_output_tokens
=
[],
...
...
vllm/v1/core/sched/scheduler.py
View file @
2ce5c5d3
...
@@ -71,6 +71,7 @@ class Scheduler(SchedulerInterface):
...
@@ -71,6 +71,7 @@ class Scheduler(SchedulerInterface):
self
.
finished_req_ids_dict
:
dict
[
int
,
set
[
str
]]
|
None
=
(
self
.
finished_req_ids_dict
:
dict
[
int
,
set
[
str
]]
|
None
=
(
defaultdict
(
set
)
if
include_finished_set
else
None
defaultdict
(
set
)
if
include_finished_set
else
None
)
)
self
.
prev_step_scheduled_req_ids
:
set
[
str
]
=
set
()
# Scheduling constraints.
# Scheduling constraints.
self
.
max_num_running_reqs
=
self
.
scheduler_config
.
max_num_seqs
self
.
max_num_running_reqs
=
self
.
scheduler_config
.
max_num_seqs
...
@@ -444,14 +445,9 @@ class Scheduler(SchedulerInterface):
...
@@ -444,14 +445,9 @@ class Scheduler(SchedulerInterface):
# `request.num_prompt_tokens` to consider the resumed
# `request.num_prompt_tokens` to consider the resumed
# requests, which have output tokens.
# requests, which have output tokens.
num_new_tokens
=
request
.
num_tokens
-
num_computed_tokens
num_new_tokens
=
request
.
num_tokens
-
num_computed_tokens
if
(
threshold
=
self
.
scheduler_config
.
long_prefill_token_threshold
0
if
0
<
threshold
<
num_new_tokens
:
<
self
.
scheduler_config
.
long_prefill_token_threshold
num_new_tokens
=
threshold
<
num_new_tokens
):
num_new_tokens
=
(
self
.
scheduler_config
.
long_prefill_token_threshold
)
# chunked prefill has to be enabled explicitly to allow
# chunked prefill has to be enabled explicitly to allow
# pooling requests to be chunked
# pooling requests to be chunked
...
@@ -620,6 +616,11 @@ class Scheduler(SchedulerInterface):
...
@@ -620,6 +616,11 @@ class Scheduler(SchedulerInterface):
structured_output_request_ids
,
grammar_bitmask
=
self
.
get_grammar_bitmask
(
structured_output_request_ids
,
grammar_bitmask
=
self
.
get_grammar_bitmask
(
num_scheduled_tokens
.
keys
(),
scheduled_spec_decode_tokens
num_scheduled_tokens
.
keys
(),
scheduled_spec_decode_tokens
)
)
# Record the request ids that were scheduled in this step.
self
.
prev_step_scheduled_req_ids
.
clear
()
self
.
prev_step_scheduled_req_ids
.
update
(
num_scheduled_tokens
.
keys
())
scheduler_output
=
SchedulerOutput
(
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
new_reqs_data
,
scheduled_new_reqs
=
new_reqs_data
,
scheduled_cached_reqs
=
cached_reqs_data
,
scheduled_cached_reqs
=
cached_reqs_data
,
...
@@ -691,14 +692,12 @@ class Scheduler(SchedulerInterface):
...
@@ -691,14 +692,12 @@ class Scheduler(SchedulerInterface):
req_ids
:
list
[
str
]
=
[]
req_ids
:
list
[
str
]
=
[]
new_token_ids
:
list
[
list
[
int
]]
=
[]
new_token_ids
:
list
[
list
[
int
]]
=
[]
new_block_ids
:
list
[
tuple
[
list
[
int
],
...]
|
None
]
=
[]
new_block_ids
:
list
[
tuple
[
list
[
int
],
...]
|
None
]
=
[]
resumed_req
_token_ids
:
list
[
list
[
int
]
|
None
]
=
[]
all
_token_ids
:
dict
[
str
,
list
[
int
]]
=
{}
num_computed_tokens
:
list
[
int
]
=
[]
num_computed_tokens
:
list
[
int
]
=
[]
num_output_tokens
:
list
[
int
]
=
[]
num_output_tokens
:
list
[
int
]
=
[]
resumed_req_ids
=
set
()
# Because resumed_reqs is usually empty, it is more efficient to do
num_running_reqs
=
len
(
running_reqs
)
# in-place appending so that we don't need to allocate a new list.
resumed_from_preemption
=
[
False
]
*
len
(
running_reqs
)
resumed_from_preemption
+=
[
True
]
*
len
(
resumed_reqs
)
for
idx
,
req
in
enumerate
(
itertools
.
chain
(
running_reqs
,
resumed_reqs
)):
for
idx
,
req
in
enumerate
(
itertools
.
chain
(
running_reqs
,
resumed_reqs
)):
req_id
=
req
.
request_id
req_id
=
req
.
request_id
req_ids
.
append
(
req_id
)
req_ids
.
append
(
req_id
)
...
@@ -715,12 +714,14 @@ class Scheduler(SchedulerInterface):
...
@@ -715,12 +714,14 @@ class Scheduler(SchedulerInterface):
req
.
num_computed_tokens
:
req
.
num_computed_tokens
+
num_tokens
req
.
num_computed_tokens
:
req
.
num_computed_tokens
+
num_tokens
]
]
new_token_ids
.
append
(
token_ids
)
new_token_ids
.
append
(
token_ids
)
resumed_token_ids
=
None
scheduled_in_prev_step
=
req_id
in
self
.
prev_step_scheduled_req_ids
if
resumed_from_preemption
[
idx
]:
if
idx
>=
num_running_reqs
:
resumed_token_ids
=
req
.
all_token_ids
[
assert
not
scheduled_in_prev_step
resumed_req_ids
.
add
(
req_id
)
if
not
scheduled_in_prev_step
:
all_token_ids
[
req_id
]
=
req
.
all_token_ids
[
:
req
.
num_computed_tokens
+
num_tokens
:
req
.
num_computed_tokens
+
num_tokens
]
]
resumed_req_token_ids
.
append
(
resumed_token_ids
)
new_block_ids
.
append
(
new_block_ids
.
append
(
req_to_new_blocks
[
req_id
].
get_block_ids
(
allow_none
=
True
)
req_to_new_blocks
[
req_id
].
get_block_ids
(
allow_none
=
True
)
)
)
...
@@ -731,9 +732,9 @@ class Scheduler(SchedulerInterface):
...
@@ -731,9 +732,9 @@ class Scheduler(SchedulerInterface):
return
CachedRequestData
(
return
CachedRequestData
(
req_ids
=
req_ids
,
req_ids
=
req_ids
,
resumed_
from_preemption
=
resumed_from_preemption
,
resumed_
req_ids
=
resumed_req_ids
,
new_token_ids
=
new_token_ids
,
new_token_ids
=
new_token_ids
,
resumed_req_token_ids
=
resumed_req
_token_ids
,
all_token_ids
=
all
_token_ids
,
new_block_ids
=
new_block_ids
,
new_block_ids
=
new_block_ids
,
num_computed_tokens
=
num_computed_tokens
,
num_computed_tokens
=
num_computed_tokens
,
num_output_tokens
=
num_output_tokens
,
num_output_tokens
=
num_output_tokens
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
2ce5c5d3
...
@@ -706,7 +706,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -706,7 +706,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_state
=
self
.
requests
[
req_id
]
req_state
=
self
.
requests
[
req_id
]
num_computed_tokens
=
req_data
.
num_computed_tokens
[
i
]
num_computed_tokens
=
req_data
.
num_computed_tokens
[
i
]
new_block_ids
=
req_data
.
new_block_ids
[
i
]
new_block_ids
=
req_data
.
new_block_ids
[
i
]
resumed_from_preemption
=
req_data
.
resumed_
from_preemption
[
i
]
resumed_from_preemption
=
req_id
in
req_data
.
resumed_
req_ids
num_output_tokens
=
req_data
.
num_output_tokens
[
i
]
num_output_tokens
=
req_data
.
num_output_tokens
[
i
]
# Update the cached states.
# Update the cached states.
...
@@ -754,16 +754,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -754,16 +754,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Replace the existing block IDs with the new ones.
# Replace the existing block IDs with the new ones.
req_state
.
block_ids
=
new_block_ids
req_state
.
block_ids
=
new_block_ids
if
self
.
use_async_scheduling
and
num_output_tokens
>
0
:
# We must recover the output token ids for resumed requests in the
# async scheduling case, so that correct input_ids are obtained.
resumed_token_ids
=
req_data
.
resumed_req_token_ids
[
i
]
assert
resumed_token_ids
is
not
None
req_state
.
output_token_ids
=
resumed_token_ids
[
-
num_output_tokens
:]
if
req_index
is
None
:
if
req_index
is
None
:
# The request is not in the persistent batch.
# The request is not in the persistent batch.
# The request was either preempted and resumed later, or was not
# The request was either preempted and resumed later, or was not
# scheduled in the previous step and needs to be added again.
# scheduled in the previous step and needs to be added again.
if
self
.
use_async_scheduling
and
num_output_tokens
>
0
:
# We must recover the output token ids for resumed requests in the
# async scheduling case, so that correct input_ids are obtained.
resumed_token_ids
=
req_data
.
all_token_ids
[
req_id
]
req_state
.
output_token_ids
=
resumed_token_ids
[
-
num_output_tokens
:]
reqs_to_add
.
append
(
req_state
)
reqs_to_add
.
append
(
req_state
)
continue
continue
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
2ce5c5d3
...
@@ -483,7 +483,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -483,7 +483,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_state
=
self
.
requests
[
req_id
]
req_state
=
self
.
requests
[
req_id
]
num_computed_tokens
=
req_data
.
num_computed_tokens
[
i
]
num_computed_tokens
=
req_data
.
num_computed_tokens
[
i
]
new_block_ids
=
req_data
.
new_block_ids
[
i
]
new_block_ids
=
req_data
.
new_block_ids
[
i
]
resumed_from_preemption
=
req_data
.
resumed_
from_preemption
[
i
]
resumed_from_preemption
=
req_id
in
req_data
.
resumed_
req_ids
# Update the cached states.
# Update the cached states.
req_state
.
num_computed_tokens
=
num_computed_tokens
req_state
.
num_computed_tokens
=
num_computed_tokens
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment