Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3a0fba5c
Unverified
Commit
3a0fba5c
authored
Apr 21, 2025
by
Woosuk Kwon
Committed by
GitHub
Apr 21, 2025
Browse files
[V1][Spec Decode] Handle draft tokens beyond max_model_len (#16087)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
299ebb62
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
137 additions
and
15 deletions
+137
-15
tests/v1/core/test_scheduler.py
tests/v1/core/test_scheduler.py
+6
-1
tests/v1/spec_decode/test_max_len.py
tests/v1/spec_decode/test_max_len.py
+57
-0
tests/v1/spec_decode/test_ngram.py
tests/v1/spec_decode/test_ngram.py
+19
-9
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+7
-0
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+32
-3
vllm/v1/spec_decode/ngram_proposer.py
vllm/v1/spec_decode/ngram_proposer.py
+9
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+7
-1
No files found.
tests/v1/core/test_scheduler.py
View file @
3a0fba5c
...
...
@@ -30,6 +30,7 @@ def create_scheduler(
use_kv_connector
:
bool
=
False
,
num_blocks
:
int
=
10000
,
block_size
:
int
=
16
,
max_model_len
:
Optional
[
int
]
=
None
,
)
->
Scheduler
:
'''Create scheduler under test.
...
...
@@ -44,12 +45,15 @@ def create_scheduler(
Returns:
:class:`Scheduler` instance
'''
if
max_model_len
is
None
:
max_model_len
=
max_num_batched_tokens
scheduler_config
=
SchedulerConfig
(
max_num_seqs
=
max_num_seqs
,
max_num_batched_tokens
=
max_num_batched_tokens
,
max_model_len
=
max_
num_batched_tok
en
s
,
max_model_len
=
max_
model_l
en
,
long_prefill_token_threshold
=
long_prefill_token_threshold
,
disable_chunked_mm_input
=
disable_chunked_mm_input
,
enable_chunked_prefill
=
True
,
)
model_config
=
ModelConfig
(
model
=
model
,
...
...
@@ -296,6 +300,7 @@ def test_no_mm_input_chunking():
model
=
"llava-hf/llava-1.5-7b-hf"
,
max_num_batched_tokens
=
1024
,
disable_chunked_mm_input
=
True
,
max_model_len
=
2048
,
)
mm_positions
=
[[
PlaceholderRange
(
offset
=
400
,
length
=
800
)]]
requests
=
create_requests
(
num_requests
=
1
,
...
...
tests/v1/spec_decode/test_max_len.py
0 → 100644
View file @
3a0fba5c
# SPDX-License-Identifier: Apache-2.0
"""Test whether spec decoding handles the max model length properly."""
import
pytest
from
vllm
import
LLM
,
SamplingParams
_PROMPTS
=
[
"1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1"
,
"Repeat the following sentence 10 times: Consistency is key to mastering any skill."
,
# noqa: E501
"Who won the Turing Award in 2018, and for what contribution? Describe in detail."
,
# noqa: E501
]
@
pytest
.
mark
.
parametrize
(
"num_speculative_tokens"
,
[
1
,
3
,
10
])
def
test_ngram_max_len
(
monkeypatch
:
pytest
.
MonkeyPatch
,
num_speculative_tokens
:
int
,
):
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
model
=
"facebook/opt-125m"
,
max_model_len
=
100
,
enforce_eager
=
True
,
# For faster initialization.
speculative_config
=
{
"method"
:
"ngram"
,
"prompt_lookup_max"
:
5
,
"prompt_lookup_min"
:
3
,
"num_speculative_tokens"
:
num_speculative_tokens
,
},
)
sampling_params
=
SamplingParams
(
max_tokens
=
100
,
ignore_eos
=
True
)
llm
.
generate
(
_PROMPTS
,
sampling_params
)
@
pytest
.
mark
.
parametrize
(
"num_speculative_tokens"
,
[
1
,
3
,
10
])
def
test_eagle_max_len
(
monkeypatch
:
pytest
.
MonkeyPatch
,
num_speculative_tokens
:
int
,
):
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
model
=
"meta-llama/Meta-Llama-3-8B-Instruct"
,
enforce_eager
=
True
,
# For faster initialization.
speculative_config
=
{
"method"
:
"eagle"
,
"model"
:
"yuhuili/EAGLE-LLaMA3-Instruct-8B"
,
"num_speculative_tokens"
:
num_speculative_tokens
,
},
max_model_len
=
100
,
)
sampling_params
=
SamplingParams
(
max_tokens
=
100
,
ignore_eos
=
True
)
llm
.
generate
(
_PROMPTS
,
sampling_params
)
tests/v1/spec_decode/test_ngram.py
View file @
3a0fba5c
...
...
@@ -2,7 +2,7 @@
import
numpy
as
np
from
vllm.config
import
SpeculativeConfig
,
VllmConfig
from
vllm.config
import
ModelConfig
,
SpeculativeConfig
,
VllmConfig
from
vllm.v1.spec_decode.ngram_proposer
import
(
NgramProposer
,
_find_subarray_kmp
,
_kmp_lps_array
)
...
...
@@ -42,14 +42,24 @@ def test_find_subarray_kmp():
def
test_ngram_proposer
():
def
ngram_proposer
(
min_n
:
int
,
max_n
:
int
,
k
:
int
)
->
NgramProposer
:
return
NgramProposer
(
vllm_config
=
VllmConfig
(
speculative_config
=
SpeculativeConfig
.
from_dict
(
{
"prompt_lookup_min"
:
min_n
,
"prompt_lookup_max"
:
max_n
,
"num_speculative_tokens"
:
k
,
"method"
:
"ngram"
,
})))
# Dummy model config. Just to set max_model_len.
model_config
=
ModelConfig
(
model
=
"facebook/opt-125m"
,
task
=
"generate"
,
max_model_len
=
100
,
tokenizer
=
"facebook/opt-125m"
,
tokenizer_mode
=
"auto"
,
dtype
=
"auto"
,
seed
=
None
,
trust_remote_code
=
False
)
return
NgramProposer
(
vllm_config
=
VllmConfig
(
model_config
=
model_config
,
speculative_config
=
SpeculativeConfig
.
from_dict
({
"prompt_lookup_min"
:
min_n
,
"prompt_lookup_max"
:
max_n
,
"num_speculative_tokens"
:
k
,
"method"
:
"ngram"
,
})))
# No match.
result
=
ngram_proposer
(
...
...
vllm/v1/core/sched/scheduler.py
View file @
3a0fba5c
...
...
@@ -185,6 +185,13 @@ class Scheduler(SchedulerInterface):
num_new_tokens
=
min
(
num_new_tokens
,
token_budget
)
assert
num_new_tokens
>
0
# Make sure the input position does not exceed the max model len.
# This is necessary when using spec decoding.
num_new_tokens
=
min
(
num_new_tokens
,
self
.
max_model_len
-
request
.
num_computed_tokens
)
assert
num_new_tokens
>
0
# Schedule encoder inputs.
if
request
.
has_encoder_inputs
:
(
encoder_inputs_to_schedule
,
num_new_tokens
,
...
...
vllm/v1/spec_decode/eagle.py
View file @
3a0fba5c
...
...
@@ -12,6 +12,8 @@ from vllm.model_executor.models.llama_eagle import EagleLlamaForCausalLM
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
PADDING_SLOT_ID
=
-
1
class
EagleProposer
:
...
...
@@ -23,6 +25,7 @@ class EagleProposer:
self
.
vllm_config
=
vllm_config
self
.
num_speculative_tokens
=
(
vllm_config
.
speculative_config
.
num_speculative_tokens
)
self
.
max_model_len
=
vllm_config
.
model_config
.
max_model_len
self
.
block_size
=
vllm_config
.
cache_config
.
block_size
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
...
...
@@ -112,22 +115,48 @@ class EagleProposer:
# Update the inputs.
input_ids
=
draft_token_ids_list
[
-
1
]
positions
+=
1
# NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length. Since it is complex
# to remove such requests from the batch, we keep them in the batch
# but adjust the position ids and slot mappings to avoid the
# out-of-range access during the model execution. The draft tokens
# generated with this adjustment should be ignored.
exceeds_max_model_len
=
positions
>=
self
.
max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions
=
torch
.
where
(
exceeds_max_model_len
,
0
,
positions
)
# Increment the sequence lengths.
attn_metadata
.
max_seq_len
+=
1
attn_metadata
.
seq_lens
+=
1
# Consider max model length.
attn_metadata
.
max_seq_len
=
min
(
attn_metadata
.
max_seq_len
,
self
.
max_model_len
)
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
attn_metadata
.
seq_lens
.
masked_fill_
(
exceeds_max_model_len
,
1
)
# Compute the slot mapping.
block_numbers
=
positions
//
self
.
block_size
block_numbers
=
clamped_
positions
//
self
.
block_size
block_ids
=
block_table
.
gather
(
dim
=
1
,
index
=
block_numbers
.
view
(
-
1
,
1
))
block_ids
=
block_ids
.
view
(
-
1
)
attn_metadata
.
slot_mapping
=
(
block_ids
*
self
.
block_size
+
positions
%
self
.
block_size
)
clamped_positions
%
self
.
block_size
)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
attn_metadata
.
slot_mapping
.
masked_fill_
(
exceeds_max_model_len
,
PADDING_SLOT_ID
)
# Run the model.
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
):
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
hidden_states
=
hidden_states
,
positions
=
positions
,
positions
=
clamped_
positions
,
)
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
None
)
draft_token_ids
,
probs
=
compute_probs_and_sample_next_token
(
...
...
vllm/v1/spec_decode/ngram_proposer.py
View file @
3a0fba5c
...
...
@@ -18,6 +18,9 @@ class NgramProposer:
# tokens follow the match, we will return the maximum amount of
# tokens until the end.
self
.
k
=
vllm_config
.
speculative_config
.
num_speculative_tokens
# Maximum length of the model.
self
.
max_model_len
=
vllm_config
.
model_config
.
max_model_len
# Trigger Numba JIT compilation for N-gram proposer.
# This usually takes less than 1 second.
self
.
propose
(
np
.
zeros
(
1024
,
dtype
=
np
.
int32
))
...
...
@@ -50,9 +53,14 @@ class NgramProposer:
followed that pattern. Here we will return [4,2,3] because
we only have three tokens after the match.
"""
# Do not generate draft tokens beyond the max model length.
k
=
min
(
self
.
k
,
self
.
max_model_len
-
context_token_ids
.
shape
[
0
])
if
k
<=
0
:
return
None
# TODO(woosuk): Optimize this.
for
n
in
range
(
self
.
max_n
,
self
.
min_n
-
1
,
-
1
):
result
=
_find_subarray_kmp
(
context_token_ids
,
n
,
self
.
k
)
result
=
_find_subarray_kmp
(
context_token_ids
,
n
,
k
)
if
result
is
not
None
:
return
result
return
None
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
3a0fba5c
...
...
@@ -1271,7 +1271,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
draft_token_ids
.
append
([])
continue
# Skip requests that require top-p, top-k, etc.
# Skip requests that require sampling parameters that are not
# supported with speculative decoding.
req_id
=
self
.
input_batch
.
req_ids
[
i
]
if
not
is_spec_decode_supported
(
req_id
,
self
.
input_batch
):
draft_token_ids
.
append
([])
...
...
@@ -1280,6 +1281,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Add sampled_token_ids to token_ids_cpu.
start_idx
=
self
.
input_batch
.
num_tokens_no_spec
[
i
]
end_idx
=
start_idx
+
num_sampled_ids
if
end_idx
>=
self
.
max_model_len
:
# Skip requests that have already reached the max model length.
draft_token_ids
.
append
([])
continue
self
.
input_batch
.
token_ids_cpu
[
i
,
start_idx
:
end_idx
]
=
sampled_ids
drafter_output
=
self
.
drafter
.
propose
(
self
.
input_batch
.
token_ids_cpu
[
i
,
:
end_idx
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment