Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4e57c658
Unverified
Commit
4e57c658
authored
Nov 25, 2025
by
Nick Hill
Committed by
GitHub
Nov 25, 2025
Browse files
[Core] Support logprobs with spec decode + async scheduling (#29223)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
e7d77627
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
35 additions
and
25 deletions
+35
-25
tests/v1/e2e/test_async_scheduling.py
tests/v1/e2e/test_async_scheduling.py
+6
-1
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+0
-2
vllm/v1/sample/rejection_sampler.py
vllm/v1/sample/rejection_sampler.py
+12
-2
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+17
-20
No files found.
tests/v1/e2e/test_async_scheduling.py
View file @
4e57c658
...
@@ -87,6 +87,11 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
...
@@ -87,6 +87,11 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
# Set small draft model len to force doesn't-fit-in-drafter case.
# Set small draft model len to force doesn't-fit-in-drafter case.
spec_config_short
=
spec_config
|
{
"max_model_len"
:
50
}
spec_config_short
=
spec_config
|
{
"max_model_len"
:
50
}
test_sampling_params
=
[
dict
(),
dict
(
logprobs
=
2
),
]
# test_preemption, executor, async_scheduling,
# test_preemption, executor, async_scheduling,
# spec_config, test_prefill_chunking
# spec_config, test_prefill_chunking
test_configs
=
[
test_configs
=
[
...
@@ -103,7 +108,7 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
...
@@ -103,7 +108,7 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
(
True
,
"uni"
,
True
,
spec_config_short
,
True
),
(
True
,
"uni"
,
True
,
spec_config_short
,
True
),
]
]
run_tests
(
monkeypatch
,
MTP_MODEL
,
test_configs
,
[{}]
)
run_tests
(
monkeypatch
,
MTP_MODEL
,
test_configs
,
test_sampling_params
)
@
dynamo_config
.
patch
(
cache_size_limit
=
16
)
@
dynamo_config
.
patch
(
cache_size_limit
=
16
)
...
...
vllm/v1/core/sched/scheduler.py
View file @
4e57c658
...
@@ -1089,8 +1089,6 @@ class Scheduler(SchedulerInterface):
...
@@ -1089,8 +1089,6 @@ class Scheduler(SchedulerInterface):
and
request
.
sampling_params
.
logprobs
is
not
None
and
request
.
sampling_params
.
logprobs
is
not
None
and
logprobs
and
logprobs
):
):
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
new_logprobs
=
logprobs
.
slice
(
req_index
,
req_index
+
1
)
new_logprobs
=
logprobs
.
slice
(
req_index
,
req_index
+
1
)
if
new_token_ids
and
self
.
structured_output_manager
.
should_advance
(
request
):
if
new_token_ids
and
self
.
structured_output_manager
.
should_advance
(
request
):
...
...
vllm/v1/sample/rejection_sampler.py
View file @
4e57c658
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Sequence
from
dataclasses
import
replace
from
dataclasses
import
replace
import
torch
import
torch
...
@@ -204,7 +205,9 @@ class RejectionSampler(nn.Module):
...
@@ -204,7 +205,9 @@ class RejectionSampler(nn.Module):
def
parse_output
(
def
parse_output
(
output_token_ids
:
torch
.
Tensor
,
output_token_ids
:
torch
.
Tensor
,
vocab_size
:
int
,
vocab_size
:
int
,
)
->
list
[
list
[
int
]]:
discard_req_indices
:
Sequence
[
int
]
=
(),
return_cu_num_tokens
:
bool
=
False
,
)
->
tuple
[
list
[
list
[
int
]],
list
[
int
]
|
None
]:
"""Parse the output of the rejection sampler.
"""Parse the output of the rejection sampler.
Args:
Args:
output_token_ids: The sampled token IDs in shape
output_token_ids: The sampled token IDs in shape
...
@@ -212,6 +215,8 @@ class RejectionSampler(nn.Module):
...
@@ -212,6 +215,8 @@ class RejectionSampler(nn.Module):
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
and will be filtered out in this function.
and will be filtered out in this function.
vocab_size: The size of the vocabulary.
vocab_size: The size of the vocabulary.
discard_req_indices: Optional row indices to discard tokens in.
return_cu_num_tokens: Whether to also return cumulative token counts.
Returns:
Returns:
A list of lists of token IDs.
A list of lists of token IDs.
"""
"""
...
@@ -220,10 +225,15 @@ class RejectionSampler(nn.Module):
...
@@ -220,10 +225,15 @@ class RejectionSampler(nn.Module):
valid_mask
=
(
output_token_ids_np
!=
PLACEHOLDER_TOKEN_ID
)
&
(
valid_mask
=
(
output_token_ids_np
!=
PLACEHOLDER_TOKEN_ID
)
&
(
output_token_ids_np
<
vocab_size
output_token_ids_np
<
vocab_size
)
)
cu_num_tokens
=
None
if
return_cu_num_tokens
:
cu_num_tokens
=
[
0
]
+
valid_mask
.
sum
(
axis
=
1
).
cumsum
().
tolist
()
if
len
(
discard_req_indices
)
>
0
:
valid_mask
[
discard_req_indices
]
=
False
outputs
=
[
outputs
=
[
row
[
valid_mask
[
i
]].
tolist
()
for
i
,
row
in
enumerate
(
output_token_ids_np
)
row
[
valid_mask
[
i
]].
tolist
()
for
i
,
row
in
enumerate
(
output_token_ids_np
)
]
]
return
outputs
return
outputs
,
cu_num_tokens
def
apply_logits_processors
(
def
apply_logits_processors
(
self
,
self
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
4e57c658
...
@@ -183,7 +183,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
...
@@ -183,7 +183,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
self
,
self
,
model_runner_output
:
ModelRunnerOutput
,
model_runner_output
:
ModelRunnerOutput
,
sampled_token_ids
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
logprobs_tensors
:
torch
.
Tensor
|
None
,
logprobs_tensors
:
Logprobs
Tensor
s
|
None
,
invalid_req_indices
:
list
[
int
],
invalid_req_indices
:
list
[
int
],
async_output_copy_stream
:
torch
.
cuda
.
Stream
,
async_output_copy_stream
:
torch
.
cuda
.
Stream
,
vocab_size
:
int
,
vocab_size
:
int
,
...
@@ -219,28 +219,29 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
...
@@ -219,28 +219,29 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
This function blocks until the copy is finished.
This function blocks until the copy is finished.
"""
"""
max_gen_len
=
self
.
sampled_token_ids_cpu
.
shape
[
-
1
]
self
.
async_copy_ready_event
.
synchronize
()
self
.
async_copy_ready_event
.
synchronize
()
# Release the device tensors once the copy has completed.
# Release the device tensors once the copy has completed.
del
self
.
_logprobs_tensors
del
self
.
_logprobs_tensors
del
self
.
_sampled_token_ids
del
self
.
_sampled_token_ids
max_gen_len
=
self
.
sampled_token_ids_cpu
.
shape
[
-
1
]
if
max_gen_len
==
1
:
if
max_gen_len
==
1
:
valid_sampled_token_ids
=
self
.
sampled_token_ids_cpu
.
tolist
()
valid_sampled_token_ids
=
self
.
sampled_token_ids_cpu
.
tolist
()
for
i
in
self
.
_invalid_req_indices
:
valid_sampled_token_ids
[
i
].
clear
()
cu_num_tokens
=
None
else
:
else
:
valid_sampled_token_ids
=
RejectionSampler
.
parse_output
(
valid_sampled_token_ids
,
cu_num_tokens
=
RejectionSampler
.
parse_output
(
self
.
sampled_token_ids_cpu
,
self
.
sampled_token_ids_cpu
,
self
.
vocab_size
,
self
.
vocab_size
,
self
.
_invalid_req_indices
,
return_cu_num_tokens
=
self
.
_logprobs_tensors_cpu
is
not
None
,
)
)
for
i
in
self
.
_invalid_req_indices
:
valid_sampled_token_ids
[
i
].
clear
()
output
=
self
.
_model_runner_output
output
=
self
.
_model_runner_output
output
.
sampled_token_ids
=
valid_sampled_token_ids
output
.
sampled_token_ids
=
valid_sampled_token_ids
if
self
.
_logprobs_tensors_cpu
:
if
self
.
_logprobs_tensors_cpu
:
# NOTE(nick): this will need to be updated to use cu_num_accepted_tokens
output
.
logprobs
=
self
.
_logprobs_tensors_cpu
.
tolists
(
cu_num_tokens
)
# for async sched + spec decode + logprobs compatibility.
output
.
logprobs
=
self
.
_logprobs_tensors_cpu
.
tolists
()
return
output
return
output
...
@@ -2597,28 +2598,24 @@ class GPUModelRunner(
...
@@ -2597,28 +2598,24 @@ class GPUModelRunner(
sampled_token_ids
=
sampler_output
.
sampled_token_ids
sampled_token_ids
=
sampler_output
.
sampled_token_ids
logprobs_tensors
=
sampler_output
.
logprobs_tensors
logprobs_tensors
=
sampler_output
.
logprobs_tensors
invalid_req_indices
=
[]
invalid_req_indices
=
[]
cu_num_
new_
tokens
:
list
[
int
]
|
None
=
None
cu_num_tokens
:
list
[
int
]
|
None
=
None
if
not
self
.
use_async_scheduling
:
if
not
self
.
use_async_scheduling
:
# Get the valid generated tokens.
# Get the valid generated tokens.
max_gen_len
=
sampled_token_ids
.
shape
[
-
1
]
max_gen_len
=
sampled_token_ids
.
shape
[
-
1
]
if
max_gen_len
==
1
:
if
max_gen_len
==
1
:
# No spec decode tokens.
# No spec decode tokens.
valid_sampled_token_ids
=
self
.
_to_list
(
sampled_token_ids
)
valid_sampled_token_ids
=
self
.
_to_list
(
sampled_token_ids
)
# Mask out the sampled tokens that should not be sampled.
for
i
in
discard_sampled_tokens_req_indices
:
valid_sampled_token_ids
[
int
(
i
)].
clear
()
else
:
else
:
# Includes spec decode tokens.
# Includes spec decode tokens.
valid_sampled_token_ids
=
self
.
r
ejection
_s
ampler
.
parse_output
(
valid_sampled_token_ids
,
cu_num_tokens
=
R
ejection
S
ampler
.
parse_output
(
sampled_token_ids
,
sampled_token_ids
,
self
.
input_batch
.
vocab_size
,
self
.
input_batch
.
vocab_size
,
discard_sampled_tokens_req_indices
,
return_cu_num_tokens
=
logprobs_tensors
is
not
None
,
)
)
if
logprobs_tensors
:
# Needed for extracting logprobs when spec decoding.
# This must be done prior to discarding sampled tokens.
cu_num_new_tokens
=
[
0
]
for
toks
in
valid_sampled_token_ids
:
cu_num_new_tokens
.
append
(
cu_num_new_tokens
[
-
1
]
+
len
(
toks
))
# Mask out the sampled tokens that should not be sampled.
for
i
in
discard_sampled_tokens_req_indices
:
valid_sampled_token_ids
[
int
(
i
)].
clear
()
else
:
else
:
valid_sampled_token_ids
=
[]
valid_sampled_token_ids
=
[]
invalid_req_indices
=
discard_sampled_tokens_req_indices
.
tolist
()
invalid_req_indices
=
discard_sampled_tokens_req_indices
.
tolist
()
...
@@ -2672,7 +2669,7 @@ class GPUModelRunner(
...
@@ -2672,7 +2669,7 @@ class GPUModelRunner(
req_state
.
output_token_ids
.
extend
(
sampled_ids
)
req_state
.
output_token_ids
.
extend
(
sampled_ids
)
logprobs_lists
=
(
logprobs_lists
=
(
logprobs_tensors
.
tolists
(
cu_num_
new_
tokens
)
logprobs_tensors
.
tolists
(
cu_num_tokens
)
if
not
self
.
use_async_scheduling
and
logprobs_tensors
is
not
None
if
not
self
.
use_async_scheduling
and
logprobs_tensors
is
not
None
else
None
else
None
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment