Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7df331c6
Unverified
Commit
7df331c6
authored
Nov 22, 2025
by
Nick Hill
Committed by
GitHub
Nov 22, 2025
Browse files
[BugFix] Fix chunked prompt logprobs + preemption (#29071)
parent
eb5352a7
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
127 additions
and
31 deletions
+127
-31
tests/conftest.py
tests/conftest.py
+23
-4
tests/v1/sample/test_logprobs.py
tests/v1/sample/test_logprobs.py
+76
-0
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+0
-14
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+17
-3
vllm/v1/worker/tpu_input_batch.py
vllm/v1/worker/tpu_input_batch.py
+0
-10
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+11
-0
No files found.
tests/conftest.py
View file @
7df331c6
...
...
@@ -853,6 +853,7 @@ class VllmRunner:
@
staticmethod
def
_final_steps_generate_w_logprobs
(
req_outputs
:
list
[
RequestOutput
],
include_prompt_token_ids
:
bool
=
False
,
)
->
list
[
TokensTextLogprobsPromptLogprobs
]:
outputs
:
list
[
TokensTextLogprobsPromptLogprobs
]
=
[]
for
req_output
in
req_outputs
:
...
...
@@ -861,9 +862,26 @@ class VllmRunner:
output_str
=
sample
.
text
output_ids
=
list
(
sample
.
token_ids
)
output_logprobs
=
sample
.
logprobs
outputs
.
append
(
(
output_ids
,
output_str
,
output_logprobs
,
req_output
.
prompt_logprobs
)
)
if
include_prompt_token_ids
:
outputs
.
append
(
(
# type: ignore[arg-type]
output_ids
,
output_str
,
output_logprobs
,
req_output
.
prompt_token_ids
,
req_output
.
prompt_logprobs
,
)
)
else
:
outputs
.
append
(
(
output_ids
,
output_str
,
output_logprobs
,
req_output
.
prompt_logprobs
,
)
)
return
outputs
def
generate_w_logprobs
(
...
...
@@ -873,6 +891,7 @@ class VllmRunner:
images
:
PromptImageInput
|
None
=
None
,
audios
:
PromptAudioInput
|
None
=
None
,
videos
:
PromptVideoInput
|
None
=
None
,
include_prompt_token_ids
:
bool
=
False
,
**
kwargs
:
Any
,
)
->
list
[
TokensTextLogprobs
]
|
list
[
TokensTextLogprobsPromptLogprobs
]:
inputs
=
self
.
get_inputs
(
prompts
,
images
=
images
,
videos
=
videos
,
audios
=
audios
)
...
...
@@ -882,7 +901,7 @@ class VllmRunner:
)
toks_str_logsprobs_prompt_logprobs
=
self
.
_final_steps_generate_w_logprobs
(
req_outputs
req_outputs
,
include_prompt_token_ids
)
# Omit prompt logprobs if not required by sampling params
return
(
...
...
tests/v1/sample/test_logprobs.py
View file @
7df331c6
...
...
@@ -605,3 +605,79 @@ def test_spec_decode_logprobs(
)
assert
ref_logprob
.
rank
==
spec_logprob
.
rank
assert
ref_logprob
.
decoded_token
==
spec_logprob
.
decoded_token
def
test_prompt_logprobs_with_chunking_and_preemption
():
"""Test that prompt logprobs are correctly returned when using
both chunked prefill and preemption.
This test ensures that the num_prompt_logprobs tracking persists
across preemptions and prefill chunks.
"""
# Create prompts that will trigger chunking and preemption
prompts
=
[
"The following numbers of the sequence "
+
", "
.
join
(
str
(
i
)
for
i
in
range
(
10
))
+
" are:"
,
"In one word, the capital of France is "
,
]
+
[
f
"Tell me about the number
{
i
}
: "
for
i
in
range
(
32
)]
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
40
,
min_tokens
=
20
,
prompt_logprobs
=
2
,
# Request prompt logprobs
)
with
VllmRunner
(
"Qwen/Qwen3-0.6B"
,
max_model_len
=
512
,
enable_chunked_prefill
=
True
,
max_num_batched_tokens
=
48
,
# Force prefill chunking
num_gpu_blocks_override
=
32
,
# Force preemptions
disable_log_stats
=
False
,
gpu_memory_utilization
=
0.25
,
)
as
vllm_model
:
metrics_before
=
vllm_model
.
llm
.
get_metrics
()
# Generate with prompt logprobs using generate_w_logprobs which
# returns (output_ids, output_str, output_logprobs, prompt_logprobs)
outputs
=
vllm_model
.
generate_w_logprobs
(
prompts
,
sampling_params
=
sampling_params
,
include_prompt_token_ids
=
True
)
# Verify that all outputs have prompt logprobs
for
i
,
output
in
enumerate
(
outputs
):
_
,
_
,
_
,
prompt_token_ids
,
prompt_logprobs
=
output
assert
prompt_logprobs
is
not
None
and
len
(
prompt_logprobs
)
>
0
,
(
f
"Output
{
i
}
missing prompt logprobs"
)
assert
len
(
prompt_logprobs
)
==
len
(
prompt_token_ids
),
(
"Unexpected number of prompt logprob positions"
)
# Each position should have the requested number of logprobs
for
pos
,
logprobs_dict
in
enumerate
(
prompt_logprobs
):
if
logprobs_dict
is
not
None
:
# First token may be None
assert
(
sampling_params
.
prompt_logprobs
<=
len
(
logprobs_dict
)
<=
sampling_params
.
prompt_logprobs
+
1
),
(
f
"Output
{
i
}
position
{
pos
}
has
{
len
(
logprobs_dict
)
}
"
f
"logprobs, expected
{
sampling_params
.
prompt_logprobs
}
"
)
# Check that we actually had preemptions
metrics_after
=
vllm_model
.
llm
.
get_metrics
()
preemptions_before
=
next
(
(
m
.
value
for
m
in
metrics_before
if
m
.
name
==
"vllm:num_preemptions"
),
0
)
preemptions_after
=
next
(
(
m
.
value
for
m
in
metrics_after
if
m
.
name
==
"vllm:num_preemptions"
),
0
)
preemptions
=
preemptions_after
-
preemptions_before
assert
preemptions
>
0
,
"Test did not trigger any preemptions"
print
(
f
"Test passed with
{
preemptions
}
preemptions"
)
vllm/v1/worker/gpu_input_batch.py
View file @
7df331c6
...
...
@@ -219,9 +219,6 @@ class InputBatch:
self
.
generators
:
dict
[
int
,
torch
.
Generator
]
=
{}
self
.
num_logprobs
:
dict
[
str
,
int
]
=
{}
# NOTE(rob): num_prompt_logprobs only includes reqs
# that are currently in the prefill phase.
self
.
num_prompt_logprobs
:
dict
[
str
,
int
]
=
{}
# To accumulate prompt logprobs tensor chunks across prefill steps.
self
.
in_progress_prompt_logprobs_cpu
:
dict
[
str
,
LogprobsTensors
]
=
{}
...
...
@@ -385,12 +382,6 @@ class InputBatch:
if
sampling_params
.
logprobs
==
-
1
else
sampling_params
.
logprobs
)
if
sampling_params
.
prompt_logprobs
is
not
None
:
self
.
num_prompt_logprobs
[
req_id
]
=
(
self
.
vocab_size
if
sampling_params
.
prompt_logprobs
==
-
1
else
sampling_params
.
prompt_logprobs
)
if
sampling_params
.
allowed_token_ids
:
self
.
has_allowed_token_ids
.
add
(
req_id
)
...
...
@@ -488,7 +479,6 @@ class InputBatch:
self
.
repetition_penalties_reqs
.
discard
(
req_id
)
self
.
generators
.
pop
(
req_index
,
None
)
self
.
num_logprobs
.
pop
(
req_id
,
None
)
self
.
num_prompt_logprobs
.
pop
(
req_id
,
None
)
self
.
in_progress_prompt_logprobs_cpu
.
pop
(
req_id
,
None
)
self
.
has_allowed_token_ids
.
discard
(
req_id
)
...
...
@@ -972,10 +962,6 @@ class InputBatch:
def
max_num_logprobs
(
self
)
->
int
|
None
:
return
max
(
self
.
num_logprobs
.
values
())
if
self
.
num_logprobs
else
None
@
property
def
no_prompt_logprob
(
self
)
->
bool
:
return
not
self
.
num_prompt_logprobs
@
property
def
no_allowed_token_ids
(
self
)
->
bool
:
return
len
(
self
.
has_allowed_token_ids
)
==
0
vllm/v1/worker/gpu_model_runner.py
View file @
7df331c6
...
...
@@ -393,6 +393,9 @@ class GPUModelRunner(
# Request states.
self
.
requests
:
dict
[
str
,
CachedRequestState
]
=
{}
# NOTE(rob): num_prompt_logprobs only includes reqs
# that are currently in the prefill phase.
self
.
num_prompt_logprobs
:
dict
[
str
,
int
]
=
{}
self
.
comm_stream
=
torch
.
cuda
.
Stream
()
# Input Batch
...
...
@@ -687,6 +690,7 @@ class GPUModelRunner(
# Remove finished requests from the cached states.
for
req_id
in
scheduler_output
.
finished_req_ids
:
self
.
requests
.
pop
(
req_id
,
None
)
self
.
num_prompt_logprobs
.
pop
(
req_id
,
None
)
# Remove the finished requests from the persistent batch.
# NOTE(woosuk): There could be an edge case where finished_req_ids and
# scheduled_req_ids overlap. This happens when a request is aborted and
...
...
@@ -755,6 +759,13 @@ class GPUModelRunner(
)
self
.
requests
[
req_id
]
=
req_state
if
sampling_params
and
sampling_params
.
prompt_logprobs
is
not
None
:
self
.
num_prompt_logprobs
[
req_id
]
=
(
self
.
input_batch
.
vocab_size
if
sampling_params
.
prompt_logprobs
==
-
1
else
sampling_params
.
prompt_logprobs
)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if
self
.
uses_mrope
:
self
.
_init_mrope_positions
(
req_state
)
...
...
@@ -2671,7 +2682,7 @@ class GPUModelRunner(
scheduler_output
,
self
.
vllm_config
)
if
self
.
cache_config
.
kv_sharing_fast_prefill
:
assert
not
self
.
input_batch
.
num_prompt_logprobs
,
(
assert
not
self
.
num_prompt_logprobs
,
(
"--kv-sharing-fast-prefill produces incorrect "
"logprobs for prompt tokens, tokens, please disable "
"it when the requests need prompt logprobs"
...
...
@@ -3436,7 +3447,7 @@ class GPUModelRunner(
hidden_states
:
torch
.
Tensor
,
num_scheduled_tokens
:
dict
[
str
,
int
],
)
->
dict
[
str
,
LogprobsTensors
|
None
]:
num_prompt_logprobs_dict
=
self
.
input_batch
.
num_prompt_logprobs
num_prompt_logprobs_dict
=
self
.
num_prompt_logprobs
if
not
num_prompt_logprobs_dict
:
return
{}
...
...
@@ -3447,7 +3458,10 @@ class GPUModelRunner(
# maintainable loop over optimal performance.
completed_prefill_reqs
=
[]
for
req_id
,
num_prompt_logprobs
in
num_prompt_logprobs_dict
.
items
():
num_tokens
=
num_scheduled_tokens
[
req_id
]
num_tokens
=
num_scheduled_tokens
.
get
(
req_id
)
if
num_tokens
is
None
:
# This can happen if the request was preempted in prefill stage.
continue
# Get metadata for this request.
request
=
self
.
requests
[
req_id
]
...
...
vllm/v1/worker/tpu_input_batch.py
View file @
7df331c6
...
...
@@ -149,9 +149,6 @@ class InputBatch:
self
.
generators
:
dict
[
int
,
torch
.
Generator
]
=
{}
self
.
num_logprobs
:
dict
[
str
,
int
]
=
{}
# NOTE(rob): num_prompt_logprobs only includes reqs
# that are currently in the prefill phase.
self
.
num_prompt_logprobs
:
dict
[
str
,
int
]
=
{}
# To accumulate prompt logprobs tensor chunks across prefill steps.
self
.
in_progress_prompt_logprobs_cpu
:
dict
[
str
,
LogprobsTensors
]
=
{}
...
...
@@ -256,8 +253,6 @@ class InputBatch:
if
sampling_params
.
logprobs
is
not
None
:
self
.
num_logprobs
[
req_id
]
=
sampling_params
.
logprobs
if
sampling_params
.
prompt_logprobs
is
not
None
:
self
.
num_prompt_logprobs
[
req_id
]
=
sampling_params
.
prompt_logprobs
if
sampling_params
.
logit_bias
is
not
None
:
self
.
logit_bias
[
req_index
]
=
sampling_params
.
logit_bias
...
...
@@ -317,7 +312,6 @@ class InputBatch:
self
.
repetition_penalties_reqs
.
discard
(
req_id
)
self
.
generators
.
pop
(
req_index
,
None
)
self
.
num_logprobs
.
pop
(
req_id
,
None
)
self
.
num_prompt_logprobs
.
pop
(
req_id
,
None
)
self
.
in_progress_prompt_logprobs_cpu
.
pop
(
req_id
,
None
)
# LoRA
...
...
@@ -584,10 +578,6 @@ class InputBatch:
def
max_num_logprobs
(
self
)
->
int
|
None
:
return
max
(
self
.
num_logprobs
.
values
())
if
self
.
num_logprobs
else
None
@
property
def
no_prompt_logprob
(
self
)
->
bool
:
return
not
self
.
num_prompt_logprobs
@
property
def
no_allowed_token_ids
(
self
)
->
bool
:
return
len
(
self
.
has_allowed_token_ids
)
==
0
vllm/v1/worker/tpu_model_runner.py
View file @
7df331c6
...
...
@@ -247,6 +247,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Request states.
self
.
requests
:
dict
[
str
,
CachedRequestState
]
=
{}
# NOTE(rob): num_prompt_logprobs only includes reqs
# that are currently in the prefill phase.
self
.
num_prompt_logprobs
:
dict
[
str
,
int
]
=
{}
# Initialize input batch early to avoid AttributeError in _update_states
self
.
input_batch
=
InputBatch
(
...
...
@@ -420,6 +423,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Remove finished requests from the cached states.
for
req_id
in
scheduler_output
.
finished_req_ids
:
self
.
requests
.
pop
(
req_id
,
None
)
self
.
num_prompt_logprobs
.
pop
(
req_id
,
None
)
# Remove the finished requests from the persistent batch.
# NOTE(woosuk): There could be an edge case where finished_req_ids and
...
...
@@ -477,6 +481,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
lora_request
=
new_req_data
.
lora_request
,
)
if
sampling_params
and
sampling_params
.
prompt_logprobs
is
not
None
:
self
.
num_prompt_logprobs
[
req_id
]
=
(
self
.
input_batch
.
vocab_size
if
sampling_params
.
prompt_logprobs
==
-
1
else
sampling_params
.
prompt_logprobs
)
req_ids_to_add
.
append
(
req_id
)
# Update the states of the running/resumed requests.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment