Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9d43afcc
Unverified
Commit
9d43afcc
authored
Nov 07, 2024
by
Nicolò Lucchesi
Committed by
GitHub
Nov 07, 2024
Browse files
[Feature] [Spec decode]: Combine chunked prefill with speculative decoding (#9291)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
ae62fd17
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
477 additions
and
147 deletions
+477
-147
tests/spec_decode/e2e/test_compatibility.py
tests/spec_decode/e2e/test_compatibility.py
+0
-34
tests/spec_decode/e2e/test_multistep_correctness.py
tests/spec_decode/e2e/test_multistep_correctness.py
+102
-3
tests/spec_decode/e2e/test_ngram_correctness.py
tests/spec_decode/e2e/test_ngram_correctness.py
+35
-1
tests/spec_decode/test_ngram_worker.py
tests/spec_decode/test_ngram_worker.py
+7
-2
tests/spec_decode/test_scorer.py
tests/spec_decode/test_scorer.py
+27
-4
tests/spec_decode/test_spec_decode_worker.py
tests/spec_decode/test_spec_decode_worker.py
+82
-0
tests/spec_decode/utils.py
tests/spec_decode/utils.py
+60
-11
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+7
-3
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+6
-0
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+7
-0
vllm/config.py
vllm/config.py
+6
-8
vllm/core/scheduler.py
vllm/core/scheduler.py
+1
-0
vllm/engine/output_processor/multi_step.py
vllm/engine/output_processor/multi_step.py
+5
-3
vllm/spec_decode/batch_expansion.py
vllm/spec_decode/batch_expansion.py
+31
-30
vllm/spec_decode/mqa_scorer.py
vllm/spec_decode/mqa_scorer.py
+19
-12
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+78
-32
vllm/spec_decode/top1_proposer.py
vllm/spec_decode/top1_proposer.py
+4
-4
No files found.
tests/spec_decode/e2e/test_compatibility.py
View file @
9d43afcc
...
@@ -5,40 +5,6 @@ from vllm import SamplingParams
...
@@ -5,40 +5,6 @@ from vllm import SamplingParams
from
.conftest
import
get_output_from_llm_generator
from
.conftest
import
get_output_from_llm_generator
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
{
"enable_chunked_prefill"
:
True
,
},
])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_spec_decode_xfail_chunked_prefill
(
test_llm_generator
):
"""Verify that speculative decoding with chunked prefill fails.
"""
output_len
=
128
temperature
=
0.0
prompts
=
[
"Hello, my name is"
,
]
sampling_params
=
SamplingParams
(
max_tokens
=
output_len
,
ignore_eos
=
True
,
temperature
=
temperature
,
)
with
pytest
.
raises
(
ValueError
,
match
=
"Speculative decoding and chunked prefill"
):
get_output_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"meta-llama/Llama-2-7b-chat-hf"
,
"model"
:
"meta-llama/Llama-2-7b-chat-hf"
,
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_model"
:
"JackFram/llama-68m"
,
...
...
tests/spec_decode/e2e/test_multistep_correctness.py
View file @
9d43afcc
...
@@ -62,6 +62,16 @@ from .conftest import (get_output_from_llm_generator,
...
@@ -62,6 +62,16 @@ from .conftest import (get_output_from_llm_generator,
{
{
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
False
,
},
{
# Chunked prefill enabled with small value
# to make sure we get mixed batches.
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
},
},
{
{
# Verify the detokenizer assertions in the test work when spec
# Verify the detokenizer assertions in the test work when spec
...
@@ -141,6 +151,14 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
...
@@ -141,6 +151,14 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
{
{
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
False
,
},
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
,
},
},
])
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -204,6 +222,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
...
@@ -204,6 +222,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
{
{
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
False
,
},
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
},
},
])
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -255,6 +281,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
...
@@ -255,6 +281,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
{
{
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
False
,
},
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
},
},
])
])
@
pytest
.
mark
.
parametrize
(
"max_output_len"
,
[
@
pytest
.
mark
.
parametrize
(
"max_output_len"
,
[
...
@@ -300,6 +334,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
...
@@ -300,6 +334,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
{
{
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
False
,
},
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
},
},
])
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
])
...
@@ -347,6 +389,14 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
...
@@ -347,6 +389,14 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
{
{
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
False
,
},
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
},
},
])
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
32
])
...
@@ -397,6 +447,14 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
...
@@ -397,6 +447,14 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
{
{
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
False
,
},
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
},
},
])
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -454,6 +512,14 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
...
@@ -454,6 +512,14 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
{
{
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
False
,
},
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
},
},
])
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
...
@@ -503,6 +569,15 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs,
...
@@ -503,6 +569,15 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs,
# Artificially limit the draft model max model len; this forces vLLM
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
# to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_max_model_len"
:
32
,
"speculative_max_model_len"
:
32
,
"enable_chunked_prefill"
:
False
,
},
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
,
"speculative_max_model_len"
:
32
,
},
},
])
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
...
@@ -551,6 +626,15 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs,
...
@@ -551,6 +626,15 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs,
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"num_speculative_tokens"
:
5
,
"speculative_disable_by_batch_size"
:
2
,
"speculative_disable_by_batch_size"
:
2
,
"enable_chunked_prefill"
:
False
,
},
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"speculative_disable_by_batch_size"
:
2
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
,
},
},
])
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
...
@@ -590,10 +674,17 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs,
...
@@ -590,10 +674,17 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs,
{
{
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
k
,
"num_speculative_tokens"
:
k
,
"enable_chunked_prefill"
:
False
,
}
}
# Try a range of common k, as well as large speculation.
# Try a range of common k, as well as large speculation.
for
k
in
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
63
]
for
k
in
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
63
]
])
]
+
[{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
k
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
,
}
for
k
in
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
63
]])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"output_len"
,
"output_len"
,
...
@@ -636,11 +727,19 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
...
@@ -636,11 +727,19 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
{
{
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
k
,
"num_speculative_tokens"
:
k
,
"spec_decoding_acceptance_method"
:
"typical_acceptance_sampler"
"spec_decoding_acceptance_method"
:
"typical_acceptance_sampler"
,
"enable_chunked_prefill"
:
False
}
}
# Try a range of common k.
# Try a range of common k.
for
k
in
[
1
,
2
,
3
]
for
k
in
[
1
,
2
,
3
]
])
]
+
[{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
k
,
"spec_decoding_acceptance_method"
:
"typical_acceptance_sampler"
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
}
for
k
in
[
1
,
2
,
3
]])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"output_len"
,
"output_len"
,
...
...
tests/spec_decode/e2e/test_ngram_correctness.py
View file @
9d43afcc
...
@@ -50,18 +50,33 @@ from .conftest import run_equality_correctness_test
...
@@ -50,18 +50,33 @@ from .conftest import run_equality_correctness_test
"num_speculative_tokens"
:
5
,
"num_speculative_tokens"
:
5
,
"ngram_prompt_lookup_max"
:
3
,
"ngram_prompt_lookup_max"
:
3
,
},
},
{
"speculative_model"
:
"[ngram]"
,
"num_speculative_tokens"
:
5
,
"ngram_prompt_lookup_max"
:
3
,
},
])
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
256
,
256
,
])
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"prefill_chunk_size"
,
[
-
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_ngram_e2e_greedy_correctness
(
vllm_runner
,
common_llm_kwargs
,
def
test_ngram_e2e_greedy_correctness
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
prefill_chunk_size
:
int
,
seed
:
int
):
"""Verify greedy equality on a tiny model with different batch size."""
"""Verify greedy equality on a tiny model with different batch size."""
if
prefill_chunk_size
>
0
:
common_llm_kwargs
.
update
(
**
{
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
prefill_chunk_size
,
"max_num_seqs"
:
prefill_chunk_size
})
else
:
common_llm_kwargs
[
"enable_chunked_prefill"
]
=
False
run_equality_correctness_test
(
vllm_runner
,
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
per_test_common_llm_kwargs
,
...
@@ -151,6 +166,16 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
...
@@ -151,6 +166,16 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
"speculative_model"
:
"[ngram]"
,
"speculative_model"
:
"[ngram]"
,
"num_speculative_tokens"
:
5
,
"num_speculative_tokens"
:
5
,
"ngram_prompt_lookup_max"
:
3
,
"ngram_prompt_lookup_max"
:
3
,
"enable_chunked_prefill"
:
False
,
},
{
"speculative_model"
:
"[ngram]"
,
"num_speculative_tokens"
:
5
,
"ngram_prompt_lookup_max"
:
3
,
"enable_chunked_prefill"
:
True
,
"speculative_disable_mqa_scorer"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
},
},
])
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -251,6 +276,15 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
...
@@ -251,6 +276,15 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
"num_speculative_tokens"
:
5
,
"num_speculative_tokens"
:
5
,
"ngram_prompt_lookup_max"
:
3
,
"ngram_prompt_lookup_max"
:
3
,
"speculative_disable_by_batch_size"
:
4
"speculative_disable_by_batch_size"
:
4
},
{
"speculative_model"
:
"[ngram]"
,
"num_speculative_tokens"
:
5
,
"ngram_prompt_lookup_max"
:
3
,
"speculative_disable_by_batch_size"
:
4
,
"enable_chunked_prefill"
:
True
,
"speculative_disable_mqa_scorer"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
}])
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
...
tests/spec_decode/test_ngram_worker.py
View file @
9d43afcc
...
@@ -118,7 +118,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
...
@@ -118,7 +118,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
num_gpu_blocks
,
num_gpu_blocks
,
block_size
,
block_size
,
final_prompt_lens
=
final_prompt_lens
)
final_prompt_lens
=
final_prompt_lens
)
for
sg
in
seq_group_metadata_list
:
sg
.
is_prompt
=
False
proposals
=
proposer
.
get_spec_proposals
(
proposals
=
proposer
.
get_spec_proposals
(
execute_model_req
=
ExecuteModelRequest
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
seq_group_metadata_list
,
...
@@ -147,7 +148,7 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
...
@@ -147,7 +148,7 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
def
test_ngram_algo_correctness_for_batches_match_all
():
def
test_ngram_algo_correctness_for_batches_match_all
():
"""Verify our ngram algo find the right candidate in the prompt
"""Verify our ngram algo find the right candidate in the prompt
For the scenario find candidate in all batchs
For the scenario find candidate in all batch
e
s
"""
"""
block_size
=
32
block_size
=
32
...
@@ -192,6 +193,10 @@ def test_ngram_algo_correctness_for_batches_match_all():
...
@@ -192,6 +193,10 @@ def test_ngram_algo_correctness_for_batches_match_all():
block_size
,
block_size
,
final_prompt_lens
=
final_prompt_lens
)
final_prompt_lens
=
final_prompt_lens
)
# Normally drafter is run on decode requests only; here we check the output
# of the ngram worker as it is the sole proposer that has no forward.
for
sg
in
seq_group_metadata_list
:
sg
.
is_prompt
=
False
proposals
=
proposer
.
get_spec_proposals
(
proposals
=
proposer
.
get_spec_proposals
(
execute_model_req
=
ExecuteModelRequest
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
seq_group_metadata_list
,
...
...
tests/spec_decode/test_scorer.py
View file @
9d43afcc
...
@@ -46,12 +46,14 @@ def assert_score_equal(score1: SpeculativeScores,
...
@@ -46,12 +46,14 @@ def assert_score_equal(score1: SpeculativeScores,
@
pytest
.
mark
.
parametrize
(
'max_propose_len'
,
[
1
,
3
,
5
])
@
pytest
.
mark
.
parametrize
(
'max_propose_len'
,
[
1
,
3
,
5
])
@
pytest
.
mark
.
parametrize
(
'mixed_propose_len'
,
[
True
])
@
pytest
.
mark
.
parametrize
(
'mixed_propose_len'
,
[
True
])
@
pytest
.
mark
.
parametrize
(
'device'
,
[
'cuda'
])
@
pytest
.
mark
.
parametrize
(
'device'
,
[
'cuda'
])
@
pytest
.
mark
.
parametrize
(
'prefill_chunking'
,
[
False
,
True
])
def
test_scorer
(
model_name
:
str
,
batch_size
:
int
,
max_propose_len
:
int
,
def
test_scorer
(
model_name
:
str
,
batch_size
:
int
,
max_propose_len
:
int
,
mixed_propose_len
:
bool
,
device
:
str
)
->
None
:
mixed_propose_len
:
bool
,
device
:
str
,
prefill_chunking
:
bool
)
->
None
:
"""
"""
Compare the batch expansion scorer and mqa scorer return the same score.
Compare the batch expansion scorer and mqa scorer return the same score.
We test for both queries with the same propose length and different
We test for both queries with the same propose length and different
propose length.
propose length
, as well as mixed prefill-decode batches
.
"""
"""
seed
=
0
seed
=
0
block_size
=
32
block_size
=
32
...
@@ -67,16 +69,37 @@ def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
...
@@ -67,16 +69,37 @@ def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
if
not
mixed_propose_len
:
if
not
mixed_propose_len
:
propose_lens
=
[
max_propose_len
]
*
batch_size
propose_lens
=
[
max_propose_len
]
*
batch_size
else
:
else
:
non_zero_cnt
=
random
.
randint
(
0
,
batch_size
)
# There must be at least 1 decode request, otherwise
# we have nothing to score (`_run_no_spec`).
non_zero_cnt
=
random
.
randint
(
1
,
batch_size
)
propose_lens
=
[
max_propose_len
propose_lens
=
[
max_propose_len
]
*
non_zero_cnt
+
[
0
]
*
(
batch_size
-
non_zero_cnt
)
]
*
non_zero_cnt
+
[
0
]
*
(
batch_size
-
non_zero_cnt
)
random
.
shuffle
(
propose_lens
)
random
.
shuffle
(
propose_lens
)
proposals
=
create_proposal
(
propose_lens
,
vocab_size
,
device
)
seq_group_metadatalist
,
_
,
_
=
create_batch
(
batch_size
,
seq_group_metadatalist
,
_
,
_
=
create_batch
(
batch_size
,
max_propose_len
,
max_propose_len
,
block_size
=
block_size
,
block_size
=
block_size
,
num_gpu_blocks
=
num_gpu_blocks
)
num_gpu_blocks
=
num_gpu_blocks
)
if
mixed_propose_len
and
prefill_chunking
and
(
n_prefills
:
=
batch_size
-
non_zero_cnt
):
prefill
,
_
,
_
=
create_batch
(
n_prefills
,
None
,
prefill_chunk_size
=
4
,
block_size
=
block_size
,
num_gpu_blocks
=
num_gpu_blocks
,
seq_ids
=
list
(
range
(
batch_size
,
batch_size
+
n_prefills
)))
# re-order to guarantee prefill|decode order
target_group_metadatalist
=
[
seq_group_metadatalist
[
i
]
for
i
,
p
in
enumerate
(
propose_lens
)
if
p
>
0
]
seq_group_metadatalist
=
prefill
+
target_group_metadatalist
propose_lens
=
[
0
]
*
n_prefills
+
[
p
for
p
in
propose_lens
if
p
>
0
]
proposals
=
create_proposal
(
propose_lens
,
vocab_size
,
device
)
requests
=
ExecuteModelRequest
(
seq_group_metadatalist
,
requests
=
ExecuteModelRequest
(
seq_group_metadatalist
,
num_lookahead_slots
=
max_propose_len
)
num_lookahead_slots
=
max_propose_len
)
...
...
tests/spec_decode/test_spec_decode_worker.py
View file @
9d43afcc
...
@@ -10,6 +10,7 @@ import torch
...
@@ -10,6 +10,7 @@ import torch
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
,
SequenceOutput
from
vllm.sequence
import
ExecuteModelRequest
,
SequenceOutput
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.metrics
import
(
AsyncMetricsCollector
,
from
vllm.spec_decode.metrics
import
(
AsyncMetricsCollector
,
SpecDecodeWorkerMetrics
)
SpecDecodeWorkerMetrics
)
...
@@ -819,3 +820,84 @@ def test_handle_finished_requests():
...
@@ -819,3 +820,84 @@ def test_handle_finished_requests():
# and 'request-3' are removed from seq_with_bonus_token_in_last_step.
# and 'request-3' are removed from seq_with_bonus_token_in_last_step.
assert
worker
.
_seq_with_bonus_token_in_last_step
==
\
assert
worker
.
_seq_with_bonus_token_in_last_step
==
\
{
4
,
5
,
10
}
{
4
,
5
,
10
}
@
pytest
.
mark
.
parametrize
(
'k'
,
[
3
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"batch_composition"
,
[
"prefill_only"
,
"decode_only"
,
"mixed"
])
@
torch
.
inference_mode
()
def
test_chunked_prefill_flow
(
k
:
int
,
batch_size
:
int
,
batch_composition
:
str
):
"""
Verify SpecDecodeWorker calls match the expected flow.
"""
vocab_size
=
32_000
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
target_worker
=
mock_worker
()
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
mock_spec_decode_sampler
(
"rejection_sampler"
),
disable_logprobs
=
False
,
metrics_collector
=
metrics_collector
)
exception_secret
=
'artificial stop'
worker
.
scorer
=
mock_worker
(
BatchExpansionTop1Scorer
)
worker
.
scorer
.
score_proposals
.
side_effect
=
ValueError
(
exception_secret
)
# Create batch with combination of terminal/non-terminal prefill chunks
# and decodes (different seq_ids).
decodes
,
_
,
_
=
create_batch
(
batch_size
,
k
)
# Pre-chunking here, get 'batch_size' chunks.
prefill
,
_
,
_
=
create_batch
(
batch_size
,
k
,
prefill_chunk_size
=
4
,
seq_ids
=
list
(
range
(
batch_size
,
batch_size
*
2
)))
if
batch_composition
==
"prefill_only"
:
n_prefills
=
batch_size
elif
batch_composition
==
"decode_only"
:
n_prefills
=
0
else
:
n_prefills
=
random
.
randint
(
1
,
batch_size
-
1
)
n_decodes
=
batch_size
-
n_prefills
prefill
=
random
.
sample
(
prefill
,
n_prefills
)
decodes
=
random
.
sample
(
decodes
,
n_decodes
)
target_group_metadata_list
=
prefill
+
decodes
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
target_group_metadata_list
,
num_lookahead_slots
=
k
)
target_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
1
,
batch_size
*
(
k
+
1
)),
dtype
=
torch
.
int64
,
device
=
'cuda'
)
target_token_probs
=
torch
.
rand
(
1
,
batch_size
*
(
k
+
1
),
vocab_size
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
target_token_logprobs
=
torch
.
rand
(
1
,
batch_size
*
(
k
+
1
),
vocab_size
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
target_output
=
create_sampler_output_list
(
target_token_ids
,
target_token_probs
,
target_token_logprobs
)
target_worker
.
execute_model
.
return_value
=
[
target_output
[
0
]]
if
not
len
(
decodes
):
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
# no spec run (prefill only)
draft_worker
.
execute_model
.
assert_called_once_with
(
execute_model_req
)
target_worker
.
execute_model
.
assert_called_once_with
(
execute_model_req
)
else
:
# Decode-only run OR mixed batch, scorer call fails (it's mocked)
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
# but first draft still counted
assert
draft_worker
.
get_spec_proposals
.
call_count
==
1
tests/spec_decode/utils.py
View file @
9d43afcc
...
@@ -146,6 +146,41 @@ def create_seq_group_metadata_from_prompts(
...
@@ -146,6 +146,41 @@ def create_seq_group_metadata_from_prompts(
return
seq_grou_metadata_list
return
seq_grou_metadata_list
def
create_chunked_seq_group_metadata_from_prompt
(
prompt
:
List
[
int
],
num_gpu_blocks
:
int
,
chunk_size
:
int
,
block_size
:
int
,
seq_id
:
Optional
[
int
]
=
None
)
->
List
[
SequenceGroupMetadata
]:
if
seq_id
is
None
:
seq_id
=
0
free_gpu_blocks
=
list
(
range
(
num_gpu_blocks
))
block_allocations
=
[
free_gpu_blocks
.
pop
()
for
_
in
range
(
round_up_to_next_block
(
len
(
prompt
),
block_size
))
]
seq_group_metadata_list
=
[]
for
i
,
idx
in
enumerate
(
range
(
0
,
len
(
prompt
),
chunk_size
)):
chunk_ids
=
prompt
[
idx
:
idx
+
chunk_size
]
data
=
SequenceData
.
from_seqs
(
prompt
)
data
.
update_num_computed_tokens
(
idx
)
seq_data
=
{
i
:
data
}
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
str
(
seq_id
),
is_prompt
=
True
,
do_sample
=
idx
+
chunk_size
>=
len
(
prompt
),
# terminal chunk
seq_data
=
seq_data
,
sampling_params
=
SamplingParams
(
temperature
=
0.0
),
block_tables
=
{
i
:
block_allocations
},
token_chunk_size
=
len
(
chunk_ids
)))
return
seq_group_metadata_list
def
assert_logprobs_dict_allclose
(
def
assert_logprobs_dict_allclose
(
actual_logprobs
:
List
[
Dict
[
int
,
Logprob
]],
actual_logprobs
:
List
[
Dict
[
int
,
Logprob
]],
expected_logprobs
:
List
[
Dict
[
int
,
Logprob
]])
->
None
:
expected_logprobs
:
List
[
Dict
[
int
,
Logprob
]])
->
None
:
...
@@ -198,7 +233,8 @@ def create_batch(batch_size,
...
@@ -198,7 +233,8 @@ def create_batch(batch_size,
prev_output_token_len
:
int
=
10
,
prev_output_token_len
:
int
=
10
,
seq_ids
:
Optional
[
List
[
int
]]
=
None
,
seq_ids
:
Optional
[
List
[
int
]]
=
None
,
num_gpu_blocks
:
Optional
[
int
]
=
None
,
num_gpu_blocks
:
Optional
[
int
]
=
None
,
block_size
:
Optional
[
int
]
=
None
):
block_size
:
Optional
[
int
]
=
None
,
prefill_chunk_size
:
Optional
[
int
]
=
None
):
if
block_size
is
None
:
if
block_size
is
None
:
block_size
=
8
block_size
=
8
...
@@ -213,15 +249,28 @@ def create_batch(batch_size,
...
@@ -213,15 +249,28 @@ def create_batch(batch_size,
prompt_lens
=
prompt_len
prompt_lens
=
prompt_len
prompts
=
[[
next
(
iterator
)
for
_
in
range
(
p_len
)]
for
p_len
in
prompt_lens
]
prompts
=
[[
next
(
iterator
)
for
_
in
range
(
p_len
)]
for
p_len
in
prompt_lens
]
prev_output_tokens
=
[[
next
(
iterator
)
for
_
in
range
(
prev_output_token_len
)
]
for
_
in
range
(
batch_size
)]
final_prompt_lens
=
[
len
(
prompt
)
+
len
(
prev_output_token
)
+
k
+
1
for
prompt
,
prev_output_token
in
zip
(
prompts
,
prev_output_tokens
)
]
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
if
prefill_chunk_size
:
prompts
,
num_gpu_blocks
,
block_size
,
final_prompt_lens
,
# Create a batch of chunked prompts.
prev_output_tokens
,
seq_ids
)
if
not
seq_ids
:
seq_ids
=
list
(
range
(
len
(
prompts
)))
seq_group_metadata_list
=
[]
for
p
,
sid
in
zip
(
prompts
,
seq_ids
):
seq_group_metadata_list
+=
\
create_chunked_seq_group_metadata_from_prompt
(
p
,
num_gpu_blocks
,
prefill_chunk_size
,
block_size
,
sid
)
seq_group_metadata_list
=
seq_group_metadata_list
[:
batch_size
]
prev_output_tokens
=
[]
else
:
prev_output_tokens
=
[[
next
(
iterator
)
for
_
in
range
(
prev_output_token_len
)
]
for
_
in
range
(
batch_size
)]
final_prompt_lens
=
[
len
(
prompt
)
+
len
(
prev_output_token
)
+
k
+
1
for
prompt
,
prev_output_token
in
zip
(
prompts
,
prev_output_tokens
)
]
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_prompt_lens
,
prev_output_tokens
,
seq_ids
)
return
seq_group_metadata_list
,
prompts
,
prev_output_tokens
return
seq_group_metadata_list
,
prompts
,
prev_output_tokens
vllm/attention/backends/flash_attn.py
View file @
9d43afcc
...
@@ -276,7 +276,11 @@ class FlashAttentionMetadata(AttentionMetadata):
...
@@ -276,7 +276,11 @@ class FlashAttentionMetadata(AttentionMetadata):
max_query_len
=
self
.
max_query_len
,
max_query_len
=
self
.
max_query_len
,
max_prefill_seq_len
=
0
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
self
.
max_decode_seq_len
,
max_decode_seq_len
=
self
.
max_decode_seq_len
,
query_start_loc
=
self
.
query_start_loc
[
self
.
num_prefills
:]
# Batch may be composed of prefill|decodes, adjust query start
# indices to refer to the start of decodes. E.g.
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
query_start_loc
=
(
self
.
query_start_loc
[
self
.
num_prefills
:]
-
self
.
query_start_loc
[
self
.
num_prefills
])
if
self
.
query_start_loc
is
not
None
else
None
,
if
self
.
query_start_loc
is
not
None
else
None
,
seq_start_loc
=
self
.
seq_start_loc
[
self
.
num_prefills
:]
seq_start_loc
=
self
.
seq_start_loc
[
self
.
num_prefills
:]
if
self
.
seq_start_loc
is
not
None
else
None
,
if
self
.
seq_start_loc
is
not
None
else
None
,
...
@@ -903,7 +907,9 @@ def unified_flash_attention(
...
@@ -903,7 +907,9 @@ def unified_flash_attention(
# Decoding run.
# Decoding run.
# Use flash_attn_varlen_func kernel for speculative decoding
# Use flash_attn_varlen_func kernel for speculative decoding
# because different queries might have different lengths.
# because different queries might have different lengths.
assert
decode_meta
.
max_decode_query_len
is
not
None
assert
decode_meta
.
max_decode_query_len
is
not
None
# use only for actual varlen decoding
if
decode_meta
.
max_decode_query_len
>
1
:
if
decode_meta
.
max_decode_query_len
>
1
:
assert
attn_type
==
AttentionType
.
DECODER
,
(
assert
attn_type
==
AttentionType
.
DECODER
,
(
"Only decoder-only models support max_decode_query_len > 1"
)
"Only decoder-only models support max_decode_query_len > 1"
)
...
@@ -949,8 +955,6 @@ def unified_flash_attention(
...
@@ -949,8 +955,6 @@ def unified_flash_attention(
assert
prefill_output
is
not
None
assert
prefill_output
is
not
None
return
prefill_output
.
view
(
num_prefill_query_tokens
,
hidden_size
)
return
prefill_output
.
view
(
num_prefill_query_tokens
,
hidden_size
)
# Chunked prefill does not work with speculative decoding.
# Therefore, the query length for decode should be 1 in chunked prefill.
assert
decode_meta
is
not
None
assert
decode_meta
is
not
None
decode_output
=
decode_output
.
squeeze
(
1
)
decode_output
=
decode_output
.
squeeze
(
1
)
output
=
torch
.
cat
([
prefill_output
,
decode_output
],
dim
=
0
)
output
=
torch
.
cat
([
prefill_output
,
decode_output
],
dim
=
0
)
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
9d43afcc
...
@@ -192,6 +192,12 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -192,6 +192,12 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
block_tables
=
self
.
block_tables
[
self
.
num_prefills
:],
block_tables
=
self
.
block_tables
[
self
.
num_prefills
:],
use_cuda_graph
=
self
.
use_cuda_graph
,
use_cuda_graph
=
self
.
use_cuda_graph
,
)
)
# Batch may be composed of prefill|decodes, adjust query start indices
# to refer to the start of decodes when the two are split apart.
# E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
if
self
.
_cached_decode_metadata
.
query_start_loc
is
not
None
:
qs
=
self
.
_cached_decode_metadata
.
query_start_loc
self
.
_cached_decode_metadata
.
query_start_loc
=
qs
-
qs
[
0
]
return
self
.
_cached_decode_metadata
return
self
.
_cached_decode_metadata
def
advance_step
(
self
,
def
advance_step
(
self
,
...
...
vllm/attention/backends/xformers.py
View file @
9d43afcc
...
@@ -272,6 +272,13 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -272,6 +272,13 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
max_encoder_seq_len
=
self
.
max_encoder_seq_len
,
max_encoder_seq_len
=
self
.
max_encoder_seq_len
,
cross_slot_mapping
=
self
.
cross_slot_mapping
,
cross_slot_mapping
=
self
.
cross_slot_mapping
,
cross_block_tables
=
self
.
cross_block_tables
)
cross_block_tables
=
self
.
cross_block_tables
)
# Batch may be composed of prefill|decodes, adjust query start indices
# to refer to the start of decodes when the two are split apart.
# E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
if
self
.
_cached_decode_metadata
.
query_start_loc
is
not
None
:
qs
=
self
.
_cached_decode_metadata
.
query_start_loc
self
.
_cached_decode_metadata
.
query_start_loc
=
qs
-
qs
[
0
]
return
self
.
_cached_decode_metadata
return
self
.
_cached_decode_metadata
...
...
vllm/config.py
View file @
9d43afcc
...
@@ -192,7 +192,6 @@ class ModelConfig:
...
@@ -192,7 +192,6 @@ class ModelConfig:
self
.
max_logprobs
=
max_logprobs
self
.
max_logprobs
=
max_logprobs
self
.
disable_sliding_window
=
disable_sliding_window
self
.
disable_sliding_window
=
disable_sliding_window
self
.
skip_tokenizer_init
=
skip_tokenizer_init
self
.
skip_tokenizer_init
=
skip_tokenizer_init
self
.
hf_config
=
get_config
(
self
.
model
,
trust_remote_code
,
revision
,
self
.
hf_config
=
get_config
(
self
.
model
,
trust_remote_code
,
revision
,
code_revision
,
rope_scaling
,
rope_theta
,
code_revision
,
rope_scaling
,
rope_theta
,
config_format
)
config_format
)
...
@@ -1317,13 +1316,6 @@ class SpeculativeConfig:
...
@@ -1317,13 +1316,6 @@ class SpeculativeConfig:
"speculative decoding is > 1, but got "
"speculative decoding is > 1, but got "
f
"
{
speculative_disable_by_batch_size
=
}
"
)
f
"
{
speculative_disable_by_batch_size
=
}
"
)
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid
if
enable_chunked_prefill
:
raise
ValueError
(
"Speculative decoding and chunked prefill are "
f
"currently mutually exclusive (
{
enable_chunked_prefill
=
}
)."
)
# TODO: The user should be able to specify revision/max model len
# TODO: The user should be able to specify revision/max model len
# for the draft model. It is not currently supported.
# for the draft model. It is not currently supported.
draft_revision
=
None
draft_revision
=
None
...
@@ -1390,6 +1382,12 @@ class SpeculativeConfig:
...
@@ -1390,6 +1382,12 @@ class SpeculativeConfig:
f
"num_speculative_tokens=
{
n_predict
}
, but "
f
"num_speculative_tokens=
{
n_predict
}
, but "
f
"
{
num_speculative_tokens
=
}
was provided."
)
f
"
{
num_speculative_tokens
=
}
was provided."
)
if
enable_chunked_prefill
and
draft_hf_config
.
model_type
in
(
"medusa"
,
"mlp_speculator"
,
"eagle"
):
raise
ValueError
(
"Chunked prefill and hidden-state based draft models are "
"not compatible."
)
draft_model_config
.
max_model_len
=
(
draft_model_config
.
max_model_len
=
(
SpeculativeConfig
.
_maybe_override_draft_max_model_len
(
SpeculativeConfig
.
_maybe_override_draft_max_model_len
(
speculative_max_model_len
,
speculative_max_model_len
,
...
...
vllm/core/scheduler.py
View file @
9d43afcc
...
@@ -1147,6 +1147,7 @@ class Scheduler:
...
@@ -1147,6 +1147,7 @@ class Scheduler:
# Update swapped requests.
# Update swapped requests.
self
.
swapped
.
extend
(
running_scheduled
.
swapped_out
)
self
.
swapped
.
extend
(
running_scheduled
.
swapped_out
)
# Put prefills first due to Attention backend ordering assumption.
return
SchedulerOutputs
(
return
SchedulerOutputs
(
scheduled_seq_groups
=
(
prefills
.
seq_groups
+
scheduled_seq_groups
=
(
prefills
.
seq_groups
+
running_scheduled
.
prefill_seq_groups
+
running_scheduled
.
prefill_seq_groups
+
...
...
vllm/engine/output_processor/multi_step.py
View file @
9d43afcc
...
@@ -134,10 +134,12 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -134,10 +134,12 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
sample
for
sample
in
samples
sample
for
sample
in
samples
if
sample
.
output_token
!=
VLLM_INVALID_TOKEN_ID
if
sample
.
output_token
!=
VLLM_INVALID_TOKEN_ID
]
]
assert
valid_samples
self
.
_process_seq_outputs
(
seq
,
valid_samples
,
# When both spec-decode and pre-fill chunking are enabled, we
sequence_group
.
sampling_params
)
# don't have guaranteed samples here (e.g. all -1s).
if
valid_samples
:
self
.
_process_seq_outputs
(
seq
,
valid_samples
,
sequence_group
.
sampling_params
)
def
_process_decode_and_stop
(
self
,
seq
:
Sequence
,
def
_process_decode_and_stop
(
self
,
seq
:
Sequence
,
sampling_params
:
SamplingParams
)
->
None
:
sampling_params
:
SamplingParams
)
->
None
:
...
...
vllm/spec_decode/batch_expansion.py
View file @
9d43afcc
...
@@ -90,7 +90,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -90,7 +90,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
else
:
else
:
# Batch has a mix of spec decode enabled and disabled seq groups
# Batch has a mix of spec decode enabled and disabled seq groups
contracted
=
self
.
_contract_batch
(
contracted
=
self
.
_contract_batch
(
contracted_bs
=
len
(
execute_model_req
.
seq_group_metadata_list
)
,
execute_model_req
.
seq_group_metadata_list
,
target_sampler_output
=
target_sampler_output
,
target_sampler_output
=
target_sampler_output
,
proposals
=
proposals
,
proposals
=
proposals
,
num_scoring_tokens
=
num_scoring_tokens
,
num_scoring_tokens
=
num_scoring_tokens
,
...
@@ -126,7 +126,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -126,7 +126,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
split_batch_by_proposal_len
(
split_batch_by_proposal_len
(
seq_group_metadata_list
,
proposal_lens_list
)
seq_group_metadata_list
,
proposal_lens_list
)
target_seq_group_metadata_list
=
self
.
_create_scoring_model_input
(
spec_expanded_seqs
=
self
.
_create_scoring_model_input
(
seq_group_metadata_list
=
spec_seqs
,
seq_group_metadata_list
=
spec_seqs
,
proposal_token_ids
=
proposal_token_ids_list
,
proposal_token_ids
=
proposal_token_ids_list
,
# NOTE: We determine the seq ids in the expanded batch using the
# NOTE: We determine the seq ids in the expanded batch using the
...
@@ -135,16 +135,19 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -135,16 +135,19 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
seq_ids
=
get_all_seq_ids
(
seq_group_metadata_list
)),
seq_ids
=
get_all_seq_ids
(
seq_group_metadata_list
)),
)
)
num_scoring_tokens
=
len
(
target_seq_group_metadata_list
)
num_scoring_tokens
=
len
(
spec_expanded_seqs
)
target_seq_group_metadata_list
.
extend
(
non_spec_seqs
)
# Batch speculative and non-speculative (e.g. chunked prefill) requests
# but make sure order is prefill|decode due to backend requirement.
target_seq_group_metadata_list
=
non_spec_seqs
+
spec_expanded_seqs
return
(
spec_indices
,
non_spec_indices
,
target_seq_group_metadata_list
,
return
(
spec_indices
,
non_spec_indices
,
target_seq_group_metadata_list
,
num_scoring_tokens
)
num_scoring_tokens
)
def
_contract_batch
(
def
_contract_batch
(
self
,
contracted_bs
:
int
,
target_sampler_output
:
SamplerOutput
,
self
,
contracted_seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
proposals
:
SpeculativeProposals
,
num_scoring_tokens
:
int
,
target_sampler_output
:
SamplerOutput
,
proposals
:
SpeculativeProposals
,
non_spec_indices
:
List
[
int
],
spec_indices
:
List
[
int
],
k
:
int
num_scoring_tokens
:
int
,
non_spec_indices
:
List
[
int
],
spec_indices
:
List
[
int
],
k
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
Optional
[
torch
.
Tensor
]]:
"""Contract the expanded batch back into its original size.
"""Contract the expanded batch back into its original size.
...
@@ -154,6 +157,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -154,6 +157,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
contracted_bs is the original batch size, and the batch size that the
contracted_bs is the original batch size, and the batch size that the
target_sampler_output will be contracted to.
target_sampler_output will be contracted to.
"""
"""
contracted_bs
=
len
(
contracted_seq_group_metadata_list
)
(
target_token_ids
,
target_probs
,
target_logprobs
,
target_hidden_states
,
(
target_token_ids
,
target_probs
,
target_logprobs
,
target_hidden_states
,
non_spec_target_token_ids
,
non_spec_target_probs
,
non_spec_target_token_ids
,
non_spec_target_probs
,
non_spec_target_logprobs
,
non_spec_target_logprobs
,
...
@@ -166,8 +170,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -166,8 +170,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
# The number of tokens in the expanded batch used for speculation is
# The number of tokens in the expanded batch used for speculation is
# equal to the total expanded batch size minus the number of samples for
# equal to the total expanded batch size minus the number of samples for
# non-speculative sequences
.
# non-speculative sequences
, prefill chunks with no out tokens included
non_spec_expanded_bs
=
len
(
non_spec_
target_token_id
s
)
non_spec_expanded_bs
=
len
(
non_spec_
indice
s
)
spec_expanded_bs
=
expanded_batch_size
-
non_spec_expanded_bs
spec_expanded_bs
=
expanded_batch_size
-
non_spec_expanded_bs
target_token_ids
=
target_token_ids
.
reshape
(
spec_expanded_bs
,
k
+
1
)
target_token_ids
=
target_token_ids
.
reshape
(
spec_expanded_bs
,
k
+
1
)
...
@@ -191,7 +195,12 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -191,7 +195,12 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
else
:
else
:
all_hidden_states
=
None
all_hidden_states
=
None
if
non_spec_indices
:
# Rule out prefills that produce no tokens.
non_spec_indices
=
[
idx
for
idx
in
non_spec_indices
if
contracted_seq_group_metadata_list
[
idx
].
do_sample
]
if
len
(
non_spec_indices
):
all_tokens
[
non_spec_indices
,
:
1
]
=
\
all_tokens
[
non_spec_indices
,
:
1
]
=
\
non_spec_target_token_ids
.
unsqueeze
(
1
)
non_spec_target_token_ids
.
unsqueeze
(
1
)
all_probs
[
non_spec_indices
,
:
1
,
:]
=
\
all_probs
[
non_spec_indices
,
:
1
,
:]
=
\
...
@@ -290,9 +299,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -290,9 +299,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
This function creates K+1 target SequenceGroupMetadata to take
This function creates K+1 target SequenceGroupMetadata to take
advantage of the bonus token.
advantage of the bonus token.
"""
"""
assert
not
input_seq_group_metadata
.
is_prompt
,
(
"Speculating on "
"prompts not yet supported"
)
assert
len
(
input_seq_group_metadata
.
seq_data
)
==
1
,
(
assert
len
(
input_seq_group_metadata
.
seq_data
)
==
1
,
(
"Beam search "
"Beam search "
"not supported in speculative decoding"
)
"not supported in speculative decoding"
)
...
@@ -390,27 +396,22 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -390,27 +396,22 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
# and non spec sequences) and should be removed in the future. It can be
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
# done by supporting per-sequence proposal lens.
#
#
# First samples are from speculative scoring, latter samples are non-
# First samples are non-speculative, latter samples are from speculative
# speculative samples.
# scoring (prefill|decode order).
split_sizes
=
(
num_scoring_tokens
,
split_sizes
=
(
sampler_output
.
sampled_token_ids
.
numel
()
-
sampler_output
.
sampled_token_ids
.
numel
()
-
num_scoring_tokens
,
num_scoring_tokens
)
num_scoring_tokens
)
(
non_spec_probs
,
(
spec_probs
,
non_spec_probs
spec_probs
)
=
sampler_output
.
sampled_token_probs
.
split
(
split_sizes
)
)
=
sampler_output
.
sampled_token_probs
.
split
(
split_sizes
)
(
non_spec_sampled_tokens
,
spec_sampled_tokens
(
spec_sampled_tokens
,
non_spec_sampled_tokens
)
=
sampler_output
.
sampled_token_ids
.
flatten
().
split
(
split_sizes
)
)
=
sampler_output
.
sampled_token_ids
.
flatten
().
split
(
split_sizes
)
(
(
non_spec_logprobs
,
spec_logprobs
,
spec_logprobs
)
=
sampler_output
.
logprobs
.
split
(
split_sizes
)
non_spec_logprobs
,
)
=
sampler_output
.
logprobs
.
split
(
split_sizes
)
if
sampler_output
.
hidden_states
is
not
None
:
if
sampler_output
.
hidden_states
is
not
None
:
(
(
non_spec_hidden_states
,
spec_hidden_states
spec_hidden_states
,
)
=
sampler_output
.
hidden_states
.
split
(
split_sizes
)
non_spec_hidden_states
,
)
=
sampler_output
.
hidden_states
.
split
(
split_sizes
)
else
:
else
:
spec_hidden_states
,
non_
spec_hidden_states
=
None
,
None
non_
spec_hidden_states
,
spec_hidden_states
=
None
,
None
return
(
spec_sampled_tokens
,
spec_probs
,
spec_logprobs
,
return
(
spec_sampled_tokens
,
spec_probs
,
spec_logprobs
,
spec_hidden_states
,
non_spec_sampled_tokens
,
non_spec_probs
,
spec_hidden_states
,
non_spec_sampled_tokens
,
non_spec_probs
,
...
...
vllm/spec_decode/mqa_scorer.py
View file @
9d43afcc
...
@@ -21,6 +21,11 @@ class MQAScorer(SpeculativeScorer):
...
@@ -21,6 +21,11 @@ class MQAScorer(SpeculativeScorer):
all_proposal_lengths
=
proposals
.
proposal_lens
.
tolist
()
all_proposal_lengths
=
proposals
.
proposal_lens
.
tolist
()
for
i
,
seq_group_metadata
in
enumerate
(
for
i
,
seq_group_metadata
in
enumerate
(
execute_model_req
.
seq_group_metadata_list
):
execute_model_req
.
seq_group_metadata_list
):
if
all_proposal_lengths
[
i
]
==
0
:
# Keep prompt seqs untouched (keep computed_tokens for chunks).
target_seq_group_metadata_list
.
append
(
seq_group_metadata
)
continue
seq_data_dict
=
seq_group_metadata
.
seq_data
seq_data_dict
=
seq_group_metadata
.
seq_data
assert
len
(
seq_data_dict
)
==
1
assert
len
(
seq_data_dict
)
==
1
seq_id
=
next
(
iter
(
seq_data_dict
.
keys
()))
seq_id
=
next
(
iter
(
seq_data_dict
.
keys
()))
...
@@ -40,8 +45,7 @@ class MQAScorer(SpeculativeScorer):
...
@@ -40,8 +45,7 @@ class MQAScorer(SpeculativeScorer):
new_seq_data
.
update_num_computed_tokens
(
new_seq_data
.
update_num_computed_tokens
(
len
(
prompt_token_ids
)
+
len
(
output_token_ids
)
-
1
)
len
(
prompt_token_ids
)
+
len
(
output_token_ids
)
-
1
)
# Ensure that the new sequence has at least one token
# Ensure that the new decode sequence has at least one token.
# because we only use mqa scorer in the decoding stage.
assert
len
(
output_token_ids
)
>=
1
assert
len
(
output_token_ids
)
>=
1
new_seq_data_dict
=
{
target_seq_id
:
new_seq_data
}
new_seq_data_dict
=
{
target_seq_id
:
new_seq_data
}
...
@@ -54,7 +58,6 @@ class MQAScorer(SpeculativeScorer):
...
@@ -54,7 +58,6 @@ class MQAScorer(SpeculativeScorer):
target_seq_id
:
seq_group_metadata
.
block_tables
[
seq_id
],
target_seq_id
:
seq_group_metadata
.
block_tables
[
seq_id
],
},
},
lora_request
=
None
,
lora_request
=
None
,
token_chunk_size
=
1
,
)
)
target_seq_group_metadata_list
.
append
(
new_seq_group_metadata
)
target_seq_group_metadata_list
.
append
(
new_seq_group_metadata
)
...
@@ -77,6 +80,7 @@ class MQAScorer(SpeculativeScorer):
...
@@ -77,6 +80,7 @@ class MQAScorer(SpeculativeScorer):
all_probs
=
target_probs
.
reshape
(
bs
,
k
+
1
,
self
.
_vocab_size
)
all_probs
=
target_probs
.
reshape
(
bs
,
k
+
1
,
self
.
_vocab_size
)
all_logprobs
=
target_logprobs
.
reshape
(
bs
,
k
+
1
,
self
.
_vocab_size
)
all_logprobs
=
target_logprobs
.
reshape
(
bs
,
k
+
1
,
self
.
_vocab_size
)
else
:
else
:
# We either have decodes with different lens or prefill+decodes.
all_tokens
=
target_token_ids
.
new_full
(
size
=
(
bs
,
k
+
1
),
all_tokens
=
target_token_ids
.
new_full
(
size
=
(
bs
,
k
+
1
),
fill_value
=-
1
)
fill_value
=-
1
)
all_probs
=
target_probs
.
new_zeros
(
*
all_tokens
.
shape
,
all_probs
=
target_probs
.
new_zeros
(
*
all_tokens
.
shape
,
...
@@ -85,15 +89,18 @@ class MQAScorer(SpeculativeScorer):
...
@@ -85,15 +89,18 @@ class MQAScorer(SpeculativeScorer):
fill_value
=-
float
(
"inf"
))
fill_value
=-
float
(
"inf"
))
target_token_ids
=
target_token_ids
.
flatten
()
target_token_ids
=
target_token_ids
.
flatten
()
start_loc
=
0
start_loc
=
0
for
i
,
proposed_len
in
enumerate
(
all_proposal_lengths
):
for
i
,
(
proposed_len
,
seq_meta
)
in
enumerate
(
output_len
=
proposed_len
+
1
zip
(
all_proposal_lengths
,
target_seq_group_metadata_list
)):
end_loc
=
start_loc
+
output_len
# Skip chunks with no output tokens.
all_tokens
[
if
seq_meta
.
do_sample
:
i
,
:
output_len
]
=
target_token_ids
[
start_loc
:
end_loc
]
output_len
=
proposed_len
+
1
all_probs
[
i
,
:
output_len
]
=
target_probs
[
start_loc
:
end_loc
]
end_loc
=
start_loc
+
output_len
all_logprobs
[
all_tokens
[
i
,
:
output_len
]
=
target_logprobs
[
start_loc
:
end_loc
]
i
,
:
output_len
]
=
target_token_ids
[
start_loc
:
end_loc
]
start_loc
=
end_loc
all_probs
[
i
,
:
output_len
]
=
target_probs
[
start_loc
:
end_loc
]
all_logprobs
[
i
,
:
output_len
]
=
target_logprobs
[
start_loc
:
end_loc
]
start_loc
=
end_loc
hidden_states
=
None
hidden_states
=
None
if
target_sampler_output
.
hidden_states
is
not
None
:
if
target_sampler_output
.
hidden_states
is
not
None
:
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
9d43afcc
...
@@ -418,7 +418,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -418,7 +418,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# none of the requests in the batch have spec decoding enabled.
# none of the requests in the batch have spec decoding enabled.
# In any of these cases, the proposer and scorer workers
# In any of these cases, the proposer and scorer workers
# are called normally.
# are called normally.
no_spec
=
num_lookahead_slots
==
0
or
disable_all_speculation
or
all
(
# We expect `num_speculative_tokens` to be None for prefills.
no_spec
=
all
(
sgm
.
is_prompt
for
sgm
in
execute_model_req
.
seq_group_metadata_list
)
or
num_lookahead_slots
==
0
or
disable_all_speculation
or
all
(
sgm
.
num_speculative_tokens
==
0
sgm
.
num_speculative_tokens
==
0
for
sgm
in
execute_model_req
.
seq_group_metadata_list
)
for
sgm
in
execute_model_req
.
seq_group_metadata_list
)
...
@@ -484,7 +487,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -484,7 +487,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
def
_serialize_sampler_output_no_logprobs
(
def
_serialize_sampler_output_no_logprobs
(
self
,
execute_model_req
:
ExecuteModelRequest
,
self
,
execute_model_req
:
ExecuteModelRequest
,
sampler_output
:
SamplerOutput
)
->
SamplerOutput
:
sampler_output
:
SamplerOutput
)
->
List
[
SamplerOutput
]
:
"""
"""
Creates and returns a `SamplerOutput` with only the token IDs being
Creates and returns a `SamplerOutput` with only the token IDs being
serialized to CPU and populated in `CompletionSequenceGroupOutput`.
serialized to CPU and populated in `CompletionSequenceGroupOutput`.
...
@@ -514,41 +517,56 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -514,41 +517,56 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
if
any
(
seq_output_prompt_logprobs
)
else
\
if
any
(
seq_output_prompt_logprobs
)
else
\
sampler_output
.
sampled_token_ids
).
tolist
()
sampler_output
.
sampled_token_ids
).
tolist
()
seq_data_entries
=
(
seq_data_entries
=
[
(
seq_id
,
seq_data
)
for
sg
in
\
(
seq_id
,
seq_data
)
for
sg
in
\
execute_model_req
.
seq_group_metadata_list
\
execute_model_req
.
seq_group_metadata_list
\
for
seq_id
,
seq_data
in
sg
.
seq_data
.
items
()
for
seq_id
,
seq_data
in
sg
.
seq_data
.
items
()
)
if
sg
.
do_sample
# ignore empty token sequences
]
completion_seq_group_output_list
:
List
[
completion_seq_group_output_list
:
List
[
CompletionSequenceGroupOutput
]
=
[]
CompletionSequenceGroupOutput
]
=
[]
for
index
,
((
seq_id
,
seq_data
),
needs_prompt_logprobs
)
in
\
output_index
=
0
enumerate
(
zip
(
seq_data_entries
,
seq_output_prompt_logprobs
)):
# Make sure the non-terminal prefill chunks are still aligned with
if
needs_prompt_logprobs
:
# their own empty output.
prompt_token_ids
=
seq_data
.
get_prompt_token_ids
()
for
seq_group_meta
in
execute_model_req
.
seq_group_metadata_list
:
prompt_logprobs
=
[
# Since we can get chunks here, we dont always have a sampled token
create_logprobs_output
(
# (only on last chunk) but we still have to provide an output.
token_id
=
p_token_id
,
if
not
seq_group_meta
.
do_sample
:
completion_seq_group_output_list
.
append
(
CompletionSequenceGroupOutput
(
samples
=
[],
prompt_logprobs
=
None
))
else
:
# Sequence with output.
seq_id
,
seq_data
=
seq_data_entries
[
output_index
]
needs_prompt_logprobs
=
seq_output_prompt_logprobs
[
output_index
]
if
needs_prompt_logprobs
:
prompt_token_ids
=
seq_data
.
get_prompt_token_ids
()
prompt_logprobs
=
[
create_logprobs_output
(
token_id
=
p_token_id
,
token_id_logprob_rank
=-
1
,
token_id_logprob
=
0.0
,
topk_token_ids
=
[],
topk_logprobs
=
[],
)
# no prompt logprobs for the first token
for
p_token_id
in
prompt_token_ids
[
1
:]
]
else
:
prompt_logprobs
=
None
completion_seq_group_output_list
.
append
(
create_sequence_group_output
(
token_id
=
sampled_token_ids_list
[
output_index
][
0
],
token_id_logprob_rank
=-
1
,
token_id_logprob_rank
=-
1
,
token_id_logprob
=
0.0
,
token_id_logprob
=
0.0
,
seq_id
=
seq_id
,
topk_token_ids
=
[],
topk_token_ids
=
[],
topk_logprobs
=
[],
topk_logprobs
=
[],
)
prompt_logprobs
=
prompt_logprobs
))
# no prompt logprobs for the first token
output_index
+=
1
for
p_token_id
in
prompt_token_ids
[
1
:]
]
return
[
SamplerOutput
(
outputs
=
completion_seq_group_output_list
)]
else
:
prompt_logprobs
=
None
completion_seq_group_output_list
.
append
(
create_sequence_group_output
(
token_id
=
sampled_token_ids_list
[
index
][
0
],
token_id_logprob_rank
=-
1
,
token_id_logprob
=
0.0
,
seq_id
=
seq_id
,
topk_token_ids
=
[],
topk_logprobs
=
[],
prompt_logprobs
=
prompt_logprobs
))
return
SamplerOutput
(
outputs
=
completion_seq_group_output_list
)
@
nvtx_range
(
"spec_decode_worker._run_no_spec"
)
@
nvtx_range
(
"spec_decode_worker._run_no_spec"
)
def
_run_no_spec
(
self
,
execute_model_req
:
ExecuteModelRequest
,
def
_run_no_spec
(
self
,
execute_model_req
:
ExecuteModelRequest
,
...
@@ -568,6 +586,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -568,6 +586,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
hidden_states
=
sampler_output
.
hidden_states
hidden_states
=
sampler_output
.
hidden_states
if
hidden_states
is
not
None
:
if
hidden_states
is
not
None
:
# remove hidden_states for prompt tokens
# remove hidden_states for prompt tokens
# TODO Enable `return_hidden_states`: prefill chunks hidden states
# are pruned by the logits processor. Also, they should be arranged
# back into full-prefill latent. Address it to enable MLPSpeculator.
if
any
(
seq
.
is_prompt
if
any
(
seq
.
is_prompt
for
seq
in
execute_model_req
.
seq_group_metadata_list
):
for
seq
in
execute_model_req
.
seq_group_metadata_list
):
hidden_states
=
hidden_states
[
hidden_states
=
hidden_states
[
...
@@ -593,14 +614,14 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -593,14 +614,14 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
sampler_output_to_return
=
(
self
.
_serialize_sampler_output_no_logprobs
(
sampler_output_to_return
=
(
self
.
_serialize_sampler_output_no_logprobs
(
execute_model_req
=
execute_model_req
,
sampler_output
=
sampler_output
)
execute_model_req
=
execute_model_req
,
sampler_output
=
sampler_output
)
if
self
.
_disable_logprobs
else
if
self
.
_disable_logprobs
else
sampler_output
)
[
sampler_output
]
)
# Clear device tensors from sampler output. This reduces communication
# Clear device tensors from sampler output. This reduces communication
# overhead when the engine runs in a different process than the workers.
# overhead when the engine runs in a different process than the workers.
sampler_output
.
sampled_token_probs
=
None
sampler_output
.
sampled_token_probs
=
None
sampler_output
.
sampled_token_ids
=
None
sampler_output
.
sampled_token_ids
=
None
sampler_output
.
logprobs
=
None
sampler_output
.
logprobs
=
None
return
[
sampler_output_to_return
]
return
sampler_output_to_return
def
_run_non_driver_rank
(
self
)
->
bool
:
def
_run_non_driver_rank
(
self
)
->
bool
:
"""Run proposer and verifier model in non-driver workers. This is used
"""Run proposer and verifier model in non-driver workers. This is used
...
@@ -644,9 +665,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -644,9 +665,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
This invokes the proposer worker to get k speculative tokens for each
This invokes the proposer worker to get k speculative tokens for each
sequence, then scores each speculative token using the scoring worker.
sequence, then scores each speculative token using the scoring worker.
When `enable_chunked_prefill` is set, scorer will batch decodes and
prefills, while proposer will sync its KV-cache by running an extra
forward on prefills.
Returns a list of SamplerOutput, each containing a single token per
Returns a list of SamplerOutput, each containing a single token per
sequence.
sequence.
"""
"""
# With prefill chunking, expect requests to have prompts first
# so that backend gets prefill|decode.
assert
num_lookahead_slots
==
execute_model_req
.
num_lookahead_slots
assert
num_lookahead_slots
==
execute_model_req
.
num_lookahead_slots
# Pass last hidden states from target model to proposer
# Pass last hidden states from target model to proposer
...
@@ -671,6 +698,25 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -671,6 +698,25 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposals
,
proposals
,
)
)
_
,
(
non_spec_seqs
,
non_spec_indices
)
=
split_batch_by_proposal_len
(
execute_model_req
.
seq_group_metadata_list
,
proposals
.
proposal_lens
)
# With prefill chunking enabled, `non_spec_seqs` contains prefills too:
# discard decodes that have already been processed by proposer.
non_spec_indices
=
[
idx
for
idx
in
non_spec_indices
if
execute_model_req
.
seq_group_metadata_list
[
idx
].
is_prompt
]
if
len
(
non_spec_indices
):
all_hidden_states
=
proposal_scores
.
hidden_states
# TODO fix `return_hidden_states`, same as in `_run_no_spec`
if
all_hidden_states
is
not
None
:
prefill_hidden_states
=
all_hidden_states
[
non_spec_indices
]
execute_model_req
.
previous_hidden_states
=
\
prepare_prefill_hidden_states
(
prefill_hidden_states
)
# Sync proposer KV cache for prefills.
prefill_req
=
execute_model_req
.
clone
(
non_spec_seqs
)
self
.
proposer_worker
.
execute_model
(
prefill_req
)
with
Timer
()
as
verification_timer
:
with
Timer
()
as
verification_timer
:
accepted_token_ids
,
target_logprobs
=
self
.
_verify_tokens
(
accepted_token_ids
,
target_logprobs
=
self
.
_verify_tokens
(
execute_model_req
.
seq_group_metadata_list
,
proposal_scores
,
execute_model_req
.
seq_group_metadata_list
,
proposal_scores
,
...
@@ -769,7 +815,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -769,7 +815,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self
.
previous_hidden_states
=
HiddenStates
(
self
.
previous_hidden_states
=
HiddenStates
(
hidden_states
,
seq_group_metadata_list
,
hidden_states
,
seq_group_metadata_list
,
second_last_token_hidden_states
)
second_last_token_hidden_states
)
return
accepted_token_ids
,
logprobs
return
accepted_token_ids
,
logprobs
def
_create_output_sampler_list
(
def
_create_output_sampler_list
(
...
@@ -819,6 +864,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -819,6 +864,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
accepted_token_ids_by_step
=
accepted_token_ids_by_step
.
tolist
()
accepted_token_ids_by_step
=
accepted_token_ids_by_step
.
tolist
()
# Construct the output on a per-step, per-sequence basis.
# Construct the output on a per-step, per-sequence basis.
# Non-terminal prefill chunks will end up here as rows with just -1s
# i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]]
sampler_output_list
:
List
[
SamplerOutput
]
=
[]
sampler_output_list
:
List
[
SamplerOutput
]
=
[]
for
step_index
in
range
(
num_steps
):
for
step_index
in
range
(
num_steps
):
if
all
(
token_id
==
-
1
if
all
(
token_id
==
-
1
...
@@ -861,7 +908,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -861,7 +908,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# This is periodic because the rejection sampler emits metrics
# This is periodic because the rejection sampler emits metrics
# periodically.
# periodically.
self
.
_maybe_log_stage_times
(
*
stage_times
)
self
.
_maybe_log_stage_times
(
*
stage_times
)
return
sampler_output_list
return
sampler_output_list
def
_maybe_log_stage_times
(
self
,
average_time_per_proposal_tok_ms
:
float
,
def
_maybe_log_stage_times
(
self
,
average_time_per_proposal_tok_ms
:
float
,
...
...
vllm/spec_decode/top1_proposer.py
View file @
9d43afcc
...
@@ -109,7 +109,6 @@ class Top1Proposer(SpeculativeProposer):
...
@@ -109,7 +109,6 @@ class Top1Proposer(SpeculativeProposer):
proposal_probs
=
proposal_probs
,
proposal_probs
=
proposal_probs
,
proposal_lens
=
proposal_lens
,
proposal_lens
=
proposal_lens
,
no_proposals
=
maybe_sampler_output
is
None
)
no_proposals
=
maybe_sampler_output
is
None
)
return
proposals
return
proposals
def
_split_by_proposal_len
(
def
_split_by_proposal_len
(
...
@@ -127,9 +126,10 @@ class Top1Proposer(SpeculativeProposer):
...
@@ -127,9 +126,10 @@ class Top1Proposer(SpeculativeProposer):
nonzero_proposal_len_seqs
:
List
[
SequenceGroupMetadata
]
=
[]
nonzero_proposal_len_seqs
:
List
[
SequenceGroupMetadata
]
=
[]
nonzero_proposal_len_indices
:
List
[
int
]
=
[]
nonzero_proposal_len_indices
:
List
[
int
]
=
[]
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
# The speculative decoding for this request has been disabled
# The speculative decoding for this request has either been disabled
# (e.g. due to high traffic).
# (e.g. due to high traffic) or this is a prompt request.
if
seq_group_metadata
.
num_speculative_tokens
==
0
:
if
(
seq_group_metadata
.
is_prompt
or
seq_group_metadata
.
num_speculative_tokens
==
0
):
proposal_lens
.
append
(
0
)
proposal_lens
.
append
(
0
)
continue
continue
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment