Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
58ce8d12
Unverified
Commit
58ce8d12
authored
Nov 12, 2025
by
Andy Lo
Committed by
GitHub
Nov 12, 2025
Browse files
[BugFix] Priority scheduling and spec tokens preemption (#28558)
Signed-off-by:
Andy Lo
<
andy@mistral.ai
>
parent
94a9ebcf
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
266 additions
and
0 deletions
+266
-0
tests/v1/core/test_priority_scheduler_random.py
tests/v1/core/test_priority_scheduler_random.py
+252
-0
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+14
-0
No files found.
tests/v1/core/test_priority_scheduler_random.py
0 → 100644
View file @
58ce8d12
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
random
import
uuid
import
pytest
from
vllm.config
import
VllmConfig
from
vllm.multimodal.inputs
import
(
MultiModalFeatureSpec
,
MultiModalKwargsItem
,
PlaceholderRange
,
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils.hashing
import
get_hash_fn_by_name
from
vllm.v1.core.kv_cache_utils
import
get_request_block_hasher
,
init_none_hash
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.outputs
import
DraftTokenIds
,
ModelRunnerOutput
from
vllm.v1.request
import
Request
from
.test_scheduler
import
create_scheduler_with_priority
from
.utils
import
EOS_TOKEN_ID
pytestmark
=
pytest
.
mark
.
cpu_test
def
_create_random_request
(
max_tokens_range
:
tuple
[
int
,
int
],
num_tokens_range
:
tuple
[
int
,
int
],
arrival_time_range
:
tuple
[
float
,
float
],
priority_range
:
tuple
[
int
,
int
],
num_mm_item_range
:
tuple
[
int
,
int
],
vllm_config
:
VllmConfig
,
):
max_tokens
=
random
.
randint
(
*
max_tokens_range
)
num_tokens
=
random
.
randint
(
*
num_tokens_range
)
priority
=
random
.
randint
(
*
priority_range
)
arrival_time
=
random
.
uniform
(
*
arrival_time_range
)
num_mm_item
=
random
.
randint
(
*
num_mm_item_range
)
mm_positions
:
list
[
PlaceholderRange
]
=
[]
for
mm_start
in
sorted
(
random
.
sample
(
range
(
num_tokens
),
min
(
num_mm_item
,
num_tokens
))
):
if
mm_start
+
10
>
num_tokens
:
continue
mm_positions
.
append
(
PlaceholderRange
(
offset
=
mm_start
,
length
=
10
))
request_id
=
uuid
.
uuid4
().
hex
sampling_params
=
SamplingParams
(
ignore_eos
=
False
,
max_tokens
=
max_tokens
,
)
mm_features
=
[]
for
j
,
position
in
enumerate
(
mm_positions
):
identifier
=
f
"
{
request_id
}
_hash_
{
j
}
"
mm_feature
=
MultiModalFeatureSpec
(
data
=
MultiModalKwargsItem
.
dummy
(
"dummy_m"
),
mm_position
=
position
,
identifier
=
identifier
,
modality
=
"image"
,
)
mm_features
.
append
(
mm_feature
)
prompt_token_ids
=
random
.
choices
(
range
(
100
),
k
=
num_tokens
)
caching_hash_fn
=
get_hash_fn_by_name
(
vllm_config
.
cache_config
.
prefix_caching_hash_algo
)
init_none_hash
(
caching_hash_fn
)
block_hasher
=
get_request_block_hasher
(
vllm_config
.
cache_config
.
block_size
,
caching_hash_fn
)
request
=
Request
(
request_id
=
request_id
,
prompt_token_ids
=
prompt_token_ids
,
sampling_params
=
sampling_params
,
pooling_params
=
None
,
mm_features
=
mm_features
if
mm_features
else
None
,
eos_token_id
=
EOS_TOKEN_ID
,
arrival_time
=
arrival_time
,
priority
=
priority
,
block_hasher
=
block_hasher
,
)
return
request
def
_mock_execute_model
(
scheduler_output
:
SchedulerOutput
,
num_output_tokens_range
:
tuple
[
int
,
int
]
)
->
ModelRunnerOutput
:
request_ids
:
list
[
str
]
=
[]
request_ids
.
extend
(
req
.
req_id
for
req
in
scheduler_output
.
scheduled_new_reqs
)
request_ids
.
extend
(
scheduler_output
.
scheduled_cached_reqs
.
req_ids
)
random
.
shuffle
(
request_ids
)
num_output_tokens
=
[
random
.
randint
(
*
num_output_tokens_range
)
for
_
in
range
(
len
(
request_ids
))
]
sampled_token_ids
=
[
[
random
.
randint
(
0
,
100
)
for
_
in
range
(
num_tokens
)]
for
num_tokens
in
num_output_tokens
]
return
ModelRunnerOutput
(
req_ids
=
request_ids
,
req_id_to_index
=
{
req_id
:
i
for
i
,
req_id
in
enumerate
(
request_ids
)},
sampled_token_ids
=
sampled_token_ids
,
logprobs
=
None
,
prompt_logprobs_dict
=
{},
pooler_output
=
[],
)
def
_mock_draft_token_ids
(
scheduler_output
:
SchedulerOutput
,
num_output_tokens_range
:
tuple
[
int
,
int
],
seen_request_prompt_length
:
dict
[
str
,
int
],
)
->
DraftTokenIds
:
request_ids
:
list
[
str
]
=
[]
sampled_token_ids
:
list
[
list
[
int
]]
=
[]
for
request
in
scheduler_output
.
scheduled_new_reqs
:
assert
request
.
req_id
not
in
seen_request_prompt_length
seen_request_prompt_length
[
request
.
req_id
]
=
len
(
request
.
prompt_token_ids
or
[])
if
request
.
num_computed_tokens
>=
seen_request_prompt_length
[
request
.
req_id
]:
num_tokens
=
random
.
randint
(
*
num_output_tokens_range
)
request_ids
.
append
(
request
.
req_id
)
sampled_token_ids
.
append
(
[
random
.
randint
(
0
,
100
)
for
_
in
range
(
num_tokens
)]
)
for
req_id
,
num_computed_tokens
in
zip
(
scheduler_output
.
scheduled_cached_reqs
.
req_ids
,
scheduler_output
.
scheduled_cached_reqs
.
num_computed_tokens
,
):
if
num_computed_tokens
>=
seen_request_prompt_length
[
req_id
]:
num_tokens
=
random
.
randint
(
*
num_output_tokens_range
)
request_ids
.
append
(
req_id
)
sampled_token_ids
.
append
(
[
random
.
randint
(
0
,
100
)
for
_
in
range
(
num_tokens
)]
)
return
DraftTokenIds
(
req_ids
=
request_ids
,
draft_token_ids
=
sampled_token_ids
)
def
_chech_valid_scheduler_output
(
scheduler_output
:
SchedulerOutput
,
seen_request_ids
:
set
[
str
],
seen_mm_hashes
:
set
[
str
],
):
for
req
in
scheduler_output
.
scheduled_new_reqs
:
assert
req
.
req_id
not
in
seen_request_ids
seen_request_ids
.
add
(
req
.
req_id
)
for
req_id
in
scheduler_output
.
scheduled_cached_reqs
.
req_ids
:
assert
req_id
in
seen_request_ids
req_ids
=
set
[
str
]()
req_ids
.
update
(
req
.
req_id
for
req
in
scheduler_output
.
scheduled_new_reqs
)
req_ids
.
update
(
scheduler_output
.
scheduled_cached_reqs
.
req_ids
)
assert
set
(
scheduler_output
.
num_scheduled_tokens
.
keys
())
==
req_ids
assert
(
sum
(
scheduler_output
.
num_scheduled_tokens
.
values
())
==
scheduler_output
.
total_num_scheduled_tokens
)
assert
set
(
scheduler_output
.
scheduled_spec_decode_tokens
.
keys
())
<=
req_ids
assert
set
(
scheduler_output
.
scheduled_encoder_inputs
.
keys
())
<=
req_ids
for
req
in
scheduler_output
.
scheduled_new_reqs
:
for
mm_feature
in
req
.
mm_features
:
seen_mm_hashes
.
add
(
mm_feature
.
identifier
)
for
mm_hash
in
scheduler_output
.
free_encoder_mm_hashes
:
assert
mm_hash
in
seen_mm_hashes
assert
scheduler_output
.
finished_req_ids
<=
seen_request_ids
@
pytest
.
mark
.
parametrize
(
"enable_prefix_caching"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"num_speculative_tokens"
,
[
None
,
1
,
5
])
@
pytest
.
mark
.
parametrize
(
(
"max_input_tokens"
,
"max_output_tokens"
,
"max_num_seqs"
,
"num_blocks"
),
[
# Standard profile
(
5000
,
500
,
256
,
10000
),
# Generation heavy + high max_num_seqs + low num_blocks -> Many preemptions
(
500
,
5000
,
1024
,
1000
),
],
ids
=
[
"standard"
,
"preemption"
],
)
def
test_priority_scheduling_blast
(
enable_prefix_caching
:
bool
,
num_speculative_tokens
:
int
|
None
,
max_input_tokens
:
int
,
max_output_tokens
:
int
,
max_num_seqs
:
int
,
num_blocks
:
int
,
):
random
.
seed
(
42
)
seen_request_prompt_length
=
dict
[
str
,
int
]()
seen_request_ids
=
set
[
str
]()
seen_mm_hashes
=
set
[
str
]()
scheduler
=
create_scheduler_with_priority
(
model
=
"Qwen/Qwen2.5-VL-3B-Instruct"
,
max_num_seqs
=
max_num_seqs
,
enable_prefix_caching
=
enable_prefix_caching
,
num_blocks
=
num_blocks
,
num_speculative_tokens
=
num_speculative_tokens
,
)
num_initial_requests
=
10
for
_
in
range
(
num_initial_requests
):
req
=
_create_random_request
(
max_tokens_range
=
(
1
,
max_output_tokens
),
num_tokens_range
=
(
1
,
max_input_tokens
),
arrival_time_range
=
(
0
,
1
),
priority_range
=
(
-
3
,
3
),
num_mm_item_range
=
(
0
,
2
),
vllm_config
=
scheduler
.
vllm_config
,
)
scheduler
.
add_request
(
req
)
for
_
in
range
(
20000
):
if
len
(
scheduler
.
waiting
)
==
0
:
num_new_requests
=
random
.
randint
(
0
,
2
)
for
_
in
range
(
num_new_requests
):
req
=
_create_random_request
(
max_tokens_range
=
(
1
,
max_output_tokens
),
num_tokens_range
=
(
1
,
max_input_tokens
),
arrival_time_range
=
(
0
,
1
),
priority_range
=
(
-
3
,
3
),
num_mm_item_range
=
(
0
,
2
),
vllm_config
=
scheduler
.
vllm_config
,
)
scheduler
.
add_request
(
req
)
scheduler_output
=
scheduler
.
schedule
()
_chech_valid_scheduler_output
(
scheduler_output
,
seen_request_ids
,
seen_mm_hashes
)
model_output
=
_mock_execute_model
(
scheduler_output
,
num_output_tokens_range
=
(
1
,
1
+
(
num_speculative_tokens
or
0
)),
)
scheduler
.
update_from_output
(
scheduler_output
,
model_output
)
if
num_speculative_tokens
is
not
None
:
scheduler
.
update_draft_token_ids
(
_mock_draft_token_ids
(
scheduler_output
,
(
0
,
num_speculative_tokens
),
seen_request_prompt_length
,
)
)
vllm/v1/core/sched/scheduler.py
View file @
58ce8d12
...
...
@@ -300,6 +300,20 @@ class Scheduler(SchedulerInterface):
]
req_to_new_blocks
.
pop
(
preempted_req
.
request_id
)
num_scheduled_tokens
.
pop
(
preempted_req
.
request_id
)
scheduled_spec_decode_tokens
.
pop
(
preempted_req
.
request_id
,
None
)
preempted_encoder_inputs
=
scheduled_encoder_inputs
.
pop
(
preempted_req
.
request_id
,
None
)
if
preempted_encoder_inputs
:
# Restore encoder compute budget if the preempted
# request had encoder inputs scheduled in this step.
num_tokens_to_restore
=
sum
(
preempted_req
.
get_num_encoder_tokens
(
i
)
for
i
in
preempted_encoder_inputs
)
encoder_compute_budget
+=
num_tokens_to_restore
req_index
-=
1
else
:
preempted_req
=
self
.
running
.
pop
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment