Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8e7a8916
Unverified
Commit
8e7a8916
authored
Nov 28, 2025
by
Nick Hill
Committed by
GitHub
Nov 28, 2025
Browse files
[BugFix] Fix spec decoding max_tokens scheduling perf issue (#29542)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
953d9c82
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
38 deletions
+28
-38
tests/v1/test_outputs.py
tests/v1/test_outputs.py
+11
-13
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+9
-5
vllm/v1/outputs.py
vllm/v1/outputs.py
+8
-20
No files found.
tests/v1/test_outputs.py
View file @
8e7a8916
...
@@ -43,7 +43,7 @@ class TestLogprobsLists(TestCase):
...
@@ -43,7 +43,7 @@ class TestLogprobsLists(TestCase):
cu_num_generated_tokens
=
None
,
cu_num_generated_tokens
=
None
,
)
)
sliced
=
logprobsLists
.
slice
(
1
,
3
)
sliced
=
logprobsLists
.
slice
_request
(
1
,
num_positions
=
2
)
assert
sliced
.
logprob_token_ids
==
[[
2
],
[
3
]]
assert
sliced
.
logprob_token_ids
==
[[
2
],
[
3
]]
assert
sliced
.
logprobs
==
[[
0.2
],
[
0.3
]]
assert
sliced
.
logprobs
==
[[
0.2
],
[
0.3
]]
assert
sliced
.
sampled_token_ranks
==
[
2
,
3
]
assert
sliced
.
sampled_token_ranks
==
[
2
,
3
]
...
@@ -51,7 +51,7 @@ class TestLogprobsLists(TestCase):
...
@@ -51,7 +51,7 @@ class TestLogprobsLists(TestCase):
def
test_slice_from_start
(
self
):
def
test_slice_from_start
(
self
):
"""Test slicing from the start position"""
"""Test slicing from the start position"""
sliced
=
self
.
logprobsLists
.
slice
(
0
,
2
)
sliced
=
self
.
logprobsLists
.
slice
_request
(
0
,
num_positions
=
5
)
assert
len
(
sliced
.
logprob_token_ids
)
==
5
assert
len
(
sliced
.
logprob_token_ids
)
==
5
assert
sliced
.
logprob_token_ids
==
[
assert
sliced
.
logprob_token_ids
==
[
[
1
,
2
],
[
1
,
2
],
...
@@ -60,11 +60,11 @@ class TestLogprobsLists(TestCase):
...
@@ -60,11 +60,11 @@ class TestLogprobsLists(TestCase):
[
7
,
8
],
[
7
,
8
],
[
9
,
10
],
[
9
,
10
],
]
]
assert
sliced
.
cu_num_generated_tokens
==
[
0
,
2
,
5
]
assert
sliced
.
cu_num_generated_tokens
is
None
def
test_slice_from_middle
(
self
):
def
test_slice_from_middle
(
self
):
"""Test slicing from the middle position"""
"""Test slicing from the middle position"""
sliced
=
self
.
logprobsLists
.
slice
(
1
,
3
)
sliced
=
self
.
logprobsLists
.
slice
_request
(
1
,
num_positions
=
7
)
assert
len
(
sliced
.
logprob_token_ids
)
==
7
assert
len
(
sliced
.
logprob_token_ids
)
==
7
assert
sliced
.
logprob_token_ids
==
[
assert
sliced
.
logprob_token_ids
==
[
[
5
,
6
],
[
5
,
6
],
...
@@ -75,27 +75,25 @@ class TestLogprobsLists(TestCase):
...
@@ -75,27 +75,25 @@ class TestLogprobsLists(TestCase):
[
15
,
16
],
[
15
,
16
],
[
17
,
18
],
[
17
,
18
],
]
]
assert
sliced
.
cu_num_generated_tokens
==
[
0
,
3
,
7
]
assert
sliced
.
cu_num_generated_tokens
is
None
def
test_slice_single_request
(
self
):
def
test_slice_single_request
(
self
):
"""Test slicing a single request"""
"""Test slicing a single request"""
sliced
=
self
.
logprobsLists
.
slice
(
1
,
2
)
sliced
=
self
.
logprobsLists
.
slice
_request
(
1
,
num_positions
=
3
)
assert
len
(
sliced
.
logprob_token_ids
)
==
3
assert
len
(
sliced
.
logprob_token_ids
)
==
3
assert
sliced
.
logprob_token_ids
==
[[
5
,
6
],
[
7
,
8
],
[
9
,
10
]]
assert
sliced
.
logprob_token_ids
==
[[
5
,
6
],
[
7
,
8
],
[
9
,
10
]]
assert
sliced
.
cu_num_generated_tokens
==
[
0
,
3
]
assert
sliced
.
cu_num_generated_tokens
is
None
def
test_slice_last_request
(
self
):
def
test_slice_last_request
(
self
):
"""Test slicing the last request"""
"""Test slicing the last request"""
sliced
=
self
.
logprobsLists
.
slice
(
2
,
3
)
sliced
=
self
.
logprobsLists
.
slice
_request
(
2
,
num_positions
=
4
)
assert
len
(
sliced
.
logprob_token_ids
)
==
4
assert
len
(
sliced
.
logprob_token_ids
)
==
4
assert
sliced
.
logprob_token_ids
==
[[
11
,
12
],
[
13
,
14
],
[
15
,
16
],
[
17
,
18
]]
assert
sliced
.
logprob_token_ids
==
[[
11
,
12
],
[
13
,
14
],
[
15
,
16
],
[
17
,
18
]]
assert
sliced
.
cu_num_generated_tokens
==
[
0
,
4
]
assert
sliced
.
cu_num_generated_tokens
is
None
def
test_slice_all_requests
(
self
):
def
test_slice_all_requests
(
self
):
"""Test slicing all requests (full slice)"""
"""Test slicing all requests (full slice)"""
sliced
=
self
.
logprobsLists
.
slice
(
0
,
3
)
sliced
=
self
.
logprobsLists
.
slice
_request
(
0
,
num_positions
=
9
)
assert
len
(
sliced
.
logprob_token_ids
)
==
9
# All tokens
assert
len
(
sliced
.
logprob_token_ids
)
==
9
# All tokens
assert
sliced
.
logprob_token_ids
==
self
.
logprobsLists
.
logprob_token_ids
assert
sliced
.
logprob_token_ids
==
self
.
logprobsLists
.
logprob_token_ids
assert
(
assert
sliced
.
cu_num_generated_tokens
is
None
sliced
.
cu_num_generated_tokens
==
self
.
logprobsLists
.
cu_num_generated_tokens
)
vllm/v1/core/sched/scheduler.py
View file @
8e7a8916
...
@@ -234,11 +234,15 @@ class Scheduler(SchedulerInterface):
...
@@ -234,11 +234,15 @@ class Scheduler(SchedulerInterface):
num_new_tokens
=
self
.
scheduler_config
.
long_prefill_token_threshold
num_new_tokens
=
self
.
scheduler_config
.
long_prefill_token_threshold
num_new_tokens
=
min
(
num_new_tokens
,
token_budget
)
num_new_tokens
=
min
(
num_new_tokens
,
token_budget
)
# Make sure the input position does not exceed the max model len or
num_spec_placeholders
=
max
(
0
,
request
.
num_output_placeholders
-
1
)
# request's max_tokens.
# This is necessary when using spec decoding and/or async scheduling.
max_total_tokens
=
min
(
max_total_tokens
=
min
(
request
.
num_prompt_tokens
+
request
.
max_tokens
,
self
.
max_model_len
# Avoid scheduling tokens that we're sure won't will be needed based on
# request.max_tokens. For this calculation we assume placeholder
# speculated output tokens are rejected.
request
.
num_prompt_tokens
+
request
.
max_tokens
+
num_spec_placeholders
,
# Make sure the input position does not exceed the max model len.
# This is necessary when using spec decoding.
self
.
max_model_len
,
)
)
num_new_tokens
=
min
(
num_new_tokens
=
min
(
num_new_tokens
,
max_total_tokens
-
1
-
request
.
num_computed_tokens
num_new_tokens
,
max_total_tokens
-
1
-
request
.
num_computed_tokens
...
@@ -1089,7 +1093,7 @@ class Scheduler(SchedulerInterface):
...
@@ -1089,7 +1093,7 @@ class Scheduler(SchedulerInterface):
and
request
.
sampling_params
.
logprobs
is
not
None
and
request
.
sampling_params
.
logprobs
is
not
None
and
logprobs
and
logprobs
):
):
new_logprobs
=
logprobs
.
slice
(
req_index
,
req_index
+
1
)
new_logprobs
=
logprobs
.
slice
_request
(
req_index
,
len
(
new_token_ids
)
)
if
new_token_ids
and
self
.
structured_output_manager
.
should_advance
(
request
):
if
new_token_ids
and
self
.
structured_output_manager
.
should_advance
(
request
):
struct_output_request
=
request
.
structured_output_request
struct_output_request
=
request
.
structured_output_request
...
...
vllm/v1/outputs.py
View file @
8e7a8916
...
@@ -29,27 +29,15 @@ class LogprobsLists(NamedTuple):
...
@@ -29,27 +29,15 @@ class LogprobsLists(NamedTuple):
# different for each request.
# different for each request.
cu_num_generated_tokens
:
list
[
int
]
|
None
=
None
cu_num_generated_tokens
:
list
[
int
]
|
None
=
None
def
slice
(
self
,
start_req_idx
:
int
,
end_req_idx
:
int
):
def
slice_request
(
self
,
req_idx
:
int
,
num_positions
:
int
):
if
self
.
cu_num_generated_tokens
:
if
self
.
cu_num_generated_tokens
is
not
None
:
start
=
self
.
cu_num_generated_tokens
[
start_req_idx
]
req_idx
=
self
.
cu_num_generated_tokens
[
req_idx
]
end
=
self
.
cu_num_generated_tokens
[
end_req_idx
]
end_idx
=
req_idx
+
num_positions
# Recompute cumulative array starting from 0
cu_num_offset
=
self
.
cu_num_generated_tokens
[
start_req_idx
]
sliced_cu_num_generated_tokens
=
[
cu_num
-
cu_num_offset
for
cu_num
in
self
.
cu_num_generated_tokens
[
start_req_idx
:
end_req_idx
+
1
]
]
else
:
start
=
start_req_idx
end
=
end_req_idx
sliced_cu_num_generated_tokens
=
None
return
LogprobsLists
(
return
LogprobsLists
(
self
.
logprob_token_ids
[
start
:
end
],
self
.
logprob_token_ids
[
req_idx
:
end_idx
],
self
.
logprobs
[
start
:
end
],
self
.
logprobs
[
req_idx
:
end_idx
],
self
.
sampled_token_ranks
[
start
:
end
],
self
.
sampled_token_ranks
[
req_idx
:
end_idx
],
sliced_cu_num_generated_tokens
,
None
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment