Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9d43afcc
Unverified
Commit
9d43afcc
authored
Nov 07, 2024
by
Nicolò Lucchesi
Committed by
GitHub
Nov 07, 2024
Browse files
[Feature] [Spec decode]: Combine chunked prefill with speculative decoding (#9291)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
ae62fd17
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
477 additions
and
147 deletions
+477
-147
tests/spec_decode/e2e/test_compatibility.py
tests/spec_decode/e2e/test_compatibility.py
+0
-34
tests/spec_decode/e2e/test_multistep_correctness.py
tests/spec_decode/e2e/test_multistep_correctness.py
+102
-3
tests/spec_decode/e2e/test_ngram_correctness.py
tests/spec_decode/e2e/test_ngram_correctness.py
+35
-1
tests/spec_decode/test_ngram_worker.py
tests/spec_decode/test_ngram_worker.py
+7
-2
tests/spec_decode/test_scorer.py
tests/spec_decode/test_scorer.py
+27
-4
tests/spec_decode/test_spec_decode_worker.py
tests/spec_decode/test_spec_decode_worker.py
+82
-0
tests/spec_decode/utils.py
tests/spec_decode/utils.py
+60
-11
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+7
-3
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+6
-0
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+7
-0
vllm/config.py
vllm/config.py
+6
-8
vllm/core/scheduler.py
vllm/core/scheduler.py
+1
-0
vllm/engine/output_processor/multi_step.py
vllm/engine/output_processor/multi_step.py
+5
-3
vllm/spec_decode/batch_expansion.py
vllm/spec_decode/batch_expansion.py
+31
-30
vllm/spec_decode/mqa_scorer.py
vllm/spec_decode/mqa_scorer.py
+19
-12
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+78
-32
vllm/spec_decode/top1_proposer.py
vllm/spec_decode/top1_proposer.py
+4
-4
No files found.
tests/spec_decode/e2e/test_compatibility.py
View file @
9d43afcc
...
...
@@ -5,40 +5,6 @@ from vllm import SamplingParams
from
.conftest
import
get_output_from_llm_generator
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
{
"enable_chunked_prefill"
:
True
,
},
])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_spec_decode_xfail_chunked_prefill
(
test_llm_generator
):
"""Verify that speculative decoding with chunked prefill fails.
"""
output_len
=
128
temperature
=
0.0
prompts
=
[
"Hello, my name is"
,
]
sampling_params
=
SamplingParams
(
max_tokens
=
output_len
,
ignore_eos
=
True
,
temperature
=
temperature
,
)
with
pytest
.
raises
(
ValueError
,
match
=
"Speculative decoding and chunked prefill"
):
get_output_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"meta-llama/Llama-2-7b-chat-hf"
,
"speculative_model"
:
"JackFram/llama-68m"
,
...
...
tests/spec_decode/e2e/test_multistep_correctness.py
View file @
9d43afcc
...
...
@@ -62,6 +62,16 @@ from .conftest import (get_output_from_llm_generator,
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
False
,
},
{
# Chunked prefill enabled with small value
# to make sure we get mixed batches.
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
},
{
# Verify the detokenizer assertions in the test work when spec
...
...
@@ -141,6 +151,14 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
False
,
},
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
,
},
])
@
pytest
.
mark
.
parametrize
(
...
...
@@ -204,6 +222,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
False
,
},
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
},
])
@
pytest
.
mark
.
parametrize
(
...
...
@@ -255,6 +281,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
False
,
},
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
},
])
@
pytest
.
mark
.
parametrize
(
"max_output_len"
,
[
...
...
@@ -300,6 +334,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
False
,
},
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
},
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
])
...
...
@@ -347,6 +389,14 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
False
,
},
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
},
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
32
])
...
...
@@ -397,6 +447,14 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
False
,
},
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
},
])
@
pytest
.
mark
.
parametrize
(
...
...
@@ -454,6 +512,14 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
False
,
},
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
},
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
...
...
@@ -503,6 +569,15 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs,
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_max_model_len"
:
32
,
"enable_chunked_prefill"
:
False
,
},
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
,
"speculative_max_model_len"
:
32
,
},
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
...
...
@@ -551,6 +626,15 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs,
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"speculative_disable_by_batch_size"
:
2
,
"enable_chunked_prefill"
:
False
,
},
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"speculative_disable_by_batch_size"
:
2
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
,
},
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
...
...
@@ -590,10 +674,17 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs,
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
k
,
"enable_chunked_prefill"
:
False
,
}
# Try a range of common k, as well as large speculation.
for
k
in
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
63
]
])
]
+
[{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
k
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
,
}
for
k
in
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
63
]])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
...
...
@@ -636,11 +727,19 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
k
,
"spec_decoding_acceptance_method"
:
"typical_acceptance_sampler"
"spec_decoding_acceptance_method"
:
"typical_acceptance_sampler"
,
"enable_chunked_prefill"
:
False
}
# Try a range of common k.
for
k
in
[
1
,
2
,
3
]
])
]
+
[{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
k
,
"spec_decoding_acceptance_method"
:
"typical_acceptance_sampler"
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
}
for
k
in
[
1
,
2
,
3
]])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
...
...
tests/spec_decode/e2e/test_ngram_correctness.py
View file @
9d43afcc
...
...
@@ -50,18 +50,33 @@ from .conftest import run_equality_correctness_test
"num_speculative_tokens"
:
5
,
"ngram_prompt_lookup_max"
:
3
,
},
{
"speculative_model"
:
"[ngram]"
,
"num_speculative_tokens"
:
5
,
"ngram_prompt_lookup_max"
:
3
,
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
256
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"prefill_chunk_size"
,
[
-
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_ngram_e2e_greedy_correctness
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
prefill_chunk_size
:
int
,
seed
:
int
):
"""Verify greedy equality on a tiny model with different batch size."""
if
prefill_chunk_size
>
0
:
common_llm_kwargs
.
update
(
**
{
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
prefill_chunk_size
,
"max_num_seqs"
:
prefill_chunk_size
})
else
:
common_llm_kwargs
[
"enable_chunked_prefill"
]
=
False
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
...
...
@@ -151,6 +166,16 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
"speculative_model"
:
"[ngram]"
,
"num_speculative_tokens"
:
5
,
"ngram_prompt_lookup_max"
:
3
,
"enable_chunked_prefill"
:
False
,
},
{
"speculative_model"
:
"[ngram]"
,
"num_speculative_tokens"
:
5
,
"ngram_prompt_lookup_max"
:
3
,
"enable_chunked_prefill"
:
True
,
"speculative_disable_mqa_scorer"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
},
])
@
pytest
.
mark
.
parametrize
(
...
...
@@ -251,6 +276,15 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
"num_speculative_tokens"
:
5
,
"ngram_prompt_lookup_max"
:
3
,
"speculative_disable_by_batch_size"
:
4
},
{
"speculative_model"
:
"[ngram]"
,
"num_speculative_tokens"
:
5
,
"ngram_prompt_lookup_max"
:
3
,
"speculative_disable_by_batch_size"
:
4
,
"enable_chunked_prefill"
:
True
,
"speculative_disable_mqa_scorer"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
...
...
tests/spec_decode/test_ngram_worker.py
View file @
9d43afcc
...
...
@@ -118,7 +118,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
num_gpu_blocks
,
block_size
,
final_prompt_lens
=
final_prompt_lens
)
for
sg
in
seq_group_metadata_list
:
sg
.
is_prompt
=
False
proposals
=
proposer
.
get_spec_proposals
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
...
...
@@ -147,7 +148,7 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
def
test_ngram_algo_correctness_for_batches_match_all
():
"""Verify our ngram algo find the right candidate in the prompt
For the scenario find candidate in all batchs
For the scenario find candidate in all batch
e
s
"""
block_size
=
32
...
...
@@ -192,6 +193,10 @@ def test_ngram_algo_correctness_for_batches_match_all():
block_size
,
final_prompt_lens
=
final_prompt_lens
)
# Normally drafter is run on decode requests only; here we check the output
# of the ngram worker as it is the sole proposer that has no forward.
for
sg
in
seq_group_metadata_list
:
sg
.
is_prompt
=
False
proposals
=
proposer
.
get_spec_proposals
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
...
...
tests/spec_decode/test_scorer.py
View file @
9d43afcc
...
...
@@ -46,12 +46,14 @@ def assert_score_equal(score1: SpeculativeScores,
@
pytest
.
mark
.
parametrize
(
'max_propose_len'
,
[
1
,
3
,
5
])
@
pytest
.
mark
.
parametrize
(
'mixed_propose_len'
,
[
True
])
@
pytest
.
mark
.
parametrize
(
'device'
,
[
'cuda'
])
@
pytest
.
mark
.
parametrize
(
'prefill_chunking'
,
[
False
,
True
])
def
test_scorer
(
model_name
:
str
,
batch_size
:
int
,
max_propose_len
:
int
,
mixed_propose_len
:
bool
,
device
:
str
)
->
None
:
mixed_propose_len
:
bool
,
device
:
str
,
prefill_chunking
:
bool
)
->
None
:
"""
Compare the batch expansion scorer and mqa scorer return the same score.
We test for both queries with the same propose length and different
propose length.
propose length
, as well as mixed prefill-decode batches
.
"""
seed
=
0
block_size
=
32
...
...
@@ -67,16 +69,37 @@ def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
if
not
mixed_propose_len
:
propose_lens
=
[
max_propose_len
]
*
batch_size
else
:
non_zero_cnt
=
random
.
randint
(
0
,
batch_size
)
# There must be at least 1 decode request, otherwise
# we have nothing to score (`_run_no_spec`).
non_zero_cnt
=
random
.
randint
(
1
,
batch_size
)
propose_lens
=
[
max_propose_len
]
*
non_zero_cnt
+
[
0
]
*
(
batch_size
-
non_zero_cnt
)
random
.
shuffle
(
propose_lens
)
proposals
=
create_proposal
(
propose_lens
,
vocab_size
,
device
)
seq_group_metadatalist
,
_
,
_
=
create_batch
(
batch_size
,
max_propose_len
,
block_size
=
block_size
,
num_gpu_blocks
=
num_gpu_blocks
)
if
mixed_propose_len
and
prefill_chunking
and
(
n_prefills
:
=
batch_size
-
non_zero_cnt
):
prefill
,
_
,
_
=
create_batch
(
n_prefills
,
None
,
prefill_chunk_size
=
4
,
block_size
=
block_size
,
num_gpu_blocks
=
num_gpu_blocks
,
seq_ids
=
list
(
range
(
batch_size
,
batch_size
+
n_prefills
)))
# re-order to guarantee prefill|decode order
target_group_metadatalist
=
[
seq_group_metadatalist
[
i
]
for
i
,
p
in
enumerate
(
propose_lens
)
if
p
>
0
]
seq_group_metadatalist
=
prefill
+
target_group_metadatalist
propose_lens
=
[
0
]
*
n_prefills
+
[
p
for
p
in
propose_lens
if
p
>
0
]
proposals
=
create_proposal
(
propose_lens
,
vocab_size
,
device
)
requests
=
ExecuteModelRequest
(
seq_group_metadatalist
,
num_lookahead_slots
=
max_propose_len
)
...
...
tests/spec_decode/test_spec_decode_worker.py
View file @
9d43afcc
...
...
@@ -10,6 +10,7 @@ import torch
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
,
SequenceOutput
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.metrics
import
(
AsyncMetricsCollector
,
SpecDecodeWorkerMetrics
)
...
...
@@ -819,3 +820,84 @@ def test_handle_finished_requests():
# and 'request-3' are removed from seq_with_bonus_token_in_last_step.
assert
worker
.
_seq_with_bonus_token_in_last_step
==
\
{
4
,
5
,
10
}
@
pytest
.
mark
.
parametrize
(
'k'
,
[
3
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"batch_composition"
,
[
"prefill_only"
,
"decode_only"
,
"mixed"
])
@
torch
.
inference_mode
()
def
test_chunked_prefill_flow
(
k
:
int
,
batch_size
:
int
,
batch_composition
:
str
):
"""
Verify SpecDecodeWorker calls match the expected flow.
"""
vocab_size
=
32_000
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
target_worker
=
mock_worker
()
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
mock_spec_decode_sampler
(
"rejection_sampler"
),
disable_logprobs
=
False
,
metrics_collector
=
metrics_collector
)
exception_secret
=
'artificial stop'
worker
.
scorer
=
mock_worker
(
BatchExpansionTop1Scorer
)
worker
.
scorer
.
score_proposals
.
side_effect
=
ValueError
(
exception_secret
)
# Create batch with combination of terminal/non-terminal prefill chunks
# and decodes (different seq_ids).
decodes
,
_
,
_
=
create_batch
(
batch_size
,
k
)
# Pre-chunking here, get 'batch_size' chunks.
prefill
,
_
,
_
=
create_batch
(
batch_size
,
k
,
prefill_chunk_size
=
4
,
seq_ids
=
list
(
range
(
batch_size
,
batch_size
*
2
)))
if
batch_composition
==
"prefill_only"
:
n_prefills
=
batch_size
elif
batch_composition
==
"decode_only"
:
n_prefills
=
0
else
:
n_prefills
=
random
.
randint
(
1
,
batch_size
-
1
)
n_decodes
=
batch_size
-
n_prefills
prefill
=
random
.
sample
(
prefill
,
n_prefills
)
decodes
=
random
.
sample
(
decodes
,
n_decodes
)
target_group_metadata_list
=
prefill
+
decodes
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
target_group_metadata_list
,
num_lookahead_slots
=
k
)
target_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
1
,
batch_size
*
(
k
+
1
)),
dtype
=
torch
.
int64
,
device
=
'cuda'
)
target_token_probs
=
torch
.
rand
(
1
,
batch_size
*
(
k
+
1
),
vocab_size
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
target_token_logprobs
=
torch
.
rand
(
1
,
batch_size
*
(
k
+
1
),
vocab_size
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
target_output
=
create_sampler_output_list
(
target_token_ids
,
target_token_probs
,
target_token_logprobs
)
target_worker
.
execute_model
.
return_value
=
[
target_output
[
0
]]
if
not
len
(
decodes
):
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
# no spec run (prefill only)
draft_worker
.
execute_model
.
assert_called_once_with
(
execute_model_req
)
target_worker
.
execute_model
.
assert_called_once_with
(
execute_model_req
)
else
:
# Decode-only run OR mixed batch, scorer call fails (it's mocked)
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
# but first draft still counted
assert
draft_worker
.
get_spec_proposals
.
call_count
==
1
tests/spec_decode/utils.py
View file @
9d43afcc
...
...
@@ -146,6 +146,41 @@ def create_seq_group_metadata_from_prompts(
return
seq_grou_metadata_list
def
create_chunked_seq_group_metadata_from_prompt
(
prompt
:
List
[
int
],
num_gpu_blocks
:
int
,
chunk_size
:
int
,
block_size
:
int
,
seq_id
:
Optional
[
int
]
=
None
)
->
List
[
SequenceGroupMetadata
]:
if
seq_id
is
None
:
seq_id
=
0
free_gpu_blocks
=
list
(
range
(
num_gpu_blocks
))
block_allocations
=
[
free_gpu_blocks
.
pop
()
for
_
in
range
(
round_up_to_next_block
(
len
(
prompt
),
block_size
))
]
seq_group_metadata_list
=
[]
for
i
,
idx
in
enumerate
(
range
(
0
,
len
(
prompt
),
chunk_size
)):
chunk_ids
=
prompt
[
idx
:
idx
+
chunk_size
]
data
=
SequenceData
.
from_seqs
(
prompt
)
data
.
update_num_computed_tokens
(
idx
)
seq_data
=
{
i
:
data
}
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
str
(
seq_id
),
is_prompt
=
True
,
do_sample
=
idx
+
chunk_size
>=
len
(
prompt
),
# terminal chunk
seq_data
=
seq_data
,
sampling_params
=
SamplingParams
(
temperature
=
0.0
),
block_tables
=
{
i
:
block_allocations
},
token_chunk_size
=
len
(
chunk_ids
)))
return
seq_group_metadata_list
def
assert_logprobs_dict_allclose
(
actual_logprobs
:
List
[
Dict
[
int
,
Logprob
]],
expected_logprobs
:
List
[
Dict
[
int
,
Logprob
]])
->
None
:
...
...
@@ -198,7 +233,8 @@ def create_batch(batch_size,
prev_output_token_len
:
int
=
10
,
seq_ids
:
Optional
[
List
[
int
]]
=
None
,
num_gpu_blocks
:
Optional
[
int
]
=
None
,
block_size
:
Optional
[
int
]
=
None
):
block_size
:
Optional
[
int
]
=
None
,
prefill_chunk_size
:
Optional
[
int
]
=
None
):
if
block_size
is
None
:
block_size
=
8
...
...
@@ -213,15 +249,28 @@ def create_batch(batch_size,
prompt_lens
=
prompt_len
prompts
=
[[
next
(
iterator
)
for
_
in
range
(
p_len
)]
for
p_len
in
prompt_lens
]
prev_output_tokens
=
[[
next
(
iterator
)
for
_
in
range
(
prev_output_token_len
)
]
for
_
in
range
(
batch_size
)]
final_prompt_lens
=
[
len
(
prompt
)
+
len
(
prev_output_token
)
+
k
+
1
for
prompt
,
prev_output_token
in
zip
(
prompts
,
prev_output_tokens
)
]
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_prompt_lens
,
prev_output_tokens
,
seq_ids
)
if
prefill_chunk_size
:
# Create a batch of chunked prompts.
if
not
seq_ids
:
seq_ids
=
list
(
range
(
len
(
prompts
)))
seq_group_metadata_list
=
[]
for
p
,
sid
in
zip
(
prompts
,
seq_ids
):
seq_group_metadata_list
+=
\
create_chunked_seq_group_metadata_from_prompt
(
p
,
num_gpu_blocks
,
prefill_chunk_size
,
block_size
,
sid
)
seq_group_metadata_list
=
seq_group_metadata_list
[:
batch_size
]
prev_output_tokens
=
[]
else
:
prev_output_tokens
=
[[
next
(
iterator
)
for
_
in
range
(
prev_output_token_len
)
]
for
_
in
range
(
batch_size
)]
final_prompt_lens
=
[
len
(
prompt
)
+
len
(
prev_output_token
)
+
k
+
1
for
prompt
,
prev_output_token
in
zip
(
prompts
,
prev_output_tokens
)
]
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_prompt_lens
,
prev_output_tokens
,
seq_ids
)
return
seq_group_metadata_list
,
prompts
,
prev_output_tokens
vllm/attention/backends/flash_attn.py
View file @
9d43afcc
...
...
@@ -276,7 +276,11 @@ class FlashAttentionMetadata(AttentionMetadata):
max_query_len
=
self
.
max_query_len
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
self
.
max_decode_seq_len
,
query_start_loc
=
self
.
query_start_loc
[
self
.
num_prefills
:]
# Batch may be composed of prefill|decodes, adjust query start
# indices to refer to the start of decodes. E.g.
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
query_start_loc
=
(
self
.
query_start_loc
[
self
.
num_prefills
:]
-
self
.
query_start_loc
[
self
.
num_prefills
])
if
self
.
query_start_loc
is
not
None
else
None
,
seq_start_loc
=
self
.
seq_start_loc
[
self
.
num_prefills
:]
if
self
.
seq_start_loc
is
not
None
else
None
,
...
...
@@ -903,7 +907,9 @@ def unified_flash_attention(
# Decoding run.
# Use flash_attn_varlen_func kernel for speculative decoding
# because different queries might have different lengths.
assert
decode_meta
.
max_decode_query_len
is
not
None
# use only for actual varlen decoding
if
decode_meta
.
max_decode_query_len
>
1
:
assert
attn_type
==
AttentionType
.
DECODER
,
(
"Only decoder-only models support max_decode_query_len > 1"
)
...
...
@@ -949,8 +955,6 @@ def unified_flash_attention(
assert
prefill_output
is
not
None
return
prefill_output
.
view
(
num_prefill_query_tokens
,
hidden_size
)
# Chunked prefill does not work with speculative decoding.
# Therefore, the query length for decode should be 1 in chunked prefill.
assert
decode_meta
is
not
None
decode_output
=
decode_output
.
squeeze
(
1
)
output
=
torch
.
cat
([
prefill_output
,
decode_output
],
dim
=
0
)
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
9d43afcc
...
...
@@ -192,6 +192,12 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
block_tables
=
self
.
block_tables
[
self
.
num_prefills
:],
use_cuda_graph
=
self
.
use_cuda_graph
,
)
# Batch may be composed of prefill|decodes, adjust query start indices
# to refer to the start of decodes when the two are split apart.
# E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
if
self
.
_cached_decode_metadata
.
query_start_loc
is
not
None
:
qs
=
self
.
_cached_decode_metadata
.
query_start_loc
self
.
_cached_decode_metadata
.
query_start_loc
=
qs
-
qs
[
0
]
return
self
.
_cached_decode_metadata
def
advance_step
(
self
,
...
...
vllm/attention/backends/xformers.py
View file @
9d43afcc
...
...
@@ -272,6 +272,13 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
max_encoder_seq_len
=
self
.
max_encoder_seq_len
,
cross_slot_mapping
=
self
.
cross_slot_mapping
,
cross_block_tables
=
self
.
cross_block_tables
)
# Batch may be composed of prefill|decodes, adjust query start indices
# to refer to the start of decodes when the two are split apart.
# E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
if
self
.
_cached_decode_metadata
.
query_start_loc
is
not
None
:
qs
=
self
.
_cached_decode_metadata
.
query_start_loc
self
.
_cached_decode_metadata
.
query_start_loc
=
qs
-
qs
[
0
]
return
self
.
_cached_decode_metadata
...
...
vllm/config.py
View file @
9d43afcc
...
...
@@ -192,7 +192,6 @@ class ModelConfig:
self
.
max_logprobs
=
max_logprobs
self
.
disable_sliding_window
=
disable_sliding_window
self
.
skip_tokenizer_init
=
skip_tokenizer_init
self
.
hf_config
=
get_config
(
self
.
model
,
trust_remote_code
,
revision
,
code_revision
,
rope_scaling
,
rope_theta
,
config_format
)
...
...
@@ -1317,13 +1316,6 @@ class SpeculativeConfig:
"speculative decoding is > 1, but got "
f
"
{
speculative_disable_by_batch_size
=
}
"
)
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid
if
enable_chunked_prefill
:
raise
ValueError
(
"Speculative decoding and chunked prefill are "
f
"currently mutually exclusive (
{
enable_chunked_prefill
=
}
)."
)
# TODO: The user should be able to specify revision/max model len
# for the draft model. It is not currently supported.
draft_revision
=
None
...
...
@@ -1390,6 +1382,12 @@ class SpeculativeConfig:
f
"num_speculative_tokens=
{
n_predict
}
, but "
f
"
{
num_speculative_tokens
=
}
was provided."
)
if
enable_chunked_prefill
and
draft_hf_config
.
model_type
in
(
"medusa"
,
"mlp_speculator"
,
"eagle"
):
raise
ValueError
(
"Chunked prefill and hidden-state based draft models are "
"not compatible."
)
draft_model_config
.
max_model_len
=
(
SpeculativeConfig
.
_maybe_override_draft_max_model_len
(
speculative_max_model_len
,
...
...
vllm/core/scheduler.py
View file @
9d43afcc
...
...
@@ -1147,6 +1147,7 @@ class Scheduler:
# Update swapped requests.
self
.
swapped
.
extend
(
running_scheduled
.
swapped_out
)
# Put prefills first due to Attention backend ordering assumption.
return
SchedulerOutputs
(
scheduled_seq_groups
=
(
prefills
.
seq_groups
+
running_scheduled
.
prefill_seq_groups
+
...
...
vllm/engine/output_processor/multi_step.py
View file @
9d43afcc
...
...
@@ -134,10 +134,12 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
sample
for
sample
in
samples
if
sample
.
output_token
!=
VLLM_INVALID_TOKEN_ID
]
assert
valid_samples
self
.
_process_seq_outputs
(
seq
,
valid_samples
,
sequence_group
.
sampling_params
)
# When both spec-decode and pre-fill chunking are enabled, we
# don't have guaranteed samples here (e.g. all -1s).
if
valid_samples
:
self
.
_process_seq_outputs
(
seq
,
valid_samples
,
sequence_group
.
sampling_params
)
def
_process_decode_and_stop
(
self
,
seq
:
Sequence
,
sampling_params
:
SamplingParams
)
->
None
:
...
...
vllm/spec_decode/batch_expansion.py
View file @
9d43afcc
...
...
@@ -90,7 +90,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
else
:
# Batch has a mix of spec decode enabled and disabled seq groups
contracted
=
self
.
_contract_batch
(
contracted_bs
=
len
(
execute_model_req
.
seq_group_metadata_list
)
,
execute_model_req
.
seq_group_metadata_list
,
target_sampler_output
=
target_sampler_output
,
proposals
=
proposals
,
num_scoring_tokens
=
num_scoring_tokens
,
...
...
@@ -126,7 +126,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
split_batch_by_proposal_len
(
seq_group_metadata_list
,
proposal_lens_list
)
target_seq_group_metadata_list
=
self
.
_create_scoring_model_input
(
spec_expanded_seqs
=
self
.
_create_scoring_model_input
(
seq_group_metadata_list
=
spec_seqs
,
proposal_token_ids
=
proposal_token_ids_list
,
# NOTE: We determine the seq ids in the expanded batch using the
...
...
@@ -135,16 +135,19 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
seq_ids
=
get_all_seq_ids
(
seq_group_metadata_list
)),
)
num_scoring_tokens
=
len
(
target_seq_group_metadata_list
)
target_seq_group_metadata_list
.
extend
(
non_spec_seqs
)
num_scoring_tokens
=
len
(
spec_expanded_seqs
)
# Batch speculative and non-speculative (e.g. chunked prefill) requests
# but make sure order is prefill|decode due to backend requirement.
target_seq_group_metadata_list
=
non_spec_seqs
+
spec_expanded_seqs
return
(
spec_indices
,
non_spec_indices
,
target_seq_group_metadata_list
,
num_scoring_tokens
)
def
_contract_batch
(
self
,
contracted_bs
:
int
,
target_sampler_output
:
SamplerOutput
,
proposals
:
SpeculativeProposals
,
num_scoring_tokens
:
int
,
non_spec_indices
:
List
[
int
],
spec_indices
:
List
[
int
],
k
:
int
self
,
contracted_seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
target_sampler_output
:
SamplerOutput
,
proposals
:
SpeculativeProposals
,
num_scoring_tokens
:
int
,
non_spec_indices
:
List
[
int
],
spec_indices
:
List
[
int
],
k
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""Contract the expanded batch back into its original size.
...
...
@@ -154,6 +157,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
contracted_bs is the original batch size, and the batch size that the
target_sampler_output will be contracted to.
"""
contracted_bs
=
len
(
contracted_seq_group_metadata_list
)
(
target_token_ids
,
target_probs
,
target_logprobs
,
target_hidden_states
,
non_spec_target_token_ids
,
non_spec_target_probs
,
non_spec_target_logprobs
,
...
...
@@ -166,8 +170,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
# The number of tokens in the expanded batch used for speculation is
# equal to the total expanded batch size minus the number of samples for
# non-speculative sequences
.
non_spec_expanded_bs
=
len
(
non_spec_
target_token_id
s
)
# non-speculative sequences
, prefill chunks with no out tokens included
non_spec_expanded_bs
=
len
(
non_spec_
indice
s
)
spec_expanded_bs
=
expanded_batch_size
-
non_spec_expanded_bs
target_token_ids
=
target_token_ids
.
reshape
(
spec_expanded_bs
,
k
+
1
)
...
...
@@ -191,7 +195,12 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
else
:
all_hidden_states
=
None
if
non_spec_indices
:
# Rule out prefills that produce no tokens.
non_spec_indices
=
[
idx
for
idx
in
non_spec_indices
if
contracted_seq_group_metadata_list
[
idx
].
do_sample
]
if
len
(
non_spec_indices
):
all_tokens
[
non_spec_indices
,
:
1
]
=
\
non_spec_target_token_ids
.
unsqueeze
(
1
)
all_probs
[
non_spec_indices
,
:
1
,
:]
=
\
...
...
@@ -290,9 +299,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
This function creates K+1 target SequenceGroupMetadata to take
advantage of the bonus token.
"""
assert
not
input_seq_group_metadata
.
is_prompt
,
(
"Speculating on "
"prompts not yet supported"
)
assert
len
(
input_seq_group_metadata
.
seq_data
)
==
1
,
(
"Beam search "
"not supported in speculative decoding"
)
...
...
@@ -390,27 +396,22 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
#
# First samples are from speculative scoring, latter samples are non-
# speculative samples.
split_sizes
=
(
num_scoring_tokens
,
sampler_output
.
sampled_token_ids
.
numel
()
-
num_scoring_tokens
)
(
spec_probs
,
non_spec_probs
)
=
sampler_output
.
sampled_token_probs
.
split
(
split_sizes
)
(
spec_sampled_tokens
,
non_spec_sampled_tokens
# First samples are non-speculative, latter samples are from speculative
# scoring (prefill|decode order).
split_sizes
=
(
sampler_output
.
sampled_token_ids
.
numel
()
-
num_scoring_tokens
,
num_scoring_tokens
)
(
non_spec_probs
,
spec_probs
)
=
sampler_output
.
sampled_token_probs
.
split
(
split_sizes
)
(
non_spec_sampled_tokens
,
spec_sampled_tokens
)
=
sampler_output
.
sampled_token_ids
.
flatten
().
split
(
split_sizes
)
(
spec_logprobs
,
non_spec_logprobs
,
)
=
sampler_output
.
logprobs
.
split
(
split_sizes
)
(
non_spec_logprobs
,
spec_logprobs
)
=
sampler_output
.
logprobs
.
split
(
split_sizes
)
if
sampler_output
.
hidden_states
is
not
None
:
(
spec_hidden_states
,
non_spec_hidden_states
,
)
=
sampler_output
.
hidden_states
.
split
(
split_sizes
)
(
non_spec_hidden_states
,
spec_hidden_states
)
=
sampler_output
.
hidden_states
.
split
(
split_sizes
)
else
:
spec_hidden_states
,
non_
spec_hidden_states
=
None
,
None
non_
spec_hidden_states
,
spec_hidden_states
=
None
,
None
return
(
spec_sampled_tokens
,
spec_probs
,
spec_logprobs
,
spec_hidden_states
,
non_spec_sampled_tokens
,
non_spec_probs
,
...
...
vllm/spec_decode/mqa_scorer.py
View file @
9d43afcc
...
...
@@ -21,6 +21,11 @@ class MQAScorer(SpeculativeScorer):
all_proposal_lengths
=
proposals
.
proposal_lens
.
tolist
()
for
i
,
seq_group_metadata
in
enumerate
(
execute_model_req
.
seq_group_metadata_list
):
if
all_proposal_lengths
[
i
]
==
0
:
# Keep prompt seqs untouched (keep computed_tokens for chunks).
target_seq_group_metadata_list
.
append
(
seq_group_metadata
)
continue
seq_data_dict
=
seq_group_metadata
.
seq_data
assert
len
(
seq_data_dict
)
==
1
seq_id
=
next
(
iter
(
seq_data_dict
.
keys
()))
...
...
@@ -40,8 +45,7 @@ class MQAScorer(SpeculativeScorer):
new_seq_data
.
update_num_computed_tokens
(
len
(
prompt_token_ids
)
+
len
(
output_token_ids
)
-
1
)
# Ensure that the new sequence has at least one token
# because we only use mqa scorer in the decoding stage.
# Ensure that the new decode sequence has at least one token.
assert
len
(
output_token_ids
)
>=
1
new_seq_data_dict
=
{
target_seq_id
:
new_seq_data
}
...
...
@@ -54,7 +58,6 @@ class MQAScorer(SpeculativeScorer):
target_seq_id
:
seq_group_metadata
.
block_tables
[
seq_id
],
},
lora_request
=
None
,
token_chunk_size
=
1
,
)
target_seq_group_metadata_list
.
append
(
new_seq_group_metadata
)
...
...
@@ -77,6 +80,7 @@ class MQAScorer(SpeculativeScorer):
all_probs
=
target_probs
.
reshape
(
bs
,
k
+
1
,
self
.
_vocab_size
)
all_logprobs
=
target_logprobs
.
reshape
(
bs
,
k
+
1
,
self
.
_vocab_size
)
else
:
# We either have decodes with different lens or prefill+decodes.
all_tokens
=
target_token_ids
.
new_full
(
size
=
(
bs
,
k
+
1
),
fill_value
=-
1
)
all_probs
=
target_probs
.
new_zeros
(
*
all_tokens
.
shape
,
...
...
@@ -85,15 +89,18 @@ class MQAScorer(SpeculativeScorer):
fill_value
=-
float
(
"inf"
))
target_token_ids
=
target_token_ids
.
flatten
()
start_loc
=
0
for
i
,
proposed_len
in
enumerate
(
all_proposal_lengths
):
output_len
=
proposed_len
+
1
end_loc
=
start_loc
+
output_len
all_tokens
[
i
,
:
output_len
]
=
target_token_ids
[
start_loc
:
end_loc
]
all_probs
[
i
,
:
output_len
]
=
target_probs
[
start_loc
:
end_loc
]
all_logprobs
[
i
,
:
output_len
]
=
target_logprobs
[
start_loc
:
end_loc
]
start_loc
=
end_loc
for
i
,
(
proposed_len
,
seq_meta
)
in
enumerate
(
zip
(
all_proposal_lengths
,
target_seq_group_metadata_list
)):
# Skip chunks with no output tokens.
if
seq_meta
.
do_sample
:
output_len
=
proposed_len
+
1
end_loc
=
start_loc
+
output_len
all_tokens
[
i
,
:
output_len
]
=
target_token_ids
[
start_loc
:
end_loc
]
all_probs
[
i
,
:
output_len
]
=
target_probs
[
start_loc
:
end_loc
]
all_logprobs
[
i
,
:
output_len
]
=
target_logprobs
[
start_loc
:
end_loc
]
start_loc
=
end_loc
hidden_states
=
None
if
target_sampler_output
.
hidden_states
is
not
None
:
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
9d43afcc
...
...
@@ -418,7 +418,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# none of the requests in the batch have spec decoding enabled.
# In any of these cases, the proposer and scorer workers
# are called normally.
no_spec
=
num_lookahead_slots
==
0
or
disable_all_speculation
or
all
(
# We expect `num_speculative_tokens` to be None for prefills.
no_spec
=
all
(
sgm
.
is_prompt
for
sgm
in
execute_model_req
.
seq_group_metadata_list
)
or
num_lookahead_slots
==
0
or
disable_all_speculation
or
all
(
sgm
.
num_speculative_tokens
==
0
for
sgm
in
execute_model_req
.
seq_group_metadata_list
)
...
...
@@ -484,7 +487,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
def
_serialize_sampler_output_no_logprobs
(
self
,
execute_model_req
:
ExecuteModelRequest
,
sampler_output
:
SamplerOutput
)
->
SamplerOutput
:
sampler_output
:
SamplerOutput
)
->
List
[
SamplerOutput
]
:
"""
Creates and returns a `SamplerOutput` with only the token IDs being
serialized to CPU and populated in `CompletionSequenceGroupOutput`.
...
...
@@ -514,41 +517,56 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
if
any
(
seq_output_prompt_logprobs
)
else
\
sampler_output
.
sampled_token_ids
).
tolist
()
seq_data_entries
=
(
seq_data_entries
=
[
(
seq_id
,
seq_data
)
for
sg
in
\
execute_model_req
.
seq_group_metadata_list
\
for
seq_id
,
seq_data
in
sg
.
seq_data
.
items
()
)
if
sg
.
do_sample
# ignore empty token sequences
]
completion_seq_group_output_list
:
List
[
CompletionSequenceGroupOutput
]
=
[]
for
index
,
((
seq_id
,
seq_data
),
needs_prompt_logprobs
)
in
\
enumerate
(
zip
(
seq_data_entries
,
seq_output_prompt_logprobs
)):
if
needs_prompt_logprobs
:
prompt_token_ids
=
seq_data
.
get_prompt_token_ids
()
prompt_logprobs
=
[
create_logprobs_output
(
token_id
=
p_token_id
,
output_index
=
0
# Make sure the non-terminal prefill chunks are still aligned with
# their own empty output.
for
seq_group_meta
in
execute_model_req
.
seq_group_metadata_list
:
# Since we can get chunks here, we dont always have a sampled token
# (only on last chunk) but we still have to provide an output.
if
not
seq_group_meta
.
do_sample
:
completion_seq_group_output_list
.
append
(
CompletionSequenceGroupOutput
(
samples
=
[],
prompt_logprobs
=
None
))
else
:
# Sequence with output.
seq_id
,
seq_data
=
seq_data_entries
[
output_index
]
needs_prompt_logprobs
=
seq_output_prompt_logprobs
[
output_index
]
if
needs_prompt_logprobs
:
prompt_token_ids
=
seq_data
.
get_prompt_token_ids
()
prompt_logprobs
=
[
create_logprobs_output
(
token_id
=
p_token_id
,
token_id_logprob_rank
=-
1
,
token_id_logprob
=
0.0
,
topk_token_ids
=
[],
topk_logprobs
=
[],
)
# no prompt logprobs for the first token
for
p_token_id
in
prompt_token_ids
[
1
:]
]
else
:
prompt_logprobs
=
None
completion_seq_group_output_list
.
append
(
create_sequence_group_output
(
token_id
=
sampled_token_ids_list
[
output_index
][
0
],
token_id_logprob_rank
=-
1
,
token_id_logprob
=
0.0
,
seq_id
=
seq_id
,
topk_token_ids
=
[],
topk_logprobs
=
[],
)
# no prompt logprobs for the first token
for
p_token_id
in
prompt_token_ids
[
1
:]
]
else
:
prompt_logprobs
=
None
completion_seq_group_output_list
.
append
(
create_sequence_group_output
(
token_id
=
sampled_token_ids_list
[
index
][
0
],
token_id_logprob_rank
=-
1
,
token_id_logprob
=
0.0
,
seq_id
=
seq_id
,
topk_token_ids
=
[],
topk_logprobs
=
[],
prompt_logprobs
=
prompt_logprobs
))
return
SamplerOutput
(
outputs
=
completion_seq_group_output_list
)
prompt_logprobs
=
prompt_logprobs
))
output_index
+=
1
return
[
SamplerOutput
(
outputs
=
completion_seq_group_output_list
)]
@
nvtx_range
(
"spec_decode_worker._run_no_spec"
)
def
_run_no_spec
(
self
,
execute_model_req
:
ExecuteModelRequest
,
...
...
@@ -568,6 +586,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
hidden_states
=
sampler_output
.
hidden_states
if
hidden_states
is
not
None
:
# remove hidden_states for prompt tokens
# TODO Enable `return_hidden_states`: prefill chunks hidden states
# are pruned by the logits processor. Also, they should be arranged
# back into full-prefill latent. Address it to enable MLPSpeculator.
if
any
(
seq
.
is_prompt
for
seq
in
execute_model_req
.
seq_group_metadata_list
):
hidden_states
=
hidden_states
[
...
...
@@ -593,14 +614,14 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
sampler_output_to_return
=
(
self
.
_serialize_sampler_output_no_logprobs
(
execute_model_req
=
execute_model_req
,
sampler_output
=
sampler_output
)
if
self
.
_disable_logprobs
else
sampler_output
)
[
sampler_output
]
)
# Clear device tensors from sampler output. This reduces communication
# overhead when the engine runs in a different process than the workers.
sampler_output
.
sampled_token_probs
=
None
sampler_output
.
sampled_token_ids
=
None
sampler_output
.
logprobs
=
None
return
[
sampler_output_to_return
]
return
sampler_output_to_return
def
_run_non_driver_rank
(
self
)
->
bool
:
"""Run proposer and verifier model in non-driver workers. This is used
...
...
@@ -644,9 +665,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
This invokes the proposer worker to get k speculative tokens for each
sequence, then scores each speculative token using the scoring worker.
When `enable_chunked_prefill` is set, scorer will batch decodes and
prefills, while proposer will sync its KV-cache by running an extra
forward on prefills.
Returns a list of SamplerOutput, each containing a single token per
sequence.
"""
# With prefill chunking, expect requests to have prompts first
# so that backend gets prefill|decode.
assert
num_lookahead_slots
==
execute_model_req
.
num_lookahead_slots
# Pass last hidden states from target model to proposer
...
...
@@ -671,6 +698,25 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposals
,
)
_
,
(
non_spec_seqs
,
non_spec_indices
)
=
split_batch_by_proposal_len
(
execute_model_req
.
seq_group_metadata_list
,
proposals
.
proposal_lens
)
# With prefill chunking enabled, `non_spec_seqs` contains prefills too:
# discard decodes that have already been processed by proposer.
non_spec_indices
=
[
idx
for
idx
in
non_spec_indices
if
execute_model_req
.
seq_group_metadata_list
[
idx
].
is_prompt
]
if
len
(
non_spec_indices
):
all_hidden_states
=
proposal_scores
.
hidden_states
# TODO fix `return_hidden_states`, same as in `_run_no_spec`
if
all_hidden_states
is
not
None
:
prefill_hidden_states
=
all_hidden_states
[
non_spec_indices
]
execute_model_req
.
previous_hidden_states
=
\
prepare_prefill_hidden_states
(
prefill_hidden_states
)
# Sync proposer KV cache for prefills.
prefill_req
=
execute_model_req
.
clone
(
non_spec_seqs
)
self
.
proposer_worker
.
execute_model
(
prefill_req
)
with
Timer
()
as
verification_timer
:
accepted_token_ids
,
target_logprobs
=
self
.
_verify_tokens
(
execute_model_req
.
seq_group_metadata_list
,
proposal_scores
,
...
...
@@ -769,7 +815,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self
.
previous_hidden_states
=
HiddenStates
(
hidden_states
,
seq_group_metadata_list
,
second_last_token_hidden_states
)
return
accepted_token_ids
,
logprobs
def
_create_output_sampler_list
(
...
...
@@ -819,6 +864,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
accepted_token_ids_by_step
=
accepted_token_ids_by_step
.
tolist
()
# Construct the output on a per-step, per-sequence basis.
# Non-terminal prefill chunks will end up here as rows with just -1s
# i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]]
sampler_output_list
:
List
[
SamplerOutput
]
=
[]
for
step_index
in
range
(
num_steps
):
if
all
(
token_id
==
-
1
...
...
@@ -861,7 +908,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# This is periodic because the rejection sampler emits metrics
# periodically.
self
.
_maybe_log_stage_times
(
*
stage_times
)
return
sampler_output_list
def
_maybe_log_stage_times
(
self
,
average_time_per_proposal_tok_ms
:
float
,
...
...
vllm/spec_decode/top1_proposer.py
View file @
9d43afcc
...
...
@@ -109,7 +109,6 @@ class Top1Proposer(SpeculativeProposer):
proposal_probs
=
proposal_probs
,
proposal_lens
=
proposal_lens
,
no_proposals
=
maybe_sampler_output
is
None
)
return
proposals
def
_split_by_proposal_len
(
...
...
@@ -127,9 +126,10 @@ class Top1Proposer(SpeculativeProposer):
nonzero_proposal_len_seqs
:
List
[
SequenceGroupMetadata
]
=
[]
nonzero_proposal_len_indices
:
List
[
int
]
=
[]
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
# The speculative decoding for this request has been disabled
# (e.g. due to high traffic).
if
seq_group_metadata
.
num_speculative_tokens
==
0
:
# The speculative decoding for this request has either been disabled
# (e.g. due to high traffic) or this is a prompt request.
if
(
seq_group_metadata
.
is_prompt
or
seq_group_metadata
.
num_speculative_tokens
==
0
):
proposal_lens
.
append
(
0
)
continue
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment