Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
62b8aebc
Unverified
Commit
62b8aebc
authored
Apr 23, 2024
by
Cade Daniel
Committed by
GitHub
Apr 23, 2024
Browse files
[Speculative decoding 7/9] Speculative decoding end-to-end correctness tests. (#3951)
parent
050f285f
Changes
22
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1111 additions
and
150 deletions
+1111
-150
tests/samplers/test_rejection_sampler.py
tests/samplers/test_rejection_sampler.py
+6
-2
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+2
-1
tests/spec_decode/e2e/__init__.py
tests/spec_decode/e2e/__init__.py
+0
-0
tests/spec_decode/e2e/conftest.py
tests/spec_decode/e2e/conftest.py
+35
-10
tests/spec_decode/e2e/test_compatibility.py
tests/spec_decode/e2e/test_compatibility.py
+169
-0
tests/spec_decode/e2e/test_correctness.py
tests/spec_decode/e2e/test_correctness.py
+493
-47
tests/spec_decode/test_metrics.py
tests/spec_decode/test_metrics.py
+2
-2
tests/spec_decode/test_multi_step_worker.py
tests/spec_decode/test_multi_step_worker.py
+2
-2
tests/spec_decode/test_spec_decode_worker.py
tests/spec_decode/test_spec_decode_worker.py
+23
-17
tests/spec_decode/utils.py
tests/spec_decode/utils.py
+5
-2
vllm/config.py
vllm/config.py
+65
-2
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+15
-3
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+30
-8
vllm/engine/metrics.py
vllm/engine/metrics.py
+22
-1
vllm/executor/gpu_executor.py
vllm/executor/gpu_executor.py
+1
-0
vllm/model_executor/layers/rejection_sampler.py
vllm/model_executor/layers/rejection_sampler.py
+7
-0
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+167
-15
vllm/spec_decode/batch_expansion.py
vllm/spec_decode/batch_expansion.py
+42
-28
vllm/spec_decode/interfaces.py
vllm/spec_decode/interfaces.py
+2
-2
vllm/spec_decode/metrics.py
vllm/spec_decode/metrics.py
+23
-8
No files found.
tests/samplers/test_rejection_sampler.py
View file @
62b8aebc
...
...
@@ -91,12 +91,16 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
bonus_token_ids
,
)
# Bonus tokens are currently disabled. Verify they're set to -1.
# See https://github.com/vllm-project/vllm/issues/4212
expected_bonus_token_ids
=
bonus_token_ids
.
clone
()
*
0
-
1
if
which_tokens_accepted
==
"all_tokens_accepted"
:
# Expect all tokens to be equal to draft tokens.
assert
torch
.
equal
(
output_token_ids
[:,
:
-
1
],
draft_token_ids
)
# Expect all bonus tokens to be included.
assert
torch
.
equal
(
output_token_ids
[:,
-
1
:],
bonus_token_ids
)
assert
torch
.
equal
(
output_token_ids
[:,
-
1
:],
expected_
bonus_token_ids
)
elif
which_tokens_accepted
==
"no_tokens_accepted"
:
# Expect first token to be equal to recovered tokens.
assert
torch
.
equal
(
output_token_ids
[:,
0
],
recovered_token_ids
[:,
0
])
...
...
@@ -106,7 +110,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
torch
.
ones_like
(
output_token_ids
[:,
1
:])
*
-
1
)
elif
which_tokens_accepted
==
"some_tokens_accepted"
:
recovered_plus_bonus
=
torch
.
cat
(
(
recovered_token_ids
,
bonus_token_ids
),
dim
=-
1
)
(
recovered_token_ids
,
expected_
bonus_token_ids
),
dim
=-
1
)
# Assert first rejected token is a recovered token or bonus token.
assert
torch
.
equal
(
recovered_plus_bonus
[
torch
.
arange
(
0
,
batch_size
),
...
...
tests/samplers/test_sampler.py
View file @
62b8aebc
...
...
@@ -636,7 +636,8 @@ def test_sampler_top_k_top_p(seed: int, device: str):
def
mock_sample
(
probs
,
*
args
,
**
kwargs
):
nonlocal
sample_probs
sample_probs
=
probs
return
[[
prob
.
topk
(
1
,
dim
=-
1
).
indices
.
tolist
(),
[
0
]]
for
prob
in
probs
]
return
([[
prob
.
topk
(
1
,
dim
=-
1
).
indices
.
tolist
(),
[
0
]]
for
prob
in
probs
],
None
)
with
patch
(
"vllm.model_executor.layers.sampler._sample"
,
mock_sample
):
sampler
(
logits
=
fake_logits
,
sampling_metadata
=
sampling_metadata
)
...
...
tests/spec_decode/e2e/__init__.py
0 → 100644
View file @
62b8aebc
tests/spec_decode/e2e/conftest.py
View file @
62b8aebc
from
typing
import
List
,
Tuple
import
pytest
from
tests.conftest
import
cleanup
...
...
@@ -6,28 +8,34 @@ from vllm.model_executor.utils import set_random_seed
@
pytest
.
fixture
def
baseline_llm_generator
(
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
seed
):
return
create_llm_generator
(
common_llm_kwargs
,
per_test_common_llm_kwargs
,
def
baseline_llm_generator
(
request
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
seed
):
return
create_llm_generator
(
"baseline"
,
request
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
seed
)
@
pytest
.
fixture
def
test_llm_generator
(
common_llm_kwargs
,
per_test_common_llm_kwargs
,
def
test_llm_generator
(
request
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
test_llm_kwargs
,
seed
):
return
create_llm_generator
(
common_llm_kwargs
,
per_test_common_llm_kwargs
,
test_llm_kwargs
,
seed
)
return
create_llm_generator
(
"test"
,
request
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
test_llm_kwargs
,
seed
)
def
create_llm_generator
(
common_llm_kwargs
,
per_test_common_llm_kwargs
,
distinct_llm_kwargs
,
seed
):
def
create_llm_generator
(
baseline_or_test
,
request
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
distinct_llm_kwargs
,
seed
):
kwargs
=
{
**
common_llm_kwargs
,
**
per_test_common_llm_kwargs
,
**
distinct_llm_kwargs
,
}
test_name
=
request
.
node
.
name
def
generator_inner
():
print
(
f
'Creating
{
baseline_or_test
=
}
LLM for
{
test_name
=
}
.
{
kwargs
=
}
'
)
llm
=
LLM
(
**
kwargs
)
set_random_seed
(
seed
)
...
...
@@ -36,6 +44,23 @@ def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
del
llm
cleanup
()
def
generator_outer
():
for
llm
in
generator_inner
():
yield
llm
del
llm
return
generator_outer
def
get_output_from_llm_generator
(
llm_generator
,
prompts
,
sampling_params
)
->
Tuple
[
List
[
str
],
List
[
List
[
int
]]]:
tokens
=
[]
token_ids
=
[]
for
llm
in
llm_generator
():
outputs
=
llm
.
generate
(
prompts
,
sampling_params
,
use_tqdm
=
True
)
token_ids
=
[
output
.
outputs
[
0
].
token_ids
for
output
in
outputs
]
tokens
=
[
output
.
outputs
[
0
].
text
for
output
in
outputs
]
del
llm
return
tokens
,
token_ids
tests/spec_decode/e2e/test_compatibility.py
0 → 100644
View file @
62b8aebc
import
pytest
from
vllm
import
SamplingParams
from
.conftest
import
get_output_from_llm_generator
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
# Required for spec decode.
"use_v2_block_manager"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
{
# Expect failure as spec decode not supported by
# Ray backend.
"worker_use_ray"
:
True
,
},
])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_spec_decode_xfail_ray
(
test_llm_generator
):
"""Verify that speculative decoding with Ray fails.
"""
output_len
=
128
temperature
=
0.0
prompts
=
[
"Hello, my name is"
,
]
sampling_params
=
SamplingParams
(
max_tokens
=
output_len
,
ignore_eos
=
True
,
temperature
=
temperature
,
)
with
pytest
.
raises
(
AssertionError
,
match
=
"Speculative decoding not yet supported for "
):
get_output_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
# Required for spec decode.
"use_v2_block_manager"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
{
"enable_chunked_prefill"
:
True
,
},
])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_spec_decode_xfail_chunked_prefill
(
test_llm_generator
):
"""Verify that speculative decoding with chunked prefill fails.
"""
output_len
=
128
temperature
=
0.0
prompts
=
[
"Hello, my name is"
,
]
sampling_params
=
SamplingParams
(
max_tokens
=
output_len
,
ignore_eos
=
True
,
temperature
=
temperature
,
)
with
pytest
.
raises
(
ValueError
,
match
=
"Speculative decoding and chunked prefill"
):
get_output_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"meta-llama/Llama-2-7b-chat-hf"
,
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
# Required for spec decode.
"use_v2_block_manager"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
{
# Speculative max model len > overridden max model len should raise.
"max_model_len"
:
128
,
"speculative_max_model_len"
:
129
,
},
{
# Speculative max model len > draft max model len should raise.
# https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12
"speculative_max_model_len"
:
2048
+
1
,
},
{
# Speculative max model len > target max model len should raise.
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/f5db02db724555f92da89c216ac04704f23d4590/config.json#L12
"speculative_max_model_len"
:
4096
+
1
,
},
])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_spec_decode_xfail_spec_max_model_len
(
test_llm_generator
):
"""Verify that speculative decoding validates speculative_max_model_len.
"""
output_len
=
128
temperature
=
0.0
prompts
=
[
"Hello, my name is"
,
]
sampling_params
=
SamplingParams
(
max_tokens
=
output_len
,
ignore_eos
=
True
,
temperature
=
temperature
,
)
with
pytest
.
raises
(
ValueError
,
match
=
"cannot be larger than"
):
get_output_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_spec_decode_xfail_block_manager_v1
(
test_llm_generator
):
"""Verify that speculative decoding with block manager v1 fails.
"""
output_len
=
128
temperature
=
0.0
prompts
=
[
"Hello, my name is"
,
]
sampling_params
=
SamplingParams
(
max_tokens
=
output_len
,
ignore_eos
=
True
,
temperature
=
temperature
,
)
with
pytest
.
raises
(
ValueError
,
match
=
"Speculative decoding requires usage of the V2"
):
get_output_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
tests/spec_decode/e2e/test_correctness.py
View file @
62b8aebc
"""The tests in this file verify end-to-end speculative decoding correctness.
This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality. This gives us good coverage of temp=0.
For temp>0, we rely on unit tests on the rejection sampler to verify that the
output distribution is the same with spec decode vs. no spec decode (this would
be prohibitively expensive to run with a real model).
NOTE: Speculative decoding's distribution equality requires that the measured
distributions of the target model and proposal model be deterministic given the
same input. vLLM largely guarantees this.
@cadedaniel has seen cases where the output probabilities of a draft/target
model change slightly with certain batch sizes or prompts, even with Torch
determinism flags set. It is unclear if this is a bug in vLLM, due to non-
determinism in on-device batched operations, a bug in vLLM's spec decode
implementation, or the "hardware numerics" limitations. Either way, rejection
sampling ensures the output distribution matches the target model, but it breaks
greedy-equality tests for those batch sizes/prompts.
"""
from
itertools
import
cycle
from
typing
import
List
,
Tuple
import
pytest
from
transformers
import
AutoTokenizer
from
vllm
import
SamplingParams
from
.conftest
import
get_output_from_llm_generator
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
...
...
@@ -14,9 +45,6 @@ from vllm import SamplingParams
# Note this is repeated in the test body; to initialize a tokenizer.
"model"
:
"JackFram/llama-68m"
,
# Skip real loading for fast test.
"load_format"
:
"dummy"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
...
...
@@ -31,22 +59,15 @@ from vllm import SamplingParams
"num_speculative_tokens"
:
5
,
},
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
1
,
},
{
# No spec decode.
# Verify the detokenizer assertions in the test work when spec
# decode is disabled.
},
])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
])
# NOTE: We should run more permutations of this test (more BS, more seeds). But
# because our spec decode generates gibberish token ids, the likelihood of
# emitting an invalid token combination is nontrivial. This causes divergence in
# behavior of vLLM detokenization vs. hf tokenizer, for example when two "utf-
# start" bytes are emitted.
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_spec_decode_e2e_logical_flow
(
test_llm_generator
,
batch_size
:
int
):
def
test_spec_decode_e2e_with_detokenization
(
test_llm_generator
,
batch_size
:
int
):
"""Run generation with speculative decoding on a batch. Verify the engine
generates the correct number of tokens (via ignore_eos=True), and that the
detokenization matches HF transformers.
...
...
@@ -67,8 +88,6 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int):
max_tokens
=
output_len
,
ignore_eos
=
True
,
temperature
=
temperature
,
skip_special_tokens
=
True
,
spaces_between_special_tokens
=
False
,
)
batch_tokens
,
batch_token_ids
=
get_output_from_llm_generator
(
...
...
@@ -77,9 +96,10 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int):
# Expect a generation for each prompt in the batch.
assert
len
(
batch_token_ids
)
==
len
(
prompts
)
# Expect each generation to have expected number of tokens (note
# ignore_eos=True).
assert
all
(
len
(
token_ids
)
==
output_len
for
token_ids
in
batch_token_ids
)
# Expect each generation to have expected number of tokens (note ignore_eos
# is True).
assert
[
len
(
token_ids
)
for
token_ids
in
batch_token_ids
]
==
([
output_len
]
*
batch_size
)
# Expect detokenized string to match.
tok
=
AutoTokenizer
.
from_pretrained
(
"JackFram/llama-68m"
)
...
...
@@ -92,14 +112,111 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int):
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Use a small model for a fast test.
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
# Print spec metrics.
"disable_log_stats"
:
False
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
# Try two different tiny base models.
# Note that one is equal to the draft model, another isn't.
{
"model"
:
"JackFram/llama-68m"
,
},
{
"model"
:
"JackFram/llama-160m"
,
},
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use long output len for the small model test.
1536
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_spec_decode_e2e_greedy_correctness_tiny_model_bs1
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify greedy equality on a tiny model with batch size of one.
Since this test is cheaper than other e2e correctness tests, we generate
with a higher output_len.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
# Skip real loading for fast test.
"load_format"
:
"dummy"
,
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
# Print spec metrics.
"disable_log_stats"
:
False
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
# Try two different tiny base models.
# Note that one is equal to the draft model, another isn't.
{
"model"
:
"JackFram/llama-68m"
,
},
{
"model"
:
"JackFram/llama-160m"
,
},
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use small output len for fast test.
256
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify greedy equality on a tiny model and large batch size.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
...
...
@@ -109,43 +226,372 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int):
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
# Try two different tiny base models.
# Note that one is equal to the draft model, another isn't.
{
# Expect failure as spec decode not supported by
# Ray backend.
"worker_use_ray"
:
True
,
"model"
:
"JackFram/llama-68m"
,
},
{
"model"
:
"JackFram/llama-160m"
,
},
])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
])
@
pytest
.
mark
.
parametrize
(
"max_output_len"
,
[
256
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
max_output_len
:
int
):
"""Verify greedy equality on a tiny model, with a large batch size, and when
sampling respects the EOS token.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
,
force_output_len
=
False
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# A "real" model (not tiny).
"model"
:
"meta-llama/Llama-2-7b-chat-hf"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
# Print spec metrics.
"disable_log_stats"
:
False
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use decently long output len for a high quality test.
256
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_spec_decode_e2e_greedy_correctness_real_model_bs1
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify greedy equality on a "real" model and batch size of 1. This is
separate from large BS tests to make identifying the source of bugs easier.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# A "real" model (not tiny).
"model"
:
"meta-llama/Llama-2-7b-chat-hf"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
# Print spec metrics.
"disable_log_stats"
:
False
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
64
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_spec_decode_e2e_greedy_correctness_real_model_large_bs
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify greedy equality with a "real" model on a nontrivial batch size.
This is the closest test to a real production workload.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"block_size"
:
8
,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override"
:
2
+
256
//
8
,
"max_model_len"
:
(
2
+
256
//
8
)
*
8
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
{
"model"
:
"JackFram/llama-160m"
,
},
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use small output len for fast test.
256
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_spec_decode_xfail
(
test_llm_generator
):
"""Verify that speculative decoding with Ray fails.
def
test_spec_decode_e2e_greedy_correctness_with_preemption
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-160m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
# As of this writing, vLLM only compiles with these 3 block sizes by
# default.
{
"block_size"
:
8
,
},
{
"block_size"
:
16
,
},
{
"block_size"
:
32
,
},
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_spec_decode_different_block_size
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify greedy equality over different block sizes.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-160m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_max_model_len"
:
32
,
},
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# This must be a good bit larger than speculative_max_model_len so that
# we can test the case where all seqs are skipped, but still small to
# ensure fast test.
64
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_skip_speculation
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify greedy equality when some (or all) sequences skip speculation.
We do this by setting the max model len of the draft model to an
artificially low value, such that when the sequences grow beyond it, they
are skipped in speculative decoding.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
k
,
}
# Try a range of common k, as well as large speculation.
for
k
in
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
63
]
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_many_k
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify that speculative decoding produces exact equality to without spec
decode with many different values of k.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
def
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
,
force_output_len
:
bool
,
print_tokens
:
bool
=
False
):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero.
"""
output_len
=
128
temperature
=
0.0
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
"San Francisco is know for its"
,
"Facebook was created in 2004 by"
,
"Curious George is a"
,
"Python 3.11 brings improvements to its"
,
]
prompts
=
[
prompt
for
prompt
,
_
in
zip
(
cycle
(
prompts
),
range
(
batch_size
))]
# If the test requires that we generated max_output_len tokens, then set the
# sampling params to ignore eos token.
ignore_eos
=
force_output_len
sampling_params
=
SamplingParams
(
max_tokens
=
output_len
,
ignore_eos
=
True
,
max_tokens
=
max_
output_len
,
ignore_eos
=
ignore_eos
,
temperature
=
temperature
,
)
with
pytest
.
raises
(
AssertionError
,
match
=
"Speculative decoding not yet supported for "
):
get_output_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
spec_batch_tokens
,
spec_batch_token_ids
=
get_output_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
(
baseline_batch_tokens
,
baseline_batch_token_ids
)
=
get_output_from_llm_generator
(
baseline_llm_generator
,
prompts
,
sampling_params
)
def
get_output_from_llm_generator
(
llm_generator
,
prompts
,
sampling_params
)
->
Tuple
[
List
[
str
],
List
[
List
[
int
]]]:
for
llm
in
llm_generator
:
outputs
=
llm
.
generate
(
prompts
,
sampling_params
,
use_tqdm
=
True
)
token_ids
=
[
output
.
outputs
[
0
].
token_ids
for
output
in
outputs
]
tokens
=
[
output
.
outputs
[
0
].
text
for
output
in
outputs
]
del
llm
assert
len
(
baseline_batch_token_ids
)
==
len
(
prompts
)
assert
len
(
spec_batch_token_ids
)
==
len
(
prompts
)
return
tokens
,
token_ids
for
i
,
(
baseline_token_ids
,
baseline_tokens
,
spec_token_ids
,
spec_tokens
)
in
enumerate
(
zip
(
baseline_batch_token_ids
,
baseline_batch_tokens
,
spec_batch_token_ids
,
spec_batch_tokens
)):
if
print_tokens
:
print
(
f
'
{
i
=
}
{
baseline_tokens
=
}
'
)
print
(
f
'
{
i
=
}
{
spec_tokens
=
}
'
)
print
(
f
'
{
i
=
}
{
baseline_token_ids
=
}
'
)
print
(
f
'
{
i
=
}
{
spec_token_ids
=
}
'
)
assert
baseline_token_ids
==
spec_token_ids
tests/spec_decode/test_metrics.py
View file @
62b8aebc
...
...
@@ -119,7 +119,7 @@ def test_initial_metrics_has_correct_values(has_data: bool):
num_draft_tokens
=
0
k
=
5
num_possible
_tokens
=
AsyncMetricsCollector
.
get_max_num_
accep
ted_tokens
(
max_num_emitted
_tokens
=
AsyncMetricsCollector
.
get_max_num_
emit
ted_tokens
(
num_draft_tokens
,
k
)
rej_sampler
=
MagicMock
()
...
...
@@ -153,7 +153,7 @@ def test_initial_metrics_has_correct_values(has_data: bool):
assert
(
metrics
.
draft_acceptance_rate
==
num_accepted_tokens
/
num_draft_tokens
)
assert
(
metrics
.
system_efficiency
==
num_emitted_tokens
/
num_possible
_tokens
)
max_num_emitted
_tokens
)
else
:
assert
math
.
isnan
(
metrics
.
draft_acceptance_rate
)
assert
math
.
isnan
(
metrics
.
system_efficiency
)
tests/spec_decode/test_multi_step_worker.py
View file @
62b8aebc
...
...
@@ -344,8 +344,8 @@ def test_draft_proposals_no_speculations():
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
proposals
.
proposal_token_ids
.
shape
==
torch
.
Size
([
0
,
k
])
assert
proposals
.
proposal_probs
.
shape
[:
-
1
]
==
torch
.
Size
([
0
,
k
])
assert
proposals
.
proposal_token_ids
.
shape
==
torch
.
Size
([
batch_size
,
k
])
assert
proposals
.
proposal_probs
.
shape
[:
-
1
]
==
torch
.
Size
([
batch_size
,
k
])
assert
proposals
.
proposal_lens
.
shape
==
torch
.
Size
([
batch_size
])
assert
proposals
.
proposal_lens
.
tolist
()
==
[
0
for
_
in
range
(
batch_size
)]
...
...
tests/spec_decode/test_spec_decode_worker.py
View file @
62b8aebc
import
random
from
types
import
SimpleNamespace
from
unittest.mock
import
MagicMock
import
pytest
...
...
@@ -62,8 +63,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
"""Verify SpecDecodeWorker calls the target model with correct
inputs. Everything else is mocked out.
"""
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
target_worker
=
mock_worker
()
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
,
use_spec
=
False
)
target_worker
=
mock_worker
(
use_spec
=
False
)
rejection_sampler
=
MagicMock
(
spec
=
RejectionSampler
)
rejection_sampler
.
token_id_dtype
=
torch
.
int64
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
...
...
@@ -144,8 +145,10 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
"""
vocab_size
=
32_000
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
,
vocab_size
=
vocab_size
)
target_worker
=
mock_worker
(
vocab_size
=
vocab_size
)
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
,
vocab_size
=
vocab_size
,
use_spec
=
False
)
target_worker
=
mock_worker
(
vocab_size
=
vocab_size
,
use_spec
=
False
)
rejection_sampler
=
MagicMock
(
spec
=
RejectionSampler
)
rejection_sampler
.
token_id_dtype
=
torch
.
int64
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
...
...
@@ -202,17 +205,16 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
num_lookahead_slots
=
k
)
assert
len
(
rejection_sampler
.
call_args_list
)
==
1
args
,
_
=
rejection_sampler
.
call_args_list
[
0
]
(
actual_proposal_scores
,
actual_bonus_token_ids
,
actual_proposal_probs
,
actual_proposal_token_ids
)
=
args
_
,
kwargs
=
rejection_sampler
.
call_args_list
[
0
]
actual
=
SimpleNamespace
(
**
kwargs
)
assert
torch
.
equal
(
actual
_
bonus_token_ids
,
assert
torch
.
equal
(
actual
.
bonus_token_ids
,
target_token_ids
.
reshape
(
batch_size
,
k
+
1
)[:,
-
1
:])
assert
torch
.
equal
(
actual
_proposal_score
s
,
actual
.
target_prob
s
,
target_token_probs
.
reshape
(
batch_size
,
k
+
1
,
-
1
)[:,
:
-
1
])
assert
torch
.
equal
(
actual
_proposal
_token_ids
,
proposal_token_ids
)
assert
torch
.
equal
(
actual
_proposal
_probs
,
proposal_probs
)
assert
torch
.
equal
(
actual
.
draft
_token_ids
,
proposal_token_ids
)
assert
torch
.
equal
(
actual
.
draft
_probs
,
proposal_probs
)
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
6
])
...
...
@@ -224,8 +226,10 @@ def test_correctly_formats_output(k: int, batch_size: int):
"""
vocab_size
=
32_000
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
,
vocab_size
=
vocab_size
)
target_worker
=
mock_worker
(
vocab_size
=
vocab_size
)
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
,
vocab_size
=
vocab_size
,
use_spec
=
False
)
target_worker
=
mock_worker
(
vocab_size
=
vocab_size
,
use_spec
=
False
)
rejection_sampler
=
MagicMock
(
spec
=
RejectionSampler
)
rejection_sampler
.
token_id_dtype
=
torch
.
int64
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
...
...
@@ -336,8 +340,10 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
"""
vocab_size
=
32_000
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
,
vocab_size
=
vocab_size
)
target_worker
=
mock_worker
(
vocab_size
=
vocab_size
)
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
,
vocab_size
=
vocab_size
,
use_spec
=
False
)
target_worker
=
mock_worker
(
vocab_size
=
vocab_size
,
use_spec
=
False
)
rejection_sampler
=
MagicMock
(
spec
=
RejectionSampler
)
rejection_sampler
.
token_id_dtype
=
torch
.
int64
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
...
...
@@ -500,8 +506,8 @@ def test_init_device():
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
well as other GPU initialization.
"""
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
target_worker
=
mock_worker
()
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
,
use_spec
=
False
)
target_worker
=
mock_worker
(
use_spec
=
False
)
rejection_sampler
=
MagicMock
(
spec
=
RejectionSampler
)
rejection_sampler
.
token_id_dtype
=
torch
.
int64
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
...
...
tests/spec_decode/utils.py
View file @
62b8aebc
...
...
@@ -63,11 +63,14 @@ def create_execute_model_data(
def
mock_worker
(
cls
=
None
,
vocab_size
:
int
=
30_000
,
max_model_len
:
int
=
2048
,
rank
:
int
=
0
)
->
MagicMock
:
rank
:
int
=
0
,
use_spec
:
bool
=
True
)
->
MagicMock
:
if
cls
is
None
:
cls
=
Worker
worker
=
MagicMock
(
spec
=
cls
)
spec
=
cls
if
use_spec
else
None
worker
=
MagicMock
(
spec
=
spec
)
worker
.
vocab_size
=
vocab_size
worker
.
max_model_len
=
max_model_len
worker
.
rank
=
rank
...
...
vllm/config.py
View file @
62b8aebc
...
...
@@ -655,6 +655,9 @@ class SpeculativeConfig:
target_dtype
:
str
,
speculative_model
:
Optional
[
str
],
num_speculative_tokens
:
Optional
[
int
],
speculative_max_model_len
:
Optional
[
int
],
enable_chunked_prefill
:
bool
,
use_v2_block_manager
:
bool
,
)
->
Optional
[
"SpeculativeConfig"
]:
"""Create a SpeculativeConfig if possible, else return None.
...
...
@@ -672,6 +675,15 @@ class SpeculativeConfig:
model, if provided.
num_speculative_tokens (Optional[int]): The number of speculative
tokens, if provided.
speculative_max_model_len (Optional[int]): The maximum model len of
the speculative model. Used when testing the ability to skip
speculation for some sequences.
enable_chunked_prefill (bool): Whether vLLM is configured to use
chunked prefill or not. Used for raising an error since its not
yet compatible with spec decode.
use_v2_block_manager (bool): Whether vLLM is configured to use the
v2 block manager or not. Used for raising an error since the v2
block manager is required with spec decode.
Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
...
...
@@ -690,12 +702,21 @@ class SpeculativeConfig:
assert
(
speculative_model
is
not
None
and
num_speculative_tokens
is
not
None
)
if
enable_chunked_prefill
:
raise
ValueError
(
"Speculative decoding and chunked prefill are "
f
"currently mutually exclusive (
{
enable_chunked_prefill
=
}
)."
)
if
not
use_v2_block_manager
:
raise
ValueError
(
"Speculative decoding requires usage of the V2 "
"block manager. Enable it with --use-v2-block-manager."
)
# TODO: The user should be able to specify revision/quantization/max
# model len for the draft model. It is not currently supported.
draft_revision
=
None
draft_code_revision
=
None
draft_quantization
=
None
draft_max_model_len
=
None
draft_model_config
=
ModelConfig
(
model
=
speculative_model
,
...
...
@@ -707,7 +728,7 @@ class SpeculativeConfig:
revision
=
draft_revision
,
code_revision
=
draft_code_revision
,
tokenizer_revision
=
target_model_config
.
tokenizer_revision
,
max_model_len
=
draft_max_model_len
,
max_model_len
=
None
,
quantization
=
draft_quantization
,
enforce_eager
=
target_model_config
.
enforce_eager
,
max_context_len_to_capture
=
target_model_config
.
...
...
@@ -715,6 +736,13 @@ class SpeculativeConfig:
max_logprobs
=
target_model_config
.
max_logprobs
,
)
draft_model_config
.
max_model_len
=
(
SpeculativeConfig
.
_maybe_override_draft_max_model_len
(
speculative_max_model_len
,
draft_model_config
.
max_model_len
,
target_model_config
.
max_model_len
,
))
draft_parallel_config
=
(
SpeculativeConfig
.
create_draft_parallel_config
(
target_parallel_config
))
...
...
@@ -725,6 +753,41 @@ class SpeculativeConfig:
num_speculative_tokens
,
)
@
staticmethod
def
_maybe_override_draft_max_model_len
(
speculative_max_model_len
:
Optional
[
int
],
draft_max_model_len
:
int
,
target_max_model_len
:
int
,
)
->
int
:
"""Determine the max sequence len for the draft model. This is usually
the draft_max_model_len, but may be the target_max_model_len if it is
less than the draft_max_model_len, or may be speculative_max_model_len
if it is specified.
This is necessary so that sequences do not exceed the capacity of the
draft model or the target model.
speculative_max_model_len is mainly used for testing that sequences can
skip speculation.
"""
if
speculative_max_model_len
is
not
None
:
if
speculative_max_model_len
>
draft_max_model_len
:
raise
ValueError
(
f
"
{
speculative_max_model_len
=
}
cannot be "
f
"larger than
{
draft_max_model_len
=
}
"
)
if
speculative_max_model_len
>
target_max_model_len
:
raise
ValueError
(
f
"
{
speculative_max_model_len
=
}
cannot be "
f
"larger than
{
target_max_model_len
=
}
"
)
return
speculative_max_model_len
return
min
(
draft_max_model_len
,
target_max_model_len
,
)
@
staticmethod
def
create_draft_parallel_config
(
target_parallel_config
:
ParallelConfig
)
->
ParallelConfig
:
...
...
vllm/engine/arg_utils.py
View file @
62b8aebc
...
...
@@ -73,6 +73,7 @@ class EngineArgs:
# Speculative decoding configuration.
speculative_model
:
Optional
[
str
]
=
None
num_speculative_tokens
:
Optional
[
int
]
=
None
speculative_max_model_len
:
Optional
[
int
]
=
None
def
__post_init__
(
self
):
if
self
.
tokenizer
is
None
:
...
...
@@ -237,7 +238,7 @@ class EngineArgs:
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
default
=
EngineArgs
.
block_size
,
choices
=
[
8
,
16
,
32
,
128
],
choices
=
[
8
,
16
,
32
],
help
=
'Token block size for contiguous chunks of '
'tokens.'
)
...
...
@@ -420,17 +421,25 @@ class EngineArgs:
parser
.
add_argument
(
'--speculative-model'
,
type
=
str
,
default
=
None
,
default
=
EngineArgs
.
speculative_model
,
help
=
'The name of the draft model to be used in speculative decoding.'
)
parser
.
add_argument
(
'--num-speculative-tokens'
,
type
=
int
,
default
=
None
,
default
=
EngineArgs
.
num_speculative_tokens
,
help
=
'The number of speculative tokens to sample from '
'the draft model in speculative decoding.'
)
parser
.
add_argument
(
'--speculative-max-model-len'
,
type
=
str
,
default
=
EngineArgs
.
speculative_max_model_len
,
help
=
'The maximum sequence length supported by the '
'draft model. Sequences over this length will skip '
'speculation.'
)
parser
.
add_argument
(
'--model-loader-extra-config'
,
type
=
str
,
default
=
EngineArgs
.
model_loader_extra_config
,
...
...
@@ -481,6 +490,9 @@ class EngineArgs:
target_dtype
=
self
.
dtype
,
speculative_model
=
self
.
speculative_model
,
num_speculative_tokens
=
self
.
num_speculative_tokens
,
speculative_max_model_len
=
self
.
speculative_max_model_len
,
enable_chunked_prefill
=
self
.
enable_chunked_prefill
,
use_v2_block_manager
=
self
.
use_v2_block_manager
,
)
scheduler_config
=
SchedulerConfig
(
...
...
vllm/engine/llm_engine.py
View file @
62b8aebc
...
...
@@ -22,7 +22,7 @@ from vllm.lora.request import LoRARequest
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
MultiModalData
,
SamplerOutput
,
Sequence
,
SequenceGroup
)
SequenceGroup
,
SequenceStage
)
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.tokenizer_group
import
(
BaseTokenizerGroup
,
get_tokenizer_group
)
...
...
@@ -480,9 +480,12 @@ class LLMEngine:
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
update_num_computed_tokens
(
scheduled_seq_group
.
token_chunk_size
)
# If uncomputed tokens > 0, it means prefill is chunked.
# We don't need to process outputs in that case.
if
seq_group
.
get_num_uncomputed_tokens
()
==
0
:
# If all sequences in the sequence group are in DECODE, then we can
# process the output tokens. Otherwise, they are (chunked) prefill
# samples and should not be processed.
stages
=
[
seq
.
data
.
_stage
for
seq
in
seq_group
.
seqs_dict
.
values
()]
if
all
(
stage
==
SequenceStage
.
DECODE
for
stage
in
stages
):
self
.
output_processor
.
process_outputs
(
seq_group
,
outputs
)
# Free the finished sequence groups.
...
...
@@ -569,7 +572,8 @@ class LLMEngine:
# Log stats.
if
self
.
log_stats
:
self
.
stat_logger
.
log
(
self
.
_get_stats
(
scheduler_outputs
))
self
.
stat_logger
.
log
(
self
.
_get_stats
(
scheduler_outputs
,
model_output
=
output
))
return
request_outputs
...
...
@@ -578,9 +582,18 @@ class LLMEngine:
if
self
.
log_stats
:
self
.
stat_logger
.
log
(
self
.
_get_stats
(
scheduler_outputs
=
None
))
def
_get_stats
(
self
,
scheduler_outputs
:
Optional
[
SchedulerOutputs
])
->
Stats
:
"""Get Stats to be Logged to Prometheus."""
def
_get_stats
(
self
,
scheduler_outputs
:
Optional
[
SchedulerOutputs
],
model_output
:
Optional
[
List
[
SamplerOutput
]]
=
None
)
->
Stats
:
"""Get Stats to be Logged to Prometheus.
Args:
scheduler_outputs: Optional, used to populate metrics related to
the scheduled batch,
model_output: Optional, used to emit speculative decoding metrics
which are created by the workers.
"""
now
=
time
.
time
()
# KV Cache Usage in %.
...
...
@@ -637,6 +650,14 @@ class LLMEngine:
time_to_first_tokens
=
time_last_iters
if
prompt_run
else
[]
time_per_output_tokens
=
[]
if
prompt_run
else
time_last_iters
# Spec decode, if enabled, emits specialized metrics from the worker in
# sampler output.
if
model_output
and
(
model_output
[
0
].
spec_decode_worker_metrics
is
not
None
):
spec_decode_metrics
=
model_output
[
0
].
spec_decode_worker_metrics
else
:
spec_decode_metrics
=
None
return
Stats
(
now
=
now
,
num_running
=
num_running
,
...
...
@@ -649,6 +670,7 @@ class LLMEngine:
time_to_first_tokens
=
time_to_first_tokens
,
time_per_output_tokens
=
time_per_output_tokens
,
time_e2e_requests
=
time_e2e_requests
,
spec_decode_metrics
=
spec_decode_metrics
,
)
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
...
...
vllm/engine/metrics.py
View file @
62b8aebc
import
time
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Protocol
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Protocol
import
numpy
as
np
from
prometheus_client
import
(
REGISTRY
,
Counter
,
Gauge
,
Histogram
,
Info
,
...
...
@@ -8,6 +8,9 @@ from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
from
vllm.logger
import
init_logger
if
TYPE_CHECKING
:
from
vllm.spec_decode.metrics
import
SpecDecodeWorkerMetrics
logger
=
init_logger
(
__name__
)
disable_created_metrics
()
...
...
@@ -118,6 +121,8 @@ class Stats:
time_per_output_tokens
:
List
[
float
]
time_e2e_requests
:
List
[
float
]
spec_decode_metrics
:
Optional
[
"SpecDecodeWorkerMetrics"
]
=
None
class
SupportsMetricsInfo
(
Protocol
):
...
...
@@ -235,3 +240,19 @@ class StatLogger:
self
.
num_prompt_tokens
=
[]
self
.
num_generation_tokens
=
[]
self
.
last_local_log
=
stats
.
now
if
stats
.
spec_decode_metrics
is
not
None
:
logger
.
info
(
self
.
_format_spec_decode_metrics_str
(
stats
.
spec_decode_metrics
))
def
_format_spec_decode_metrics_str
(
self
,
metrics
:
"SpecDecodeWorkerMetrics"
)
->
str
:
return
(
"Speculative metrics: "
f
"Draft acceptance rate:
{
metrics
.
draft_acceptance_rate
:.
3
f
}
, "
f
"System efficiency:
{
metrics
.
system_efficiency
:.
3
f
}
, "
f
"Number of speculative tokens:
{
metrics
.
num_spec_tokens
}
, "
f
"Number of accepted tokens:
{
metrics
.
accepted_tokens
}
, "
f
"Number of draft tokens tokens:
{
metrics
.
draft_tokens
}
, "
f
"Number of emitted tokens tokens:
{
metrics
.
emitted_tokens
}
."
)
vllm/executor/gpu_executor.py
View file @
62b8aebc
...
...
@@ -83,6 +83,7 @@ class GPUExecutor(ExecutorBase):
scheduler_config
=
self
.
scheduler_config
,
device_config
=
self
.
device_config
,
cache_config
=
self
.
cache_config
,
# TODO allow draft-model specific load config.
load_config
=
self
.
load_config
,
local_rank
=
0
,
rank
=
0
,
...
...
vllm/model_executor/layers/rejection_sampler.py
View file @
62b8aebc
...
...
@@ -144,6 +144,7 @@ class RejectionSampler(nn.Module):
recovered_probs
=
self
.
_get_recovered_probs
(
target_probs
,
draft_probs
).
reshape
(
batch_size
*
k
,
vocab_size
)
# NOTE: the recovered_probs are overwritten by this method.
recovered_token_ids
=
_multinomial
(
recovered_probs
,
num_samples
=
1
).
reshape
(
batch_size
,
k
)
...
...
@@ -307,6 +308,12 @@ class RejectionSampler(nn.Module):
output_with_bonus_tokens
[:,
-
1
]
=
torch
.
where
(
output
[:,
-
1
]
!=
-
1
,
bonus_token_ids
,
-
1
)
# We disable bonus tokens because it causes corrupt KV cache for
# proposal methods that require KV cache. We can fix it by "prefilling"
# the bonus token in the proposer. The following issue tracks the fix.
# https://github.com/vllm-project/vllm/issues/4212
output_with_bonus_tokens
[:,
-
1
]
=
-
1
# Fill the recovered token ids.
output
.
mul_
(
~
after_false_mask
).
add_
(
recovered_token_ids
.
mul
(
after_false_mask
))
...
...
vllm/model_executor/layers/sampler.py
View file @
62b8aebc
...
...
@@ -35,6 +35,14 @@ class Sampler(nn.Module):
in logits for each token in the input prompt.
"""
def
__init__
(
self
):
super
().
__init__
()
# Whether or not the SamplerOutput should have on-device tensors
# containing the sampled token ids and probabilities. This is used by
# speculative decoding.
self
.
include_gpu_probs_tensor
=
False
def
forward
(
self
,
logits
:
torch
.
Tensor
,
...
...
@@ -79,13 +87,45 @@ class Sampler(nn.Module):
logprobs
=
torch
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
# Sample the next tokens.
sample_results
=
_sample
(
probs
,
logprobs
,
sampling_metadata
,
sampling_tensors
)
sample_results
,
maybe_sampled_tokens_tensor
=
_sample
(
probs
,
logprobs
,
sampling_metadata
,
sampling_tensors
,
include_gpu_probs_tensor
=
self
.
include_gpu_probs_tensor
,
modify_greedy_probs
=
self
.
_should_modify_greedy_probs_inplace
,
)
if
self
.
include_gpu_probs_tensor
:
assert
maybe_sampled_tokens_tensor
is
not
None
sampled_tokens_tensor
=
maybe_sampled_tokens_tensor
on_device_tensors
=
(
probs
,
sampled_tokens_tensor
)
else
:
on_device_tensors
=
None
# Get the logprobs query results.
prompt_logprobs
,
sample_logprobs
=
_get_logprobs
(
logprobs
,
sampling_metadata
,
sample_results
)
return
_build_sampler_output
(
sample_results
,
sampling_metadata
,
prompt_logprobs
,
sample_logprobs
)
return
_build_sampler_output
(
sample_results
,
sampling_metadata
,
prompt_logprobs
,
sample_logprobs
,
on_device_tensors
=
on_device_tensors
)
@
property
def
_should_modify_greedy_probs_inplace
(
self
)
->
bool
:
"""Whether or not the sampler should modify the probability distribution
of greedily-sampled tokens such that multinomial sampling would sample
the greedily-sampled token.
In other words, if True then we set the probability of the greedily-
sampled token to 1.
This is used by speculative decoding, which requires that the sampling
method be encoded into the probability distribution.
"""
# Modify greedy probs if include_gpu_probs_tensor is set.
return
self
.
include_gpu_probs_tensor
def
_get_bin_counts_and_mask
(
...
...
@@ -359,7 +399,9 @@ def _sample_with_torch(
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]:
include_gpu_probs_tensor
:
bool
,
modify_greedy_probs
:
bool
,
)
->
Tuple
[
List
[
Tuple
[
List
[
int
],
List
[
int
]]],
Optional
[
torch
.
Tensor
]]:
categorized_seq_group_ids
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices
=
sampling_metadata
.
categorized_sample_indices
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
...
...
@@ -371,6 +413,15 @@ def _sample_with_torch(
sample_metadata
=
{}
multinomial_samples
=
{}
# Create output tensor for sampled token ids.
if
include_gpu_probs_tensor
:
sampled_token_ids_tensor
=
torch
.
empty
(
logprobs
.
shape
[
0
],
1
,
dtype
=
torch
.
long
,
device
=
logprobs
.
device
)
else
:
sampled_token_ids_tensor
=
None
# Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync.
for
sampling_type
in
SamplingType
:
...
...
@@ -383,9 +434,25 @@ def _sample_with_torch(
is_prompts
=
[
i
<
sampling_metadata
.
num_prompts
for
i
in
seq_group_ids
]
sample_metadata
[
sampling_type
]
=
(
seq_group_ids
,
seq_groups
,
is_prompts
,
sample_indices
)
long_sample_indices
=
sample_indices
.
long
()
if
sampling_type
==
SamplingType
.
GREEDY
:
greedy_samples
=
torch
.
argmax
(
logprobs
[
sample_indices
.
long
()
],
greedy_samples
=
torch
.
argmax
(
logprobs
[
long_
sample_indices
],
dim
=-
1
)
if
include_gpu_probs_tensor
:
# Store sampled tokens in output tensor.
sampled_token_ids_tensor
[
long_sample_indices
]
=
greedy_samples
.
unsqueeze
(
-
1
)
if
modify_greedy_probs
:
# If required, modify the probabilities such that sampling from
# the modified distribution would always sample the argmax
# token id.
_modify_greedy_probs_inplace
(
logprobs
,
probs
,
long_sample_indices
,
greedy_samples
)
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
max_best_of_in_batch
=
1
for
seq_group
,
is_prompt
in
zip
(
seq_groups
,
is_prompts
):
...
...
@@ -397,15 +464,23 @@ def _sample_with_torch(
"seq_groups"
:
seq_groups
,
"generators"
:
sampling_metadata
.
generators
,
}
multinomial_samples
[
sampling_type
]
=
_multinomial
(
probs
[
sample_indices
.
long
()
],
max_best_of_in_batch
,
probs
[
long_
sample_indices
],
max_best_of_in_batch
,
**
seeded_args
)
if
include_gpu_probs_tensor
:
# Store sampled tokens in output tensor.
sampled_token_ids_tensor
[
long_sample_indices
]
=
multinomial_samples
[
sampling_type
]
elif
sampling_type
==
SamplingType
.
BEAM
:
beam_search_logprobs
=
logprobs
[
sample_indices
]
else
:
raise
ValueError
(
f
"Unsupported sampling type:
{
sampling_type
}
"
)
# GPU<->CPU sync happens in the loop below.
# This also converts the sample output to Python objects.
for
sampling_type
in
SamplingType
:
if
sampling_type
not
in
sample_metadata
:
...
...
@@ -427,7 +502,7 @@ def _sample_with_torch(
sample_results_dict
[
i
]
for
i
in
range
(
len
(
sampling_metadata
.
seq_groups
))
]
return
sample_results
return
sample_results
,
sampled_token_ids_tensor
def
_sample_with_triton_kernel
(
...
...
@@ -511,12 +586,17 @@ def _sample_with_triton_kernel(
def
_sample
(
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_tensors
:
SamplingTensors
,
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]:
return
_sample_with_torch
(
probs
,
logprobs
,
sampling_metadata
)
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_tensors
:
SamplingTensors
,
include_gpu_probs_tensor
:
bool
,
modify_greedy_probs
:
bool
)
->
Tuple
[
List
[
Tuple
[
List
[
int
],
List
[
int
]]],
Optional
[
torch
.
Tensor
]]:
return
_sample_with_torch
(
probs
,
logprobs
,
sampling_metadata
,
include_gpu_probs_tensor
=
include_gpu_probs_tensor
,
modify_greedy_probs
=
modify_greedy_probs
,
)
# TODO: Enable once Triton kernel & associated code is faster.
# return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
...
...
@@ -680,12 +760,73 @@ def _get_logprobs(
return
result_prompt_logprobs
,
result_sample_logprobs
def
_modify_greedy_probs_inplace
(
logprobs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
sample_indices
:
torch
.
Tensor
,
greedy_samples
:
torch
.
Tensor
)
->
None
:
"""Modify the probability distributions of the greedily-sampled tokens such
that each sampled token has a "probability" of 1.0. This is required by
speculative decoding, which depends on the sampling method being encoded
within the probability distribution for correctness.
# Why do we only need to do this for greedy sampling?
vLLM's sampler performs the following steps for greedy or multinomial
(random) sampling:
1. Get logits from model.
2. Modify logits according to per-sequence sampling parameters.
- Multiply by temperature, top-k and top-p masking, penalize tokens
according to their frequency, etc.
3. Sample a token.
- Random sampling simply samples from the modified probability
distribution.
- Greedy sampling performs `argmax` to obtain the token with the
highest likelihood.
Ignoring greedy sampling for a moment, we find that the computed probability
distribution has the following property: we can sample from it independently
and find that the token sampled by the Sampler has a frequency corresponding
to how often we see it in our sampling. In other words, for tokens sampled
with vLLM's random SamplingType, the computed probability distribution
encodes the sampling methodology completely.
Greedy sampling does not normally have this property. vLLM modifies logits
according to sampling params, then performs `argmax`, then returns the
sampled token and the computed probability distribution. If we sample from
the distribution, we'll find the likelihood of the greedily-sampled token
is not always 1.0.
Since lossless speculative decoding requires that the sampling methodology
be encoded within the probability distribution, we are motivated to modify
the probability distribution such that the sampled token has probability 1
when speculative decoding is used.
NOTE: Alternatively, we could use an extremely low temperature to achieve
greedy sampling using multinomial computation and unite the codepaths. This
has implications on the overall design of the sampler, e.g. how to record
accurate logprobs for the user, so this improvement is deferred to later.
"""
logprobs
[
sample_indices
,
:]
=
-
float
(
'inf'
)
logprobs
[
sample_indices
,
greedy_samples
]
=
0.0
probs
[
sample_indices
,
:]
=
0
probs
[
sample_indices
,
greedy_samples
]
=
1.0
def
_build_sampler_output
(
sample_results
:
List
[
Tuple
[
List
[
int
],
List
[
int
]]],
sampling_metadata
:
SamplingMetadata
,
prompt_logprobs
:
List
[
Optional
[
PromptLogprobs
]],
sample_logprobs
:
List
[
SampleLogprobs
],
on_device_tensors
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
)
->
SamplerOutput
:
"""Construct Python objects with the output of sampling.
Args:
on_device_tensors: Tuple containing on-device tensors with the
probabilities used in sampling and the sampled token ids. This
allows post-processing without copies to CPU/serialization, e.g. in
speculative decoding rejection sampling.
"""
sampler_output
=
[]
for
(
seq_group
,
sample_result
,
group_prompt_logprobs
,
group_sample_logprobs
)
in
zip
(
sampling_metadata
.
seq_groups
,
...
...
@@ -701,4 +842,15 @@ def _build_sampler_output(
SequenceOutput
(
seq_ids
[
parent_id
],
next_token_id
,
logprobs
))
sampler_output
.
append
(
SequenceGroupOutput
(
seq_outputs
,
group_prompt_logprobs
))
return
SamplerOutput
(
outputs
=
sampler_output
)
# If not specified, store None values in SamplerOutput.
if
on_device_tensors
is
not
None
:
sampled_token_probs
,
sampled_token_ids
=
on_device_tensors
else
:
sampled_token_probs
,
sampled_token_ids
=
(
None
,
None
)
return
SamplerOutput
(
outputs
=
sampler_output
,
sampled_token_probs
=
sampled_token_probs
,
sampled_token_ids
=
sampled_token_ids
,
)
vllm/spec_decode/batch_expansion.py
View file @
62b8aebc
...
...
@@ -6,8 +6,8 @@ import torch
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.util
import
(
get_all_seq_ids
,
maybe_mock_device_tensors
,
nvtx_range
,
sampler_output_to_torch
,
from
vllm.spec_decode.util
import
(
get_all_seq_ids
,
nvtx_range
,
sampler_output_to_torch
,
split_batch_by_proposal_len
)
from
vllm.worker.worker_base
import
WorkerBase
...
...
@@ -72,10 +72,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
proposal_lens_list
=
proposals
.
proposal_lens
.
tolist
()
proposal_token_ids_list
=
proposals
.
proposal_token_ids
.
tolist
()
# Filter the list to ignore -1 proposals.
proposal_token_ids_list_without_skips
=
[
proposals
for
proposals
in
proposal_token_ids_list
if
-
1
not
in
proposals
]
(
spec_indices
,
non_spec_indices
,
target_seq_group_metadata_list
,
num_scoring_tokens
)
=
self
.
_expand_batch
(
seq_group_metadata_list
=
seq_group_metadata_list
,
proposal_token_ids_list
=
proposal_token_ids_list
,
proposal_token_ids_list
=
proposal_token_ids_list
_without_skips
,
proposal_lens_list
=
proposal_lens_list
,
)
...
...
@@ -89,7 +95,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
target_sampler_output
=
target_sampler_output
[
0
]
all_tokens
,
all_probs
=
self
.
_contract_batch
(
original
_bs
=
len
(
seq_group_metadata_list
),
contracted
_bs
=
len
(
seq_group_metadata_list
),
target_sampler_output
=
target_sampler_output
,
proposals
=
proposals
,
num_scoring_tokens
=
num_scoring_tokens
,
...
...
@@ -128,14 +134,21 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
select_proposal_len_zero
=
True
)
target_seq_group_metadata_list
=
self
.
_create_scoring_model_input
(
spec_seqs
,
proposal_token_ids_list
)
seq_group_metadata_list
=
spec_seqs
,
proposal_token_ids
=
proposal_token_ids_list
,
# NOTE: We determine the seq ids in the expanded batch using the
# full seq_group_metadata_list, instead of only spec_seqs.
target_seq_ids_iter
=
self
.
_create_target_seq_id_iterator
(
seq_ids
=
get_all_seq_ids
(
seq_group_metadata_list
)),
)
num_scoring_tokens
=
len
(
target_seq_group_metadata_list
)
target_seq_group_metadata_list
.
extend
(
non_spec_seqs
)
return
(
spec_indices
,
non_spec_indices
,
target_seq_group_metadata_list
,
num_scoring_tokens
)
def
_contract_batch
(
self
,
original
_bs
:
int
,
def
_contract_batch
(
self
,
contracted
_bs
:
int
,
target_sampler_output
:
List
[
SamplerOutput
],
proposals
:
SpeculativeProposals
,
num_scoring_tokens
:
int
,
non_spec_indices
:
List
[
int
],
...
...
@@ -144,42 +157,41 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
"""
# We mock the device tensors until PR 7/9 is merged (e2e correctness).
# https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
maybe_mock_device_tensors
(
sampler_output
=
target_sampler_output
,
batch_size
=
len
(
non_spec_indices
)
+
num_scoring_tokens
,
vocab_size
=
self
.
_vocab_size
,
device
=
self
.
_device
,
)
contracted_bs is the original batch size, and the batch size that the
target_sampler_output will be contracted to.
"""
(
target_token_ids
,
target_probs
,
non_spec_target_token_ids
,
non_spec_target_probs
)
=
self
.
_split_scoring_output
(
target_sampler_output
,
num_scoring_tokens
)
# Map distinct sequences used to score each token
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
batch_size
,
k
=
proposals
.
proposal_token_ids
.
shape
expanded_batch_size
,
k
=
proposals
.
proposal_token_ids
.
shape
# The number of tokens in the expanded batch used for speculation is
# equal to the total expanded batch size minus the number of samples for
# non-speculative sequences.
non_spec_expanded_bs
,
_
=
non_spec_target_token_ids
.
shape
spec_expanded_bs
=
expanded_batch_size
-
non_spec_expanded_bs
target_token_ids
=
target_token_ids
.
squeeze
().
reshape
(
batch_size
,
k
+
1
)
target_probs
=
target_probs
.
squeeze
().
reshape
(
batch_size
,
k
+
1
,
spec_expanded_bs
,
k
+
1
)
target_probs
=
target_probs
.
squeeze
().
reshape
(
spec_expanded_bs
,
k
+
1
,
self
.
_vocab_size
)
all_tokens
=
torch
.
full
(
size
=
(
original
_bs
,
k
+
1
),
all_tokens
=
torch
.
full
(
size
=
(
contracted
_bs
,
k
+
1
),
fill_value
=-
1
,
device
=
self
.
_device
,
dtype
=
torch
.
long
)
all_probs
=
torch
.
zeros
(
original
_bs
,
all_probs
=
torch
.
zeros
(
contracted
_bs
,
k
+
1
,
self
.
_vocab_size
,
device
=
self
.
_device
,
dtype
=
torch
.
float32
)
if
non_spec_indices
:
all_tokens
[
non_spec_indices
,
0
]
=
non_spec_target_token_ids
all_tokens
[
non_spec_indices
,
:
1
]
=
non_spec_target_token_ids
all_probs
[
non_spec_indices
,
:
1
,
:]
=
non_spec_target_probs
if
spec_indices
:
...
...
@@ -192,17 +204,19 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
proposal_token_ids
:
List
[
List
[
TokenId
]],
# shape: [batch_size, k]
target_seq_ids_iter
:
Iterator
[
TargetSeqId
],
)
->
List
[
SequenceGroupMetadata
]:
"""Given the original input sequences and proposed tokens from the draft
model, create a list of target sequences that can be used for scoring.
target_seq_ids_iter provides sequence ids for the expanded batch,
fulfilling the requirement that no seq id in the expanded batch is equal
to the seq id in the original batch.
"""
if
not
seq_group_metadata_list
:
return
[]
target_seq_ids_iter
=
self
.
_create_target_seq_id_iterator
(
get_all_seq_ids
(
seq_group_metadata_list
))
target_seq_group_metadata
=
list
(
chain
.
from_iterable
(
self
.
_create_target_seq_group_metadata
(
...
...
vllm/spec_decode/interfaces.py
View file @
62b8aebc
...
...
@@ -24,9 +24,9 @@ class SpeculativeProposals:
def
__repr__
(
self
):
return
(
f
"SpeculativeProposals("
f
"proposal_token_ids=
{
self
.
proposal_token_ids
.
shape
}
, "
f
"proposal_token_ids=
{
self
.
proposal_token_ids
}
, "
f
"proposal_probs=
{
self
.
proposal_probs
.
shape
}
, "
f
"proposal_lens=
{
self
.
proposal_lens
.
shape
}
)"
)
f
"proposal_lens=
{
self
.
proposal_lens
}
)"
)
@
dataclass
...
...
vllm/spec_decode/metrics.py
View file @
62b8aebc
...
...
@@ -147,15 +147,16 @@ class AsyncMetricsCollector:
emitted_tokens
=
self
.
_aggregate_num_emitted_tokens
.
item
()
draft_tokens
=
self
.
_aggregate_num_draft_tokens
num_possible_tokens
=
self
.
get_max_num_accepted_tokens
(
draft_tokens
,
k
)
max_num_emitted_tokens
=
self
.
get_max_num_emitted_tokens
(
draft_tokens
,
k
)
if
draft_tokens
>
0
:
draft_acceptance_rate
=
accepted_tokens
/
draft_tokens
else
:
draft_acceptance_rate
=
float
(
"nan"
)
if
num_possible
_tokens
>
0
:
system_efficiency
=
emitted_tokens
/
num_possible
_tokens
if
max_num_emitted
_tokens
>
0
:
system_efficiency
=
emitted_tokens
/
max_num_emitted
_tokens
else
:
system_efficiency
=
float
(
"nan"
)
...
...
@@ -169,8 +170,22 @@ class AsyncMetricsCollector:
)
@
staticmethod
def
get_max_num_accepted_tokens
(
draft_tokens
:
int
,
k
:
int
)
->
int
:
# Divide by k since batch size can be variable.
total_num_spec_seqs
=
draft_tokens
/
k
num_accepted_per_seq_if_all_accepted
=
k
+
1
return
int
(
total_num_spec_seqs
/
num_accepted_per_seq_if_all_accepted
)
def
get_max_num_emitted_tokens
(
draft_tokens
:
int
,
k
:
int
)
->
int
:
"""Calculate the number of emitted tokens, assuming all tokens are
accepted.
This is equal to the number of sequences that have been speculated on,
times (speculation len + 1). The +1 comes from the bonus token.
"""
# Determine the number of sequences that have been speculated on. Since
# the batch size can be variable, we divide by k.
assert
draft_tokens
%
k
==
0
total_num_spec_seqs
=
draft_tokens
//
k
# A single sequence may emit k accepted tokens and one bonus token in
# the best case.
num_emitted_per_seq_if_all_accepted
=
k
+
1
# The max num of emitted tokens is the number of speculated sequences
# times the max emitted per seq.
return
total_num_spec_seqs
*
num_emitted_per_seq_if_all_accepted
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment