Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d44a63c6
Unverified
Commit
d44a63c6
authored
Nov 22, 2025
by
Nick Hill
Committed by
GitHub
Nov 22, 2025
Browse files
[BugFix] Fix returned logprobs with spec decode + prefill chunking (#29216)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
066209a0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
22 additions
and
15 deletions
+22
-15
tests/v1/sample/test_logprobs.py
tests/v1/sample/test_logprobs.py
+9
-4
vllm/v1/sample/sampler.py
vllm/v1/sample/sampler.py
+4
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+9
-10
No files found.
tests/v1/sample/test_logprobs.py
View file @
d44a63c6
...
@@ -521,8 +521,8 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
...
@@ -521,8 +521,8 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
pytest
.
param
(
pytest
.
param
(
(
(
"eagle"
,
"eagle"
,
"meta-llama/Llama-3.
1-8
B-Instruct"
,
"meta-llama/Llama-3.
2-1
B-Instruct"
,
"
yuhuili/EAGLE-LLaMA3.1-Instruct-8B
"
,
"
nm-testing/Llama3_2_1B_speculator.eagle3
"
,
),
),
marks
=
large_gpu_mark
(
min_gb
=
32
),
marks
=
large_gpu_mark
(
min_gb
=
32
),
),
),
...
@@ -541,7 +541,7 @@ def test_spec_decode_logprobs(
...
@@ -541,7 +541,7 @@ def test_spec_decode_logprobs(
"""
"""
from
vllm
import
LLM
from
vllm
import
LLM
prompt
=
"Hello world
"
prompt
=
"Hello world
"
*
50
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
temperature
=
0
,
logprobs
=
3
,
max_tokens
=
10
,
ignore_eos
=
False
temperature
=
0
,
logprobs
=
3
,
max_tokens
=
10
,
ignore_eos
=
False
)
)
...
@@ -582,6 +582,9 @@ def test_spec_decode_logprobs(
...
@@ -582,6 +582,9 @@ def test_spec_decode_logprobs(
seed
=
42
,
seed
=
42
,
logprobs_mode
=
logprobs_mode
,
logprobs_mode
=
logprobs_mode
,
gpu_memory_utilization
=
0.4
,
gpu_memory_utilization
=
0.4
,
# Force prefill chunking
enable_chunked_prefill
=
True
,
max_num_batched_tokens
=
32
,
)
)
spec_results
=
spec_llm
.
generate
([
prompt
],
sampling_params
)
spec_results
=
spec_llm
.
generate
([
prompt
],
sampling_params
)
# Collect logprobs outputs from spec decode LLM.
# Collect logprobs outputs from spec decode LLM.
...
@@ -597,6 +600,8 @@ def test_spec_decode_logprobs(
...
@@ -597,6 +600,8 @@ def test_spec_decode_logprobs(
# Per-token logprobs are expected to be the same.
# Per-token logprobs are expected to be the same.
assert
len
(
ref_logprobs
)
==
len
(
spec_logprobs
)
assert
len
(
ref_logprobs
)
==
len
(
spec_logprobs
)
for
ref_logprob
,
spec_logprob
in
zip
(
ref_logprobs
,
spec_logprobs
):
for
ref_logprob
,
spec_logprob
in
zip
(
ref_logprobs
,
spec_logprobs
):
assert
math
.
isclose
(
ref_logprob
.
logprob
,
spec_logprob
.
logprob
,
abs_tol
=
1e-3
)
assert
math
.
isclose
(
ref_logprob
.
logprob
,
spec_logprob
.
logprob
,
rel_tol
=
5e-2
,
abs_tol
=
1e-1
)
assert
ref_logprob
.
rank
==
spec_logprob
.
rank
assert
ref_logprob
.
rank
==
spec_logprob
.
rank
assert
ref_logprob
.
decoded_token
==
spec_logprob
.
decoded_token
assert
ref_logprob
.
decoded_token
==
spec_logprob
.
decoded_token
vllm/v1/sample/sampler.py
View file @
d44a63c6
...
@@ -81,7 +81,10 @@ class Sampler(nn.Module):
...
@@ -81,7 +81,10 @@ class Sampler(nn.Module):
if
logprobs_mode
==
"raw_logprobs"
:
if
logprobs_mode
==
"raw_logprobs"
:
raw_logprobs
=
self
.
compute_logprobs
(
logits
)
raw_logprobs
=
self
.
compute_logprobs
(
logits
)
elif
logprobs_mode
==
"raw_logits"
:
elif
logprobs_mode
==
"raw_logits"
:
raw_logprobs
=
logits
.
clone
()
if
logits
.
dtype
==
torch
.
float32
:
raw_logprobs
=
logits
.
clone
()
else
:
raw_logprobs
=
logits
.
to
(
torch
.
float32
)
# Use float32 for the logits.
# Use float32 for the logits.
logits
=
logits
.
to
(
torch
.
float32
)
logits
=
logits
.
to
(
torch
.
float32
)
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
d44a63c6
...
@@ -2466,7 +2466,9 @@ class GPUModelRunner(
...
@@ -2466,7 +2466,9 @@ class GPUModelRunner(
num_sampled_tokens
=
sampler_output
.
sampled_token_ids
.
shape
[
0
]
num_sampled_tokens
=
sampler_output
.
sampled_token_ids
.
shape
[
0
]
sampled_token_ids
=
sampler_output
.
sampled_token_ids
sampled_token_ids
=
sampler_output
.
sampled_token_ids
logprobs_tensors
=
sampler_output
.
logprobs_tensors
invalid_req_indices
=
[]
invalid_req_indices
=
[]
cu_num_new_tokens
:
list
[
int
]
|
None
=
None
if
not
self
.
use_async_scheduling
:
if
not
self
.
use_async_scheduling
:
# Get the valid generated tokens.
# Get the valid generated tokens.
max_gen_len
=
sampled_token_ids
.
shape
[
-
1
]
max_gen_len
=
sampled_token_ids
.
shape
[
-
1
]
...
@@ -2479,6 +2481,12 @@ class GPUModelRunner(
...
@@ -2479,6 +2481,12 @@ class GPUModelRunner(
sampled_token_ids
,
sampled_token_ids
,
self
.
input_batch
.
vocab_size
,
self
.
input_batch
.
vocab_size
,
)
)
if
logprobs_tensors
:
# Needed for extracting logprobs when spec decoding.
# This must be done prior to discarding sampled tokens.
cu_num_new_tokens
=
[
0
]
for
toks
in
valid_sampled_token_ids
:
cu_num_new_tokens
.
append
(
cu_num_new_tokens
[
-
1
]
+
len
(
toks
))
# Mask out the sampled tokens that should not be sampled.
# Mask out the sampled tokens that should not be sampled.
for
i
in
discard_sampled_tokens_req_indices
:
for
i
in
discard_sampled_tokens_req_indices
:
valid_sampled_token_ids
[
int
(
i
)].
clear
()
valid_sampled_token_ids
[
int
(
i
)].
clear
()
...
@@ -2506,10 +2514,6 @@ class GPUModelRunner(
...
@@ -2506,10 +2514,6 @@ class GPUModelRunner(
# the sampled tokens back, because there's no direct communication
# the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker.
# between the first-stage worker and the last-stage worker.
req_ids
=
self
.
input_batch
.
req_ids
req_ids
=
self
.
input_batch
.
req_ids
logprobs_tensors
=
sampler_output
.
logprobs_tensors
cu_num_accepted_tokens
=
(
[
0
]
if
spec_decode_metadata
and
logprobs_tensors
else
None
)
for
req_idx
in
range
(
num_sampled_tokens
):
for
req_idx
in
range
(
num_sampled_tokens
):
if
self
.
use_async_scheduling
:
if
self
.
use_async_scheduling
:
sampled_ids
=
[
-
1
]
if
req_idx
not
in
invalid_req_indices_set
else
None
sampled_ids
=
[
-
1
]
if
req_idx
not
in
invalid_req_indices_set
else
None
...
@@ -2518,11 +2522,6 @@ class GPUModelRunner(
...
@@ -2518,11 +2522,6 @@ class GPUModelRunner(
num_sampled_ids
:
int
=
len
(
sampled_ids
)
if
sampled_ids
else
0
num_sampled_ids
:
int
=
len
(
sampled_ids
)
if
sampled_ids
else
0
if
cu_num_accepted_tokens
is
not
None
:
cu_num_accepted_tokens
.
append
(
cu_num_accepted_tokens
[
-
1
]
+
num_sampled_ids
)
if
not
sampled_ids
:
if
not
sampled_ids
:
continue
continue
...
@@ -2544,7 +2543,7 @@ class GPUModelRunner(
...
@@ -2544,7 +2543,7 @@ class GPUModelRunner(
req_state
.
output_token_ids
.
extend
(
sampled_ids
)
req_state
.
output_token_ids
.
extend
(
sampled_ids
)
logprobs_lists
=
(
logprobs_lists
=
(
logprobs_tensors
.
tolists
(
cu_num_
accepted
_tokens
)
logprobs_tensors
.
tolists
(
cu_num_
new
_tokens
)
if
not
self
.
use_async_scheduling
and
logprobs_tensors
is
not
None
if
not
self
.
use_async_scheduling
and
logprobs_tensors
is
not
None
else
None
else
None
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment