Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
44822d7f
Unverified
Commit
44822d7f
authored
Dec 01, 2025
by
Nick Hill
Committed by
GitHub
Dec 01, 2025
Browse files
[BugFix] Preserve spec decoding uniform decode when scheduling (#29759)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
342c4f14
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
25 additions
and
17 deletions
+25
-17
tests/v1/e2e/test_spec_decode.py
tests/v1/e2e/test_spec_decode.py
+2
-2
vllm/v1/core/sched/async_scheduler.py
vllm/v1/core/sched/async_scheduler.py
+1
-1
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+22
-14
No files found.
tests/v1/e2e/test_spec_decode.py
View file @
44822d7f
...
@@ -191,8 +191,8 @@ def test_suffix_decoding_acceptance(
...
@@ -191,8 +191,8 @@ def test_suffix_decoding_acceptance(
# Expect the acceptance rate to improve.
# Expect the acceptance rate to improve.
assert
first_accept_rate
<
last_accept_rate
assert
first_accept_rate
<
last_accept_rate
# Heuristic: expect at least 85% acceptance rate at the end.
# Heuristic: expect at least 8
2.
5% acceptance rate at the end.
assert
last_accept_rate
>
0.85
assert
last_accept_rate
>
0.8
2
5
del
spec_llm
del
spec_llm
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
...
vllm/v1/core/sched/async_scheduler.py
View file @
44822d7f
...
@@ -33,7 +33,7 @@ class AsyncScheduler(Scheduler):
...
@@ -33,7 +33,7 @@ class AsyncScheduler(Scheduler):
# in this scheduling step.
# in this scheduling step.
request
.
num_output_placeholders
+=
1
+
cur_num_spec_tokens
request
.
num_output_placeholders
+=
1
+
cur_num_spec_tokens
# Add placeholders for the new tokens in spec_token_ids.
# Add placeholders for the new tokens in spec_token_ids.
# W
w
e will update the actual spec token ids in the worker process.
# We will update the actual spec token ids in the worker process.
request
.
spec_token_ids
=
[
-
1
]
*
self
.
num_spec_tokens
request
.
spec_token_ids
=
[
-
1
]
*
self
.
num_spec_tokens
scheduler_output
.
pending_structured_output_tokens
=
(
scheduler_output
.
pending_structured_output_tokens
=
(
...
...
vllm/v1/core/sched/scheduler.py
View file @
44822d7f
...
@@ -236,6 +236,22 @@ class Scheduler(SchedulerInterface):
...
@@ -236,6 +236,22 @@ class Scheduler(SchedulerInterface):
while
req_index
<
len
(
self
.
running
)
and
token_budget
>
0
:
while
req_index
<
len
(
self
.
running
)
and
token_budget
>
0
:
request
=
self
.
running
[
req_index
]
request
=
self
.
running
[
req_index
]
if
(
request
.
num_output_placeholders
>
0
# This is (num_computed_tokens + 1) - (num_output_placeholders - 1).
# Since output placeholders are also included in the computed tokens
# count, we subtract (num_output_placeholders - 1) to remove any draft
# tokens, so that we can be sure no further steps are needed even if
# they are all rejected.
and
request
.
num_computed_tokens
+
2
-
request
.
num_output_placeholders
>=
request
.
num_prompt_tokens
+
request
.
max_tokens
):
# Async scheduling: Avoid scheduling an extra step when we are sure that
# the previous step has reached request.max_tokens. We don't schedule
# partial draft tokens since this prevents uniform decode optimizations.
req_index
+=
1
continue
num_new_tokens
=
(
num_new_tokens
=
(
request
.
num_tokens_with_spec
request
.
num_tokens_with_spec
+
request
.
num_output_placeholders
+
request
.
num_output_placeholders
...
@@ -245,18 +261,10 @@ class Scheduler(SchedulerInterface):
...
@@ -245,18 +261,10 @@ class Scheduler(SchedulerInterface):
num_new_tokens
=
self
.
scheduler_config
.
long_prefill_token_threshold
num_new_tokens
=
self
.
scheduler_config
.
long_prefill_token_threshold
num_new_tokens
=
min
(
num_new_tokens
,
token_budget
)
num_new_tokens
=
min
(
num_new_tokens
,
token_budget
)
num_spec_placeholders
=
max
(
0
,
request
.
num_output_placeholders
-
1
)
# Make sure the input position does not exceed the max model len.
max_total_tokens
=
min
(
# This is necessary when using spec decoding.
# Avoid scheduling tokens that we're sure won't will be needed based on
# request.max_tokens. For this calculation we assume placeholder
# speculated output tokens are rejected.
request
.
num_prompt_tokens
+
request
.
max_tokens
+
num_spec_placeholders
,
# Make sure the input position does not exceed the max model len.
# This is necessary when using spec decoding.
self
.
max_model_len
,
)
num_new_tokens
=
min
(
num_new_tokens
=
min
(
num_new_tokens
,
max_total_tok
en
s
-
1
-
request
.
num_computed_tokens
num_new_tokens
,
self
.
max_model_l
en
-
1
-
request
.
num_computed_tokens
)
)
# Schedule encoder inputs.
# Schedule encoder inputs.
...
@@ -799,15 +807,15 @@ class Scheduler(SchedulerInterface):
...
@@ -799,15 +807,15 @@ class Scheduler(SchedulerInterface):
for
idx
,
req
in
enumerate
(
itertools
.
chain
(
running_reqs
,
resumed_reqs
)):
for
idx
,
req
in
enumerate
(
itertools
.
chain
(
running_reqs
,
resumed_reqs
)):
req_id
=
req
.
request_id
req_id
=
req
.
request_id
req_ids
.
append
(
req_id
)
req_ids
.
append
(
req_id
)
num_tokens
=
num_scheduled_tokens
[
req_id
]
-
len
(
spec_decode_tokens
.
get
(
req_id
,
())
)
if
self
.
use_pp
:
if
self
.
use_pp
:
# When using PP, the scheduler sends the sampled tokens back,
# When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first-
# because there's no direct communication between the first-
# stage worker and the last-stage worker. Otherwise, we don't
# stage worker and the last-stage worker. Otherwise, we don't
# need to send the sampled tokens back because the model runner
# need to send the sampled tokens back because the model runner
# will cache them.
# will cache them.
num_tokens
=
num_scheduled_tokens
[
req_id
]
-
len
(
spec_decode_tokens
.
get
(
req_id
,
())
)
token_ids
=
req
.
all_token_ids
[
token_ids
=
req
.
all_token_ids
[
req
.
num_computed_tokens
:
req
.
num_computed_tokens
+
num_tokens
req
.
num_computed_tokens
:
req
.
num_computed_tokens
+
num_tokens
]
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment