Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2ac85a45
Unverified
Commit
2ac85a45
authored
Dec 18, 2025
by
Nick Hill
Committed by
GitHub
Dec 18, 2025
Browse files
[BugFix] Fix logprobs with spec decode and modified logits (#30846)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
7b43db21
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
11 deletions
+30
-11
tests/v1/sample/test_logprobs.py
tests/v1/sample/test_logprobs.py
+23
-10
vllm/v1/sample/rejection_sampler.py
vllm/v1/sample/rejection_sampler.py
+7
-1
No files found.
tests/v1/sample/test_logprobs.py
View file @
2ac85a45
...
@@ -547,6 +547,13 @@ def test_spec_decode_logprobs(
...
@@ -547,6 +547,13 @@ def test_spec_decode_logprobs(
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
temperature
=
0
,
logprobs
=
top_logprobs
,
max_tokens
=
10
,
ignore_eos
=
False
temperature
=
0
,
logprobs
=
top_logprobs
,
max_tokens
=
10
,
ignore_eos
=
False
)
)
penalty_sampling_params
=
SamplingParams
(
temperature
=
0
,
logprobs
=
top_logprobs
,
max_tokens
=
10
,
ignore_eos
=
False
,
presence_penalty
=-
1.0
,
)
method
,
model_name
,
spec_model_name
=
model_setup
method
,
model_name
,
spec_model_name
=
model_setup
max_model_len
=
256
max_model_len
=
256
...
@@ -558,14 +565,17 @@ def test_spec_decode_logprobs(
...
@@ -558,14 +565,17 @@ def test_spec_decode_logprobs(
seed
=
42
,
seed
=
42
,
logprobs_mode
=
logprobs_mode
,
logprobs_mode
=
logprobs_mode
,
gpu_memory_utilization
=
0.4
,
gpu_memory_utilization
=
0.4
,
enable_prefix_caching
=
False
,
)
ref_results
=
ref_llm
.
generate
(
[
prompt
,
prompt
],
[
sampling_params
,
penalty_sampling_params
]
)
)
ref_results
=
ref_llm
.
generate
([
prompt
],
sampling_params
)
# Collect logprobs outputs from reference LLM.
# Collect logprobs outputs from reference LLM.
ref_logprobs
=
[]
ref_logprobs
=
[]
for
output
in
ref_results
[
0
].
outputs
:
for
results
in
ref_results
:
for
logprobs
in
output
.
logprob
s
:
for
output
in
results
.
output
s
:
for
token_id
in
logprobs
:
for
logprobs
in
output
.
logprobs
:
ref_logprobs
.
app
end
(
logprobs
[
token_id
]
)
ref_logprobs
.
ext
end
(
logprobs
.
values
()
)
del
ref_llm
del
ref_llm
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
cleanup_dist_env_and_memory
()
cleanup_dist_env_and_memory
()
...
@@ -587,14 +597,17 @@ def test_spec_decode_logprobs(
...
@@ -587,14 +597,17 @@ def test_spec_decode_logprobs(
# Force prefill chunking
# Force prefill chunking
enable_chunked_prefill
=
True
,
enable_chunked_prefill
=
True
,
max_num_batched_tokens
=
32
,
max_num_batched_tokens
=
32
,
enable_prefix_caching
=
False
,
)
spec_results
=
spec_llm
.
generate
(
[
prompt
,
prompt
],
[
sampling_params
,
penalty_sampling_params
]
)
)
spec_results
=
spec_llm
.
generate
([
prompt
],
sampling_params
)
# Collect logprobs outputs from spec decode LLM.
# Collect logprobs outputs from spec decode LLM.
spec_logprobs
=
[]
spec_logprobs
=
[]
for
output
in
spec_results
[
0
].
outputs
:
for
results
in
spec_results
:
for
logprobs
in
output
.
logprob
s
:
for
output
in
results
.
output
s
:
for
token_id
in
logprobs
:
for
logprobs
in
output
.
logprobs
:
spec_logprobs
.
app
end
(
logprobs
[
token_id
]
)
spec_logprobs
.
ext
end
(
logprobs
.
values
()
)
del
spec_llm
del
spec_llm
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
cleanup_dist_env_and_memory
()
cleanup_dist_env_and_memory
()
...
...
vllm/v1/sample/rejection_sampler.py
View file @
2ac85a45
...
@@ -119,8 +119,14 @@ class RejectionSampler(nn.Module):
...
@@ -119,8 +119,14 @@ class RejectionSampler(nn.Module):
raw_target_logits
=
logits
[
target_logits_indices
]
raw_target_logits
=
logits
[
target_logits_indices
]
# Use float32 for the target_logits.
# Use float32 for the target_logits.
raw_target_logits
=
raw_target_logits
.
to
(
torch
.
float32
)
raw_target_logits
=
raw_target_logits
.
to
(
torch
.
float32
)
target_logits
=
raw_target_logits
if
not
self
.
is_processed_logprobs_mode
:
# Clone raw_target_logits before applying processors to preserve
# the original raw logits for logprobs computation, since
# apply_logits_processors modifies the tensor in-place.
target_logits
=
target_logits
.
clone
()
target_logits
=
self
.
apply_logits_processors
(
target_logits
=
self
.
apply_logits_processors
(
raw_
target_logits
,
sampling_metadata
,
metadata
target_logits
,
sampling_metadata
,
metadata
)
)
# [num_tokens, vocab_size]
# [num_tokens, vocab_size]
# NOTE(woosuk): `target_logits` can be updated in place inside the
# NOTE(woosuk): `target_logits` can be updated in place inside the
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment