Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5bc26c43
Unverified
Commit
5bc26c43
authored
Oct 10, 2025
by
Nick Hill
Committed by
GitHub
Oct 10, 2025
Browse files
[BugFix] Make penalties and bad_words work with async scheduling (#26467)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
eef921f4
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
113 additions
and
14 deletions
+113
-14
tests/v1/e2e/test_async_sched_and_preempt.py
tests/v1/e2e/test_async_sched_and_preempt.py
+19
-5
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+3
-1
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+67
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+24
-7
No files found.
tests/v1/e2e/test_async_sched_and_preempt.py
View file @
5bc26c43
...
...
@@ -28,9 +28,8 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
sampling_param_tests
:
list
[
dict
[
str
,
Any
]]
=
[
dict
(),
# dict(min_tokens=20),
# TODO enable these with https://github.com/vllm-project/vllm/pull/26467.
# dict(repetition_penalty=0.1),
# dict(bad_words=[]),
dict
(
presence_penalty
=-
1.0
),
dict
(
bad_words
=
[
"the"
,
" the"
]),
]
default_params
=
dict
(
...
...
@@ -42,9 +41,9 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
"FLEX_ATTENTION"
)
# m.setenv("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", "1")
outputs
=
[]
outputs
:
list
[
tuple
[
str
,
list
]]
=
[]
for
test_preemption
in
[
False
,
True
]:
for
executor
in
[
"
uni
"
,
"
mp
"
]:
for
executor
in
[
"
mp
"
,
"
uni
"
]:
for
async_scheduling
in
[
False
,
True
]:
cache_arg
:
dict
[
str
,
Any
]
=
(
dict
(
num_gpu_blocks_override
=
32
)
...
...
@@ -78,6 +77,21 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
),
)
)
if
not
outputs
:
# First check that the different parameter configs
# actually result in different output.
for
other_test
,
params
in
zip
(
results
[
1
:],
sampling_param_tests
[
1
:]
):
with
pytest
.
raises
(
AssertionError
):
check_outputs_equal
(
outputs_0_lst
=
results
[
0
],
outputs_1_lst
=
other_test
,
name_0
=
f
"baseline params=
{
params
}
"
,
name_1
=
f
"other params=
{
params
}
"
,
)
outputs
.
append
((
test_config
,
results
))
baseline_config
,
baseline_tests
=
outputs
[
0
]
...
...
vllm/v1/core/sched/scheduler.py
View file @
5bc26c43
...
...
@@ -737,7 +737,9 @@ class Scheduler(SchedulerInterface):
req_to_new_blocks
[
req_id
].
get_block_ids
(
allow_none
=
True
)
)
num_computed_tokens
.
append
(
req
.
num_computed_tokens
)
num_output_tokens
.
append
(
req
.
num_output_tokens
)
num_output_tokens
.
append
(
req
.
num_output_tokens
+
req
.
num_output_placeholders
)
return
CachedRequestData
(
req_ids
=
req_ids
,
...
...
vllm/v1/worker/gpu_input_batch.py
View file @
5bc26c43
...
...
@@ -79,6 +79,7 @@ class InputBatch:
block_sizes
:
list
[
int
],
# The block_size of each kv cache group
kernel_block_sizes
:
list
[
int
],
logitsprocs
:
Optional
[
LogitsProcessors
]
=
None
,
logitsprocs_need_output_token_ids
:
bool
=
False
,
is_spec_decode
:
bool
=
False
,
is_pooling_model
:
bool
=
False
,
num_speculative_tokens
:
int
=
0
,
...
...
@@ -240,6 +241,7 @@ class InputBatch:
# Store provided logitsprocs. If none are provided, initialize empty
# data structure
self
.
logitsprocs
=
logitsprocs
or
LogitsProcessors
()
self
.
logitsprocs_need_output_token_ids
=
logitsprocs_need_output_token_ids
# Store last speculative tokens for sampler.
self
.
spec_token_ids
:
list
[
Optional
[
list
[
int
]]]
=
[]
...
...
@@ -252,6 +254,11 @@ class InputBatch:
# Cached reference to the GPU tensor of previously sampled tokens
self
.
prev_sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
self
.
prev_req_id_to_index
:
Optional
[
dict
[
str
,
int
]]
=
None
# These are used to update output_token_ids with real sampled
# ids from prior step, if required by current sampling params
# (e.g. penalties).
self
.
sampled_token_ids_cpu
:
Optional
[
torch
.
Tensor
]
=
None
self
.
async_copy_ready_event
:
Optional
[
torch
.
cuda
.
Event
]
=
None
@
property
def
req_ids
(
self
)
->
list
[
str
]:
...
...
@@ -776,6 +783,19 @@ class InputBatch:
self
.
_make_prompt_token_ids_tensor
()
if
needs_prompt_token_ids
else
None
)
# Only set output_token_ids if required by the current requests'
# sampling parameters.
needs_output_token_ids
=
(
not
self
.
no_penalties
or
bool
(
self
.
bad_words_token_ids
)
or
self
.
logitsprocs_need_output_token_ids
)
output_token_ids
=
(
cast
(
list
[
list
[
int
]],
self
.
req_output_token_ids
)
if
needs_output_token_ids
else
[]
)
allowed_token_ids_mask
:
Optional
[
torch
.
Tensor
]
=
None
if
not
self
.
no_allowed_token_ids
:
assert
self
.
allowed_token_ids_mask
is
not
None
...
...
@@ -798,7 +818,7 @@ class InputBatch:
frequency_penalties
=
self
.
frequency_penalties
[:
num_reqs
],
presence_penalties
=
self
.
presence_penalties
[:
num_reqs
],
repetition_penalties
=
self
.
repetition_penalties
[:
num_reqs
],
output_token_ids
=
cast
(
list
[
list
[
int
]],
self
.
req_
output_token_ids
)
,
output_token_ids
=
output_token_ids
,
spec_token_ids
=
cast
(
list
[
list
[
int
]],
self
.
spec_token_ids
),
no_penalties
=
self
.
no_penalties
,
allowed_token_ids_mask
=
allowed_token_ids_mask
,
...
...
@@ -859,6 +879,52 @@ class InputBatch:
return
prompt_lora_mapping
,
token_lora_mapping
,
active_lora_requests
def
set_async_sampled_token_ids
(
self
,
sampled_token_ids_cpu
:
torch
.
Tensor
,
async_copy_ready_event
:
torch
.
cuda
.
Event
,
)
->
None
:
"""
In async scheduling case, store ref to sampled_token_ids_cpu
tensor and corresponding copy-ready event. Used to repair
output_token_ids prior to sampling, if needed by logits processors.
"""
if
self
.
sampling_metadata
.
output_token_ids
:
self
.
sampled_token_ids_cpu
=
sampled_token_ids_cpu
self
.
async_copy_ready_event
=
async_copy_ready_event
else
:
self
.
sampled_token_ids_cpu
=
None
self
.
async_copy_ready_event
=
None
def
update_async_output_token_ids
(
self
)
->
None
:
"""
In async scheduling case, update output_token_ids in sampling metadata
from prior steps sampled token ids once they've finished copying to CPU.
This is called right before they are needed by the logits processors.
"""
output_token_ids
=
self
.
sampling_metadata
.
output_token_ids
if
self
.
sampled_token_ids_cpu
is
None
or
not
output_token_ids
:
# Output token ids not needed or not async scheduling.
return
assert
self
.
prev_req_id_to_index
is
not
None
sampled_token_ids
=
None
for
index
,
req_id
in
enumerate
(
self
.
req_ids
):
prev_index
=
self
.
prev_req_id_to_index
.
get
(
req_id
)
if
prev_index
is
None
:
continue
req_output_token_ids
=
output_token_ids
[
index
]
if
not
req_output_token_ids
or
req_output_token_ids
[
-
1
]
!=
-
1
:
# Final output id is not a placeholder, some tokens must have
# been discarded after a kv-load failure.
continue
if
sampled_token_ids
is
None
:
assert
self
.
async_copy_ready_event
is
not
None
self
.
async_copy_ready_event
.
synchronize
()
sampled_token_ids
=
self
.
sampled_token_ids_cpu
.
squeeze
(
-
1
).
tolist
()
# Replace placeholder token id with actual sampled id.
req_output_token_ids
[
-
1
]
=
sampled_token_ids
[
prev_index
]
@
property
def
num_reqs
(
self
)
->
int
:
return
len
(
self
.
req_id_to_index
)
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
5bc26c43
...
...
@@ -178,7 +178,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
self
.
_invalid_req_indices
=
invalid_req_indices
# Event on the copy stream so we can synchronize the non-blocking copy.
self
.
_
async_copy_ready_event
=
torch
.
cuda
.
Event
()
self
.
async_copy_ready_event
=
torch
.
cuda
.
Event
()
# Keep a reference to the device tensor to avoid it being
# deallocated until we finish copying it to the host.
...
...
@@ -188,22 +188,22 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
default_stream
=
torch
.
cuda
.
current_stream
()
with
torch
.
cuda
.
stream
(
async_output_copy_stream
):
async_output_copy_stream
.
wait_stream
(
default_stream
)
self
.
_
sampled_token_ids_cpu
=
self
.
_sampled_token_ids
.
to
(
self
.
sampled_token_ids_cpu
=
self
.
_sampled_token_ids
.
to
(
"cpu"
,
non_blocking
=
True
)
self
.
_
async_copy_ready_event
.
record
()
self
.
async_copy_ready_event
.
record
()
def
get_output
(
self
)
->
ModelRunnerOutput
:
"""Copy the device tensors to the host and return a ModelRunnerOutput.
This function blocks until the copy is finished.
"""
self
.
_
async_copy_ready_event
.
synchronize
()
self
.
async_copy_ready_event
.
synchronize
()
# Release the device tensor once the copy has completed
del
self
.
_sampled_token_ids
valid_sampled_token_ids
=
self
.
_
sampled_token_ids_cpu
.
tolist
()
valid_sampled_token_ids
=
self
.
sampled_token_ids_cpu
.
tolist
()
for
i
in
self
.
_invalid_req_indices
:
valid_sampled_token_ids
[
i
].
clear
()
...
...
@@ -349,6 +349,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# solution, we initialize the input batch here, and re-initialize it
# in `initialize_kv_cache` if the block_sizes here is different from
# the block_sizes in the kv cache config.
custom_logitsprocs
=
model_config
.
logits_processors
self
.
input_batch
=
InputBatch
(
max_num_reqs
=
self
.
max_num_reqs
,
# We need to use the encoder length for encoder-decoer
...
...
@@ -366,8 +367,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
device
,
self
.
pin_memory
,
self
.
is_pooling_model
,
self
.
vllm_config
.
model_config
.
logits
_
proc
essor
s
,
custom_
logitsprocs
,
),
# We currently don't know whether a particular custom logits processor
# uses output token ids so we set this conservatively.
logitsprocs_need_output_token_ids
=
bool
(
custom_logitsprocs
),
is_pooling_model
=
self
.
is_pooling_model
,
)
...
...
@@ -2210,6 +2214,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Sample the next token and get logprobs if needed.
sampling_metadata
=
self
.
input_batch
.
sampling_metadata
if
spec_decode_metadata
is
None
:
# Update output token ids with tokens sampled in last step
# if async scheduling and required by current sampling params.
self
.
input_batch
.
update_async_output_token_ids
()
return
self
.
sampler
(
logits
=
logits
,
sampling_metadata
=
sampling_metadata
,
...
...
@@ -2666,13 +2673,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if
not
self
.
use_async_scheduling
:
return
output
return
AsyncGPUModelRunnerOutput
(
async_output
=
AsyncGPUModelRunnerOutput
(
model_runner_output
=
output
,
sampled_token_ids
=
sampler_output
.
sampled_token_ids
,
invalid_req_indices
=
invalid_req_indices
,
async_output_copy_stream
=
self
.
async_output_copy_stream
,
)
# Save ref of sampled_token_ids CPU tensor if the batch contains
# any requests with sampling params that that require output ids.
self
.
input_batch
.
set_async_sampled_token_ids
(
async_output
.
sampled_token_ids_cpu
,
async_output
.
async_copy_ready_event
,
)
return
async_output
def
take_draft_token_ids
(
self
)
->
Optional
[
DraftTokenIds
]:
if
self
.
_draft_token_ids
is
None
:
return
None
...
...
@@ -4198,6 +4214,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kernel_block_sizes
=
kernel_block_sizes
,
is_spec_decode
=
bool
(
self
.
vllm_config
.
speculative_config
),
logitsprocs
=
self
.
input_batch
.
logitsprocs
,
logitsprocs_need_output_token_ids
=
self
.
input_batch
.
logitsprocs_need_output_token_ids
,
is_pooling_model
=
self
.
is_pooling_model
,
num_speculative_tokens
=
(
self
.
vllm_config
.
speculative_config
.
num_speculative_tokens
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment