Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dd572c0a
Unverified
Commit
dd572c0a
authored
Jul 18, 2025
by
Woosuk Kwon
Committed by
GitHub
Jul 18, 2025
Browse files
[V0 Deprecation] Remove V0 Spec Decode workers (#21152)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
9ffe905a
Changes
73
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
5966 deletions
+0
-5966
tests/spec_decode/e2e/test_logprobs.py
tests/spec_decode/e2e/test_logprobs.py
+0
-315
tests/spec_decode/e2e/test_medusa_correctness.py
tests/spec_decode/e2e/test_medusa_correctness.py
+0
-417
tests/spec_decode/e2e/test_mlp_correctness.py
tests/spec_decode/e2e/test_mlp_correctness.py
+0
-533
tests/spec_decode/e2e/test_mtp_correctness.py
tests/spec_decode/e2e/test_mtp_correctness.py
+0
-333
tests/spec_decode/e2e/test_multistep_correctness.py
tests/spec_decode/e2e/test_multistep_correctness.py
+0
-842
tests/spec_decode/e2e/test_ngram_correctness.py
tests/spec_decode/e2e/test_ngram_correctness.py
+0
-392
tests/spec_decode/e2e/test_seed.py
tests/spec_decode/e2e/test_seed.py
+0
-70
tests/spec_decode/test_batch_expansion.py
tests/spec_decode/test_batch_expansion.py
+0
-110
tests/spec_decode/test_dynamic_spec_decode.py
tests/spec_decode/test_dynamic_spec_decode.py
+0
-90
tests/spec_decode/test_memory_usage.py
tests/spec_decode/test_memory_usage.py
+0
-91
tests/spec_decode/test_metrics.py
tests/spec_decode/test_metrics.py
+0
-205
tests/spec_decode/test_multi_step_worker.py
tests/spec_decode/test_multi_step_worker.py
+0
-838
tests/spec_decode/test_ngram_worker.py
tests/spec_decode/test_ngram_worker.py
+0
-221
tests/spec_decode/test_scorer.py
tests/spec_decode/test_scorer.py
+0
-116
tests/spec_decode/test_spec_decode_worker.py
tests/spec_decode/test_spec_decode_worker.py
+0
-945
tests/spec_decode/test_utils.py
tests/spec_decode/test_utils.py
+0
-150
tests/spec_decode/utils.py
tests/spec_decode/utils.py
+0
-290
tests/test_sequence.py
tests/test_sequence.py
+0
-1
tests/v1/test_oracle.py
tests/v1/test_oracle.py
+0
-6
tools/mypy.sh
tools/mypy.sh
+0
-1
No files found.
tests/spec_decode/e2e/test_logprobs.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
itertools
import
cycle
import
pytest
from
vllm
import
SamplingParams
from
..utils
import
maybe_enable_chunked_prefill
from
.conftest
import
run_equality_correctness_test
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model_name"
:
"JackFram/llama-160m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# The original model is float32, keep it for numerical stability.
"dtype"
:
"float32"
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_config"
:
{
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
3
,
"disable_logprobs"
:
False
,
},
},
{
"speculative_config"
:
{
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
3
,
"disable_logprobs"
:
True
,
},
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
7
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"logprobs"
,
[
1
,
6
])
@
pytest
.
mark
.
parametrize
(
"prefill_chunk_size"
,
[
-
1
,
4
,
12
])
def
test_logprobs_equality
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
,
logprobs
:
int
,
prefill_chunk_size
:
int
):
"""Verify output logprobs are equal with and without speculative decoding,
as well as with and without chunked prefill.
"""
maybe_enable_chunked_prefill
(
prefill_chunk_size
,
common_llm_kwargs
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
output_len
,
seed
,
temperature
=
0.0
,
logprobs
=
logprobs
,
prompt_logprobs
=
logprobs
,
disable_logprobs
=
test_llm_kwargs
[
"speculative_config"
]
[
"disable_logprobs"
])
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model_name"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# The original model is float32, keep it for numerical stability.
"dtype"
:
"float32"
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_config"
:
{
"model"
:
"JackFram/llama-160m"
,
"num_speculative_tokens"
:
3
,
"disable_logprobs"
:
False
,
},
},
{
"speculative_config"
:
{
"model"
:
"JackFram/llama-160m"
,
"num_speculative_tokens"
:
6
,
"disable_logprobs"
:
False
,
},
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"logprobs"
,
[
1
,
6
])
def
test_logprobs_different_k
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
,
logprobs
:
int
):
"""Veriy logprob greedy equality with different speculation lens.
"""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
output_len
,
seed
,
temperature
=
0.0
,
logprobs
=
logprobs
,
disable_logprobs
=
test_llm_kwargs
[
"speculative_config"
]
[
"disable_logprobs"
])
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model_name"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# The original model is float32, keep it for numerical stability.
"dtype"
:
"float32"
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_config"
:
{
"model"
:
"JackFram/llama-160m"
,
"num_speculative_tokens"
:
3
,
"disable_logprobs"
:
False
,
# Artificially limit the draft model max model len; this forces
# vLLM to skip speculation once the sequences grow beyond 32-k
# tokens.
"max_model_len"
:
32
,
},
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"logprobs"
,
[
1
])
def
test_logprobs_when_skip_speculation
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
,
logprobs
:
int
):
"""Verify logprobs greedy equality when some sequences skip speculation.
"""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
output_len
,
seed
,
temperature
=
0.0
,
logprobs
=
logprobs
,
disable_logprobs
=
test_llm_kwargs
[
"speculative_config"
]
[
"disable_logprobs"
])
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model_name"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# The original model is float32, keep it for numerical stability.
"dtype"
:
"float32"
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_config"
:
{
"model"
:
"JackFram/llama-160m"
,
"num_speculative_tokens"
:
3
,
"disable_logprobs"
:
False
,
},
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"logprobs"
,
[
6
])
def
test_logprobs_temp_1
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
,
logprobs
:
int
):
"""Verify at least one logprob result has num_logprobs+1, which tests the
case where the sampled token is not in top-k logprobs.
Ideally, this test should validate equality with non-spec by getting
logprobs. This is left as future improvement.
"""
temperature
=
1.0
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
"San Francisco is know for its"
,
"Facebook was created in 2004 by"
,
"Curious George is a"
,
"Python 3.11 brings improvements to its"
,
]
prompts
=
[
prompt
for
prompt
,
_
in
zip
(
cycle
(
prompts
),
range
(
batch_size
))]
sampling_params
=
SamplingParams
(
max_tokens
=
output_len
,
ignore_eos
=
True
,
temperature
=
temperature
,
logprobs
=
logprobs
,
)
sd_args
=
{
**
common_llm_kwargs
,
**
per_test_common_llm_kwargs
,
**
test_llm_kwargs
,
}
with
vllm_runner
(
**
sd_args
)
as
vllm_model
:
sd_outputs
=
vllm_model
.
generate_w_logprobs
(
prompts
,
sampling_params
)
num_returned_logprobs
=
[
len
(
seq_logprobs
)
for
seq_logprobs
in
sd_outputs
[
-
1
]
]
# Assert one of the returned logprobs has > num_logprobs (indicating the
# sampled token is not in top-k).
assert
any
(
[
num_returned
>
logprobs
for
num_returned
in
num_returned_logprobs
])
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model_name"
:
"JackFram/llama-160m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# The original model is float32, keep it for numerical stability.
"dtype"
:
"float32"
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_config"
:
{
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
3
,
"disable_logprobs"
:
True
,
},
}])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"logprobs"
,
[
0
])
def
test_logprobs_disabled
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
,
logprobs
:
int
):
"""Check the behavior when logprobs are disabled.
Token choices should match with the base model.
"""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
output_len
,
seed
,
temperature
=
0.0
,
logprobs
=
logprobs
,
disable_logprobs
=
test_llm_kwargs
[
"speculative_config"
]
[
"disable_logprobs"
])
tests/spec_decode/e2e/test_medusa_correctness.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various number of speculative tokens.
With those tests, we can say at least, Medusa would not break the
correctness for the target model outputs.
"""
import
pytest
from
..utils
import
maybe_enable_chunked_prefill
from
.conftest
import
run_equality_correctness_test
# main model
# lmsys/vicuna-7b-v1.3 was to be used but it's causing
# OOM in CI pipeline, so using a smaller model.
MAIN_MODEL
=
"JackFram/llama-68m"
# speculative model
SPEC_MODEL
=
"abhigoyal/vllm-medusa-llama-68m-random"
# max number of speculative tokens: this corresponds to
# num_heads in the config.json of the speculator model.
MAX_SPEC_TOKENS
=
5
# precision
PRECISION
=
"float32"
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Print spec metrics.
"disable_log_stats"
:
False
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model_name"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_config"
:
{
"model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
},
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
128
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"prefill_chunk_size"
,
[
-
1
,
32
])
def
test_medusa_e2e_greedy_correctness
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
,
prefill_chunk_size
:
int
):
"""Verify greedy equality with different batch size."""
maybe_enable_chunked_prefill
(
prefill_chunk_size
,
test_llm_kwargs
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Print spec metrics.
"disable_log_stats"
:
False
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model_name"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_config"
:
{
"model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"disable_logprobs"
:
False
,
},
},
{
"speculative_config"
:
{
"model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"disable_logprobs"
:
True
,
},
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
8
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"logprobs"
,
[
1
,
6
])
@
pytest
.
mark
.
parametrize
(
"prefill_chunk_size"
,
[
-
1
,
32
])
def
test_medusa_e2e_greedy_logprobs
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
,
logprobs
:
int
,
prefill_chunk_size
:
int
):
"""Verify greedy equality with different batch size."""
maybe_enable_chunked_prefill
(
prefill_chunk_size
,
test_llm_kwargs
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
,
logprobs
=
logprobs
,
prompt_logprobs
=
logprobs
,
disable_logprobs
=
test_llm_kwargs
[
"speculative_config"
]
[
"disable_logprobs"
])
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"enforce_eager"
:
False
,
# Print spec metrics.
"disable_log_stats"
:
False
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model_name"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_config"
:
{
"model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
},
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
128
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"prefill_chunk_size"
,
[
-
1
,
32
])
def
test_medusa_e2e_greedy_correctness_cuda_graph
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
,
prefill_chunk_size
:
int
):
"""Verify greedy equality with cuda graph enabled and different
batch sizes."""
maybe_enable_chunked_prefill
(
prefill_chunk_size
,
test_llm_kwargs
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"block_size"
:
16
,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override"
:
2
+
256
//
8
,
"max_model_len"
:
(
2
+
256
//
8
)
*
8
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model_name"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_config"
:
{
"model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
},
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use small output len for fast test.
128
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"prefill_chunk_size"
,
[
-
1
,
32
])
def
test_medusa_e2e_greedy_correctness_with_preemption
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
,
prefill_chunk_size
:
int
):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
maybe_enable_chunked_prefill
(
prefill_chunk_size
,
test_llm_kwargs
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model_name"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_config"
:
{
"model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
k
,
},
}
# Try a range of num. speculative tokens
for
k
in
range
(
1
,
1
+
MAX_SPEC_TOKENS
)
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"prefill_chunk_size"
,
[
-
1
,
32
])
def
test_medusa_different_k
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
,
prefill_chunk_size
:
int
):
"""Verify that medusa speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
maybe_enable_chunked_prefill
(
prefill_chunk_size
,
test_llm_kwargs
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model_name"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_config"
:
{
"model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"disable_by_batch_size"
:
4
,
},
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"prefill_chunk_size"
,
[
-
1
,
32
])
def
test_medusa_disable_queue
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
,
prefill_chunk_size
:
int
):
"""Verify that medusa speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
maybe_enable_chunked_prefill
(
prefill_chunk_size
,
test_llm_kwargs
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model_name"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_config"
:
{
"model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"disable_by_batch_size"
:
4
,
"disable_mqa_scorer"
:
True
,
},
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"prefill_chunk_size"
,
[
-
1
,
32
])
def
test_mqa_scorer
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
,
prefill_chunk_size
:
int
):
"""Verify that speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
maybe_enable_chunked_prefill
(
prefill_chunk_size
,
test_llm_kwargs
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
if
__name__
==
"__main__"
:
import
pytest
pytest
.
main
([
__file__
])
tests/spec_decode/e2e/test_mlp_correctness.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various number of speculative tokens.
With those tests, we can say at least, MLPSpeculator would not break the
correctness for the target model outputs.
"""
from
unittest.mock
import
patch
import
pytest
from
vllm.model_executor.layers.vocab_parallel_embedding
import
pad_vocab_size
from
..utils
import
maybe_enable_chunked_prefill
from
.conftest
import
run_equality_correctness_test
# main model
MAIN_MODEL
=
"JackFram/llama-160m"
# speculative model
SPEC_MODEL
=
"ibm-ai-platform/llama-160m-accelerator"
# max. number of speculative tokens: this corresponds to
# n_predict in the config.json of the speculator model.
MAX_SPEC_TOKENS
=
3
# precision
PRECISION
=
"float32"
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Print spec metrics.
"disable_log_stats"
:
False
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model_name"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_config"
:
{
"model"
:
SPEC_MODEL
,
},
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
128
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
,
32
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"prefill_chunk_size"
,
[
-
1
,
32
])
def
test_mlp_e2e_greedy_correctness
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
,
prefill_chunk_size
:
int
):
"""Verify greedy equality with different batch size."""
maybe_enable_chunked_prefill
(
prefill_chunk_size
,
test_llm_kwargs
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Print spec metrics.
"disable_log_stats"
:
False
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model_name"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_config"
:
{
"model"
:
SPEC_MODEL
,
"disable_logprobs"
:
False
,
},
},
{
"speculative_config"
:
{
"model"
:
SPEC_MODEL
,
"disable_logprobs"
:
True
,
},
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"logprobs"
,
[
1
,
6
])
@
pytest
.
mark
.
parametrize
(
"prefill_chunk_size"
,
[
-
1
,
4
])
def
test_mlp_e2e_greedy_logprobs
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
,
logprobs
:
int
,
prefill_chunk_size
:
int
):
"""Verify greedy equality with different batch size."""
maybe_enable_chunked_prefill
(
prefill_chunk_size
,
test_llm_kwargs
)
# NOTE Test is sensitive enough st if we don't enable chunked prefill
# scheduling on baseline too, we get slightly different logprobs, ending
# up sampling different tokens at the tail (ie top tokens don't change).
# TL;DR: sd+cp == org+cp but sd+cp != org..is this expected?
maybe_enable_chunked_prefill
(
prefill_chunk_size
,
baseline_llm_kwargs
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
,
logprobs
=
logprobs
,
prompt_logprobs
=
logprobs
,
disable_logprobs
=
test_llm_kwargs
[
"speculative_config"
]
[
"disable_logprobs"
])
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Print spec metrics.
"disable_log_stats"
:
False
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model_name"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_config"
:
{
"model"
:
SPEC_MODEL
,
},
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
2048
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"prefill_chunk_size"
,
[
-
1
,
4
])
def
test_mlp_e2e_acceptance_rate
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
prefill_chunk_size
:
int
,
seed
:
int
):
"""Verify acceptance rate with different batch size and large output
length."""
maybe_enable_chunked_prefill
(
prefill_chunk_size
,
test_llm_kwargs
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
temperature
=
0.0
,
seed
=
seed
,
expected_acceptance_rate
=
0.48
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Print spec metrics.
"disable_log_stats"
:
False
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model_name"
:
MAIN_MODEL
,
# Speculative config
"speculative_config"
:
{
"model"
:
SPEC_MODEL
,
},
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{
"seed"
:
1
}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"seed"
:
5
}])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"temperature"
,
[
1.0
])
@
pytest
.
mark
.
parametrize
(
"prefill_chunk_size"
,
[
-
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_mlp_e2e_seeded_correctness
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
temperature
:
float
,
prefill_chunk_size
:
int
,
seed
:
int
):
"""Verify seeded runs produce the same output."""
maybe_enable_chunked_prefill
(
prefill_chunk_size
,
test_llm_kwargs
)
maybe_enable_chunked_prefill
(
prefill_chunk_size
,
baseline_llm_kwargs
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
temperature
=
temperature
,
seed
=
seed
)
# Ensure this same test does fail if we _don't_ include per-request seeds
with
pytest
.
raises
(
AssertionError
):
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
temperature
=
temperature
,
seed
=
seed
,
disable_seed
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"block_size"
:
16
,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override"
:
2
+
256
//
8
,
"max_model_len"
:
(
2
+
256
//
8
)
*
8
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model_name"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_config"
:
{
"model"
:
SPEC_MODEL
,
},
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use small output len for fast test.
128
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"prefill_chunk_size"
,
[
-
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_mlp_e2e_greedy_correctness_with_preemption
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
prefill_chunk_size
:
int
,
seed
:
int
):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
maybe_enable_chunked_prefill
(
prefill_chunk_size
,
test_llm_kwargs
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"block_size"
:
16
,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override"
:
2
+
256
//
8
,
"max_model_len"
:
(
2
+
256
//
8
)
*
8
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model_name"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_config"
:
{
"model"
:
SPEC_MODEL
,
},
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use small output len for fast test.
128
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"prefill_chunk_size"
,
[
-
1
,
4
])
def
test_mlp_e2e_greedy_correctness_with_padding
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
prefill_chunk_size
:
int
,
seed
:
int
):
"""Verify greedy equality when the vocab dimension is padded
"""
maybe_enable_chunked_prefill
(
prefill_chunk_size
,
test_llm_kwargs
)
# Default pad_to is 64, test model has vocab_size of 32000
def
patched_pad_vocab_size
(
vocab_size
,
pad_to
=
None
):
return
pad_vocab_size
(
vocab_size
,
pad_to
=
32064
)
with
patch
(
"vllm.model_executor.layers.vocab_parallel_embedding.pad_vocab_size"
,
patched_pad_vocab_size
):
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model_name"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_config"
:
{
"model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
k
,
},
}
# Try a range of num. speculative tokens
for
k
in
range
(
1
,
1
+
MAX_SPEC_TOKENS
)
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"prefill_chunk_size"
,
[
-
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_mlp_different_k
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
prefill_chunk_size
:
int
,
seed
:
int
,
output_len
:
int
):
"""Verify that mlp speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
maybe_enable_chunked_prefill
(
prefill_chunk_size
,
test_llm_kwargs
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model_name"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_config"
:
{
"model"
:
SPEC_MODEL
,
"disable_by_batch_size"
:
4
,
},
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
# Speculative decoding is disabled when sequences reach decoding and the batch
# consists of single-token requests. Hence we set `max_num_seqs`
# >= `speculative_disable_by_batch_size` to test feature interaction.
@
pytest
.
mark
.
parametrize
(
"prefill_chunk_size"
,
[
-
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_mlp_disable_queue
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
prefill_chunk_size
:
int
,
seed
:
int
,
output_len
:
int
):
"""Verify that mlp speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
maybe_enable_chunked_prefill
(
prefill_chunk_size
,
test_llm_kwargs
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model_name"
:
MAIN_MODEL
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Precision
"dtype"
:
PRECISION
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_config"
:
{
"model"
:
SPEC_MODEL
,
"disable_mqa_scorer"
:
True
,
},
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"prefill_chunk_size"
,
[
-
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_mqa_scorer
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
prefill_chunk_size
:
int
,
seed
:
int
):
"""Verify that speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
maybe_enable_chunked_prefill
(
prefill_chunk_size
,
test_llm_kwargs
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
tests/spec_decode/e2e/test_mtp_correctness.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various number of speculative tokens.
With those tests, we can say at least, mtp would not break the
correctness for the target model outputs.
"""
import
pytest
from
.conftest
import
run_equality_correctness_test
# main model
MAIN_MODEL
=
"luccafong/deepseek_mtp_main_random"
# max. number of speculative tokens: this corresponds to
# num_nextn_predict_layers in the config.json of the speculator model.
MAX_SPEC_TOKENS
=
1
# precision
PRECISION
=
"bfloat16"
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Print spec metrics.
"disable_log_stats"
:
False
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model_name"
:
MAIN_MODEL
,
# GPU memory utilization
"gpu_memory_utilization"
:
0.85
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_config"
:
{
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
},
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
128
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_mtp_e2e_greedy_correctness
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
output_len
,
seed
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Print spec metrics.
"disable_log_stats"
:
False
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model_name"
:
MAIN_MODEL
,
# GPU memory utilization
"gpu_memory_utilization"
:
0.85
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_config"
:
{
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"disable_logprobs"
:
False
,
},
},
{
"speculative_config"
:
{
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"disable_logprobs"
:
True
,
},
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
128
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"logprobs"
,
[
1
,
6
])
def
test_mtp_e2e_greedy_logprobs
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
,
logprobs
:
int
):
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
output_len
,
seed
,
logprobs
=
logprobs
,
prompt_logprobs
=
logprobs
,
disable_logprobs
=
test_llm_kwargs
[
"speculative_config"
]
[
"disable_logprobs"
])
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"enforce_eager"
:
False
,
# Print spec metrics.
"disable_log_stats"
:
False
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model_name"
:
MAIN_MODEL
,
"gpu_memory_utilization"
:
0.85
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_config"
:
{
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
},
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
128
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_mtp_e2e_greedy_correctness_cuda_graph
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality with cuda graph enabled and different
batch sizes."""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
output_len
,
seed
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"block_size"
:
8
,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override"
:
2
+
256
//
8
,
"max_model_len"
:
(
2
+
256
//
8
)
*
8
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model_name"
:
MAIN_MODEL
,
# GPU memory utilization
"gpu_memory_utilization"
:
0.9
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_config"
:
{
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
},
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use small output len for fast test.
128
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_mtp_e2e_greedy_correctness_with_preemption
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
output_len
,
seed
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model_name"
:
MAIN_MODEL
,
# GPU memory utilization
"gpu_memory_utilization"
:
0.9
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_config"
:
{
"num_speculative_tokens"
:
k
,
},
}
# Try a range of num. speculative tokens
for
k
in
range
(
1
,
1
+
MAX_SPEC_TOKENS
)
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_mtp_different_k
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify that mtp speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
output_len
,
seed
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model_name"
:
MAIN_MODEL
,
# GPU memory utilization
"gpu_memory_utilization"
:
0.9
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_config"
:
{
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"disable_by_batch_size"
:
4
},
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_mtp_disable_queue
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify that mtp speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
output_len
,
seed
)
if
__name__
==
"__main__"
:
import
pytest
pytest
.
main
([
__file__
])
tests/spec_decode/e2e/test_multistep_correctness.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""The tests in this file verify end-to-end speculative decoding correctness.
This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality. This gives us good coverage of temp=0.
At temp=0, the TypicalAcceptanceSampler ensures that only the tokens with the
highest probability in the target distribution are accepted. Therefore, we can
expect greedy equality for the TypicalAcceptanceSampler at temp=0.
For temp>0, we rely on unit tests on the rejection sampler to verify that the
output distribution is the same with spec decode vs. no spec decode (this would
be prohibitively expensive to run with a real model). Similarly, for the
TypicalAcceptance sampler also, we rely on unit tests to validate temp>0
test cases.
NOTE: Speculative decoding's distribution equality requires that the measured
distributions of the target model and proposal model be deterministic given the
same input. vLLM largely guarantees this.
@cadedaniel has seen cases where the output probabilities of a draft/target
model change slightly with certain batch sizes or prompts, even with Torch
determinism flags set. It is unclear if this is a bug in vLLM, due to non-
determinism in on-device batched operations, a bug in vLLM's spec decode
implementation, or the "hardware numerics" limitations. Either way, rejection
sampling ensures the output distribution matches the target model, but it breaks
greedy-equality tests for those batch sizes/prompts.
"""
from
itertools
import
cycle
import
pytest
from
transformers
import
AutoTokenizer
from
vllm
import
SamplingParams
from
...utils
import
create_new_process_for_each_test
from
.conftest
import
(
get_output_from_llm_generator
,
run_equality_correctness_test
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Use a small model for a fast test.
# Note this is repeated in the test body; to initialize a tokenizer.
"model"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# The original model is float32, keep it for numerical stability.
"dtype"
:
"float32"
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
{
"speculative_config"
:
{
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
"enable_chunked_prefill"
:
False
,
},
{
# Chunked prefill enabled with small value
# to make sure we get mixed batches.
"speculative_config"
:
{
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
},
{
# Verify the detokenizer assertions in the test work when spec
# decode is disabled.
},
])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
create_new_process_for_each_test
()
def
test_spec_decode_e2e_with_detokenization
(
test_llm_generator
,
batch_size
:
int
):
"""Run generation with speculative decoding on a batch. Verify the engine
generates the correct number of tokens (via ignore_eos=True), and that the
detokenization matches HF transformers.
"""
output_len
=
32
temperature
=
0.0
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
prompts
=
[
prompt
for
prompt
,
_
in
zip
(
cycle
(
prompts
),
range
(
batch_size
))]
sampling_params
=
SamplingParams
(
max_tokens
=
output_len
,
ignore_eos
=
True
,
temperature
=
temperature
,
)
batch_tokens
,
batch_token_ids
,
_
=
get_output_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
# Expect a generation for each prompt in the batch.
assert
len
(
batch_token_ids
)
==
len
(
prompts
)
# Expect each generation to have expected number of tokens (note ignore_eos
# is True).
assert
[
len
(
token_ids
)
for
token_ids
in
batch_token_ids
]
==
([
output_len
]
*
batch_size
)
# Expect detokenized string to match.
tok
=
AutoTokenizer
.
from_pretrained
(
"JackFram/llama-68m"
)
for
actual_tokens
,
actual_token_ids
in
zip
(
batch_tokens
,
batch_token_ids
):
expected_tokens
=
tok
.
decode
(
actual_token_ids
)
print
(
f
"
{
actual_token_ids
=
}
"
)
assert
actual_tokens
.
strip
()
==
expected_tokens
.
strip
()
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Print spec metrics.
"disable_log_stats"
:
False
,
# The original model is float32, keep it for numerical stability.
"dtype"
:
"float32"
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
# Try two different tiny base models.
# Note that one is equal to the draft model, another isn't.
{
"model_name"
:
"JackFram/llama-68m"
,
},
{
"model_name"
:
"JackFram/llama-160m"
,
},
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_config"
:
{
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"disable_logprobs"
:
False
,
},
"enable_chunked_prefill"
:
False
,
},
{
"speculative_config"
:
{
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
3
,
"disable_logprobs"
:
False
,
},
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
,
}])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use long output len for the small model test.
10
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
create_new_process_for_each_test
()
def
test_spec_decode_e2e_greedy_correctness_tiny_model_bs1
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality on a tiny model with batch size of one.
Since this test is cheaper than other e2e correctness tests, we generate
with a higher output_len.
When the draft model is the same as the target model, we further check
whether all speculative tokens are accepted.
"""
ensure_all_accepted
=
per_test_common_llm_kwargs
.
get
(
"model_name"
)
==
test_llm_kwargs
.
get
(
"speculative_config"
)[
"model"
]
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
prompt_logprobs
=
2
,
logprobs
=
2
,
disable_logprobs
=
False
,
temperature
=
0.0
,
ensure_all_accepted
=
ensure_all_accepted
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Print spec metrics.
"disable_log_stats"
:
False
,
# The original model is float32, keep it for numerical stability.
"dtype"
:
"float32"
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
# Try two different tiny base models.
# Note that one is equal to the draft model, another isn't.
{
"model_name"
:
"JackFram/llama-68m"
,
},
{
"model_name"
:
"JackFram/llama-160m"
,
},
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_config"
:
{
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
"enable_chunked_prefill"
:
False
,
},
{
"speculative_config"
:
{
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use small output len for fast test.
256
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
create_new_process_for_each_test
()
def
test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality on a tiny model and large batch size.
"""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# The original model is float32, keep it for numerical stability.
"dtype"
:
"float32"
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
# Try two different tiny base models.
# Note that one is equal to the draft model, another isn't.
{
"model_name"
:
"JackFram/llama-68m"
,
},
{
"model_name"
:
"JackFram/llama-160m"
,
},
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_config"
:
{
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
"enable_chunked_prefill"
:
False
,
},
{
"speculative_config"
:
{
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
},
])
@
pytest
.
mark
.
parametrize
(
"max_output_len"
,
[
256
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
create_new_process_for_each_test
()
def
test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
max_output_len
:
int
,
seed
:
int
):
"""Verify greedy equality on a tiny model, with a large batch size, and when
sampling respects the EOS token.
"""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
,
seed
=
seed
,
temperature
=
0.0
,
ignore_eos
=
False
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# A "real" model (not tiny).
"model_name"
:
"meta-llama/Llama-2-7b-chat-hf"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Print spec metrics.
"disable_log_stats"
:
False
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_config"
:
{
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
"enable_chunked_prefill"
:
False
,
},
{
"speculative_config"
:
{
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
},
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use decently long output len for a high quality test.
256
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
create_new_process_for_each_test
()
def
test_spec_decode_e2e_greedy_correctness_real_model_bs1
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality on a "real" model and batch size of 1. This is
separate from large BS tests to make identifying the source of bugs easier.
"""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# A "real" model (not tiny).
"model_name"
:
"meta-llama/Llama-2-7b-chat-hf"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Print spec metrics.
"disable_log_stats"
:
False
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_config"
:
{
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
"enable_chunked_prefill"
:
False
,
},
{
"speculative_config"
:
{
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
},
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
64
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
create_new_process_for_each_test
()
def
test_spec_decode_e2e_greedy_correctness_real_model_large_bs
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality with a "real" model on a nontrivial batch size.
This is the closest test to a real production workload.
"""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"block_size"
:
16
,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override"
:
2
+
256
//
8
,
"max_model_len"
:
(
2
+
256
//
8
)
*
8
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# The original model is float32, keep it for numerical stability.
"dtype"
:
"float32"
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
{
"model_name"
:
"JackFram/llama-160m"
,
},
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_config"
:
{
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
"enable_chunked_prefill"
:
False
,
},
{
"speculative_config"
:
{
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use small output len for fast test.
256
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
create_new_process_for_each_test
()
def
test_spec_decode_e2e_greedy_correctness_with_preemption
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model_name"
:
"JackFram/llama-160m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# The original model is float32, keep it for numerical stability.
"dtype"
:
"float32"
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
# https://github.com/triton-lang/triton/issues/2266 tl.dot
# doesn't support embedding < 16
{
"block_size"
:
16
,
},
{
"block_size"
:
32
,
},
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_config"
:
{
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
"enable_chunked_prefill"
:
False
,
},
{
"speculative_config"
:
{
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
},
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
create_new_process_for_each_test
()
def
test_spec_decode_different_block_size
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality over different block sizes.
"""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model_name"
:
"JackFram/llama-160m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# The original model is float32, keep it for numerical stability.
"dtype"
:
"float32"
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_config"
:
{
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"max_model_len"
:
32
,
},
"enable_chunked_prefill"
:
False
,
},
{
"speculative_config"
:
{
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"max_model_len"
:
32
,
},
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
,
},
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# This must be a good bit larger than speculative_max_model_len so that
# we can test the case where all seqs are skipped, but still small to
# ensure fast test.
64
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
create_new_process_for_each_test
()
def
test_skip_speculation
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality when some (or all) sequences skip speculation.
We do this by setting the max model len of the draft model to an
artificially low value, such that when the sequences grow beyond it, they
are skipped in speculative decoding.
"""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model_name"
:
"JackFram/llama-160m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# The original model is float32, keep it for numerical stability.
"dtype"
:
"float32"
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_config"
:
{
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"disable_by_batch_size"
:
2
,
},
"enable_chunked_prefill"
:
False
,
},
{
"speculative_config"
:
{
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"disable_by_batch_size"
:
2
,
},
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
,
},
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
10
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
create_new_process_for_each_test
()
def
test_disable_speculation
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality when all sequences disable speculation.
"""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model_name"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# The original model is float32, keep it for numerical stability.
"dtype"
:
"float32"
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_config"
:
{
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
k
,
},
"enable_chunked_prefill"
:
False
,
}
# Try a range of common k, as well as large speculation.
for
k
in
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
63
]
]
+
[{
"speculative_config"
:
{
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
k
,
},
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
,
}
for
k
in
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
63
]])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
create_new_process_for_each_test
()
def
test_many_k
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify that speculative decoding produces exact equality to without spec
decode with many different values of k.
"""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model_name"
:
"JackFram/llama-160m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# The original model is float32, keep it for numerical stability.
"dtype"
:
"float32"
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_config"
:
{
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
k
,
"acceptance_method"
:
"typical_acceptance_sampler"
,
},
"enable_chunked_prefill"
:
False
}
# Try a range of common k.
for
k
in
[
1
,
2
,
3
]
]
+
[{
"speculative_config"
:
{
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
k
,
"acceptance_method"
:
"typical_acceptance_sampler"
,
},
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
}
for
k
in
[
1
,
2
,
3
]])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
create_new_process_for_each_test
()
def
test_typical_acceptance_sampling
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify that speculative decoding produces exact equality to without spec
decode with TypicalAcceptanceSampler as the draft token acceptance
sampling method.
"""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
tests/spec_decode/e2e/test_ngram_correctness.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
For ngram lookup, its idea comes from https://github.com/apoorvumang/prompt-lookup-decoding,
and is merged into transform code base: https://github.com/huggingface/transformers/pull/27775.
Since there is no model is needed for generate the proposal, we could make
the testcase much simpler than drafter multi-step one.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various ngram sizes / speculative sizes
With those tests, we can say at least, ngram spec would not break the
correctness for the target model outputs.
"""
import
pytest
from
..utils
import
maybe_enable_chunked_prefill
from
.conftest
import
run_equality_correctness_test
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Print spec metrics.
"disable_log_stats"
:
False
,
# The original model is float32, keep it for numerical stability.
"dtype"
:
"float32"
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
{
"model_name"
:
"JackFram/llama-68m"
,
},
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_config"
:
{
"method"
:
"ngram"
,
"num_speculative_tokens"
:
5
,
"prompt_lookup_max"
:
3
,
"disable_mqa_scorer"
:
False
,
},
},
{
"speculative_config"
:
{
"method"
:
"ngram"
,
"num_speculative_tokens"
:
5
,
"prompt_lookup_max"
:
3
,
"disable_mqa_scorer"
:
True
,
},
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
256
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"prefill_chunk_size"
,
[
-
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_ngram_e2e_greedy_correctness
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
prefill_chunk_size
:
int
,
seed
:
int
):
"""Verify greedy equality on a tiny model with different batch size."""
maybe_enable_chunked_prefill
(
prefill_chunk_size
,
common_llm_kwargs
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Print spec metrics.
"disable_log_stats"
:
False
,
# The original model is float32, keep it for numerical stability.
"dtype"
:
"float32"
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
{
"model_name"
:
"JackFram/llama-68m"
,
},
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_config"
:
{
"method"
:
"ngram"
,
"num_speculative_tokens"
:
5
,
"prompt_lookup_max"
:
3
,
"disable_logprobs"
:
False
,
},
},
{
"speculative_config"
:
{
"method"
:
"ngram"
,
"num_speculative_tokens"
:
5
,
"prompt_lookup_max"
:
3
,
"disable_logprobs"
:
True
,
},
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
8
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"logprobs"
,
[
1
,
6
])
def
test_ngram_e2e_greedy_logprobs
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
,
logprobs
:
int
):
"""Verify greedy equality on a tiny model with different batch size."""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
,
logprobs
=
logprobs
,
prompt_logprobs
=
logprobs
,
disable_logprobs
=
test_llm_kwargs
[
"speculative_config"
]
[
"disable_logprobs"
])
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"block_size"
:
16
,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override"
:
2
+
256
//
8
,
"max_model_len"
:
(
2
+
256
//
8
)
*
8
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# The original model is float32, keep it for numerical stability.
"dtype"
:
"float32"
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
{
"model_name"
:
"JackFram/llama-160m"
,
},
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_config"
:
{
"method"
:
"ngram"
,
"num_speculative_tokens"
:
5
,
"prompt_lookup_max"
:
3
,
},
"enable_chunked_prefill"
:
False
,
},
{
"speculative_config"
:
{
"method"
:
"ngram"
,
"num_speculative_tokens"
:
5
,
"prompt_lookup_max"
:
3
,
"disable_mqa_scorer"
:
True
,
},
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use small output len for fast test.
256
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_ngram_e2e_greedy_correctness_with_preemption
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
temperature
=
0
,
seed
=
seed
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model_name"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# The original model is float32, keep it for numerical stability.
"dtype"
:
"float32"
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_config"
:
{
"method"
:
"ngram"
,
"num_speculative_tokens"
:
k
,
"prompt_lookup_max"
:
3
,
},
}
# Try a range of common k, as well as large speculation.
for
k
in
[
1
,
3
,
5
]
]
+
[
{
"speculative_config"
:
{
"method"
:
"ngram"
,
"num_speculative_tokens"
:
k
,
"prompt_lookup_max"
:
1
,
},
}
# Try a range of common k, as well as large speculation.
for
k
in
[
1
,
3
,
5
]
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_ngram_different_k
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and
different ngram prompt_lookup_max.
"""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model_name"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# The original model is float32, keep it for numerical stability.
"dtype"
:
"float32"
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_config"
:
{
"method"
:
"ngram"
,
"num_speculative_tokens"
:
5
,
"prompt_lookup_max"
:
3
,
"disable_by_batch_size"
:
4
},
},
{
"speculative_config"
:
{
"method"
:
"ngram"
,
"num_speculative_tokens"
:
5
,
"prompt_lookup_max"
:
3
,
"disable_by_batch_size"
:
4
,
"disable_mqa_scorer"
:
True
,
},
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_ngram_disable_queue
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and
different ngram prompt_lookup_max.
"""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model_name"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# The original model is float32, keep it for numerical stability.
"dtype"
:
"float32"
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_config"
:
{
"method"
:
"ngram"
,
"num_speculative_tokens"
:
5
,
"prompt_lookup_max"
:
3
,
"disable_mqa_scorer"
:
True
,
},
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_ngram_scorer
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify that ngram speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
tests/spec_decode/e2e/test_seed.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
from
.conftest
import
run_equality_correctness_test
# main model
MAIN_MODEL
=
"JackFram/llama-68m"
# speculative model
SPEC_MODEL
=
"JackFram/llama-160m"
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model_name"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# speculative config
"speculative_config"
:
{
"model"
:
"JackFram/llama-160m"
,
"num_speculative_tokens"
:
3
,
},
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{
"seed"
:
1
}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"seed"
:
5
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
8
,
32
])
@
pytest
.
mark
.
parametrize
(
"temperature"
,
[
0.1
,
1.0
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
20
,
])
def
test_seeded_consistency
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
temperature
:
float
,
output_len
:
int
):
"""Verify outputs are consistent across multiple runs with same seed
"""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
temperature
=
temperature
,
disable_seed
=
False
,
)
# Ensure this same test does fail if we _don't_ include per-request seeds
with
pytest
.
raises
(
AssertionError
):
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
temperature
=
temperature
,
disable_seed
=
True
,
)
tests/spec_decode/test_batch_expansion.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
.utils
import
create_seq_group_metadata_from_prompts
,
mock_worker
@
pytest
.
mark
.
parametrize
(
'num_target_seq_ids'
,
[
100
])
@
pytest
.
mark
.
skip_global_cleanup
def
test_create_target_seq_id_iterator
(
num_target_seq_ids
:
int
):
"""Verify all new sequence ids are greater than all input
seq ids.
"""
scorer
=
BatchExpansionTop1Scorer
(
mock_worker
(),
'cuda:0'
,
32_000
)
all_seq_ids
=
[
[
1
,
3
,
5
,
7
],
list
(
range
(
100
))
+
[
0
],
[
100
],
]
for
seq_ids
in
all_seq_ids
:
max_seq_id
=
max
(
seq_ids
)
iterator
=
scorer
.
_create_target_seq_id_iterator
(
seq_ids
)
# pylint: disable=protected-access
for
_
in
range
(
num_target_seq_ids
):
assert
next
(
iterator
)
>
max_seq_id
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
6
])
@
pytest
.
mark
.
skip_global_cleanup
def
test_get_token_ids_to_score
(
k
:
int
):
"""Verify correct tokens are selected for scoring.
"""
proposal_token_ids
=
torch
.
tensor
(
list
(
range
(
k
)),
dtype
=
torch
.
int64
,
device
=
'cuda'
,
)
expected_output
:
list
[
list
[
int
]]
=
[
[],
]
for
i
in
range
(
proposal_token_ids
.
shape
[
0
]):
expected_output
.
append
(
proposal_token_ids
[:
i
+
1
].
tolist
())
scorer
=
BatchExpansionTop1Scorer
(
mock_worker
(),
'cuda:0'
,
32_000
)
actual_output
=
scorer
.
_get_token_ids_to_score
(
proposal_token_ids
.
tolist
())
# pylint: disable=protected-access
actual_output
=
[
x
.
tolist
()
if
isinstance
(
x
,
torch
.
Tensor
)
else
x
for
x
in
actual_output
]
assert
actual_output
==
expected_output
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
6
])
@
pytest
.
mark
.
skip_global_cleanup
def
test_create_single_target_seq_group_metadata
(
k
:
int
):
"""Verify correct creation of a batch-expanded seq group metadata.
"""
prompt_tokens
=
[
1
,
2
,
3
]
prev_output_tokens
=
[
4
,
5
,
6
]
token_ids
=
list
(
range
(
k
))
num_tokens_processed
=
len
(
prompt_tokens
)
+
len
(
prev_output_tokens
)
-
1
final_seq_len
=
len
(
prompt_tokens
)
+
len
(
prev_output_tokens
)
+
len
(
token_ids
)
block_size
=
32
input_seq_group_metadata
=
create_seq_group_metadata_from_prompts
(
[
prompt_tokens
],
2048
//
block_size
,
block_size
,
[
final_seq_len
],
[
prev_output_tokens
],
[
num_tokens_processed
])[
0
]
input_seq_id
=
list
(
input_seq_group_metadata
.
seq_data
.
keys
())[
0
]
target_seq_id
=
100
scorer
=
BatchExpansionTop1Scorer
(
mock_worker
(),
'cuda:0'
,
32_000
)
output
=
scorer
.
_create_single_target_seq_group_metadata
(
# pylint: disable=protected-access
input_seq_group_metadata
,
input_seq_id
,
target_seq_id
,
token_ids
,
input_seq_group_metadata
.
sampling_params
,
)
assert
output
.
request_id
==
input_seq_group_metadata
.
request_id
assert
output
.
sampling_params
.
repetition_penalty
==
\
input_seq_group_metadata
.
sampling_params
.
repetition_penalty
assert
output
.
sampling_params
.
temperature
==
\
input_seq_group_metadata
.
sampling_params
.
temperature
assert
output
.
sampling_params
.
top_p
==
\
input_seq_group_metadata
.
sampling_params
.
top_p
assert
output
.
sampling_params
.
top_k
==
\
input_seq_group_metadata
.
sampling_params
.
top_k
assert
len
(
output
.
seq_data
)
==
1
assert
output
.
seq_data
[
target_seq_id
].
get_prompt_token_ids
()
==
tuple
(
prompt_tokens
)
assert
output
.
seq_data
[
target_seq_id
].
get_output_token_ids
()
==
tuple
(
prev_output_tokens
+
token_ids
)
assert
len
(
output
.
block_tables
)
==
1
assert
output
.
block_tables
[
target_seq_id
]
==
input_seq_group_metadata
.
block_tables
[
input_seq_id
]
tests/spec_decode/test_dynamic_spec_decode.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
unittest.mock
import
MagicMock
,
patch
import
pytest
import
torch
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.spec_decode.metrics
import
AsyncMetricsCollector
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.spec_decode_worker
import
SpecDecodeWorker
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
.test_utils
import
mock_spec_decode_sampler
from
.utils
import
create_batch
,
mock_worker
@
pytest
.
mark
.
parametrize
(
'queue_size'
,
[
4
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
])
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
torch
.
inference_mode
()
def
test_disable_spec_tokens
(
queue_size
:
int
,
batch_size
:
int
,
k
:
int
,
acceptance_sampler_method
:
str
):
"""Verify that speculative tokens are disabled when the batch size
exceeds the threshold.
"""
disable_by_batch_size
=
3
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
target_worker
=
mock_worker
()
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
worker
=
SpecDecodeWorker
(
proposer_worker
=
draft_worker
,
scorer_worker
=
target_worker
,
spec_decode_sampler
=
mock_spec_decode_sampler
(
acceptance_sampler_method
),
disable_logprobs
=
False
,
metrics_collector
=
metrics_collector
,
disable_by_batch_size
=
disable_by_batch_size
)
exception_secret
=
'artificial stop'
draft_worker
.
get_spec_proposals
.
side_effect
=
ValueError
(
exception_secret
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
,
running_queue_size
=
queue_size
)
if
queue_size
>
disable_by_batch_size
:
with
patch
.
object
(
worker
,
'_run_no_spec'
,
side_effect
=
ValueError
(
exception_secret
)),
\
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
# When the batch size is larger than the threshold,
# we expect no speculative tokens (0).
expected_num_spec_tokens
=
None
if
queue_size
<
disable_by_batch_size
else
0
assert
seq_group_metadata_list
[
0
].
num_speculative_tokens
==
expected_num_spec_tokens
draft_worker
.
sampler_output
.
side_effect
=
ValueError
(
exception_secret
)
proposer
=
Top1Proposer
(
worker
=
draft_worker
,
device
=
'cpu'
,
# not used
vocab_size
=
100
,
# not used
# Must be long enough to avoid being skipped due to length.
max_proposal_len
=
1024
,
)
if
queue_size
<
disable_by_batch_size
:
# Should raise exception when executing the mocked draft model.
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
proposer
.
get_spec_proposals
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
),
seq_ids_with_bonus_token_in_last_step
=
set
())
else
:
# Should not execute the draft model because spec decode is disabled
# for all requests. Accordingly, the proposal length should be 0.
proposals
=
proposer
.
get_spec_proposals
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
),
seq_ids_with_bonus_token_in_last_step
=
set
())
assert
proposals
.
proposal_lens
.
tolist
()
==
[
0
]
*
batch_size
tests/spec_decode/test_memory_usage.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""This docstring details important information on the testing methodology.
This test verifies that memory usage remains constant (or never grows) when
we enable / disable speculation via --speculative-disable-by-batch-size.
There are a lot of things we try to keep track of between batches of requests
and if certain tensors are not freed from memory, can result in CUDA ooms.
This is particularly relevant for production situations where speculation might
be enabled during off hours, but disabled once traffic peaks during the workday.
Since traffic will stay high for a long period of time, verifying we do not
increase our memory usage over time is essential to prevent possible CUDA ooms.
"""
import
torch
import
vllm
from
tests.core.utils
import
create_dummy_prompt
from
vllm.sequence
import
SequenceGroup
ITERATIONS
=
100
MAIN_MODEL
=
"JackFram/llama-68m"
# speculative model
SPEC_MODEL
=
"abhigoyal/vllm-medusa-llama-68m-random"
BATCH_SIZE
=
5
SPEC_DISABLE_BATCH_SIZE
=
2
def
add_seq_group_to_engine
(
engine
:
vllm
.
LLMEngine
,
seq_group
:
SequenceGroup
):
scheduler
=
engine
.
scheduler
[
0
]
scheduler
.
add_seq_group
(
seq_group
)
"""
Since we are using a batch size greater than the disabled batch size,
we can ensure we go through the _no_spec codepath for most of our engine steps.
"""
def
test_memory_usage_no_spec
():
previous_memory_allocated
=
None
llm
=
vllm
.
LLM
(
model
=
MAIN_MODEL
,
speculative_config
=
{
"model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
3
,
"disable_by_batch_size"
:
SPEC_DISABLE_BATCH_SIZE
,
})
batch_sequences
=
set
()
engine
=
llm
.
llm_engine
for
i
in
range
(
ITERATIONS
):
seq
,
seq_group
=
create_dummy_prompt
(
request_id
=
str
(
i
),
prompt_length
=
10
,
min_tokens
=
10
,
max_tokens
=
10
)
add_seq_group_to_engine
(
engine
,
seq_group
)
batch_sequences
.
add
(
seq
)
engine
.
step
()
for
seq
in
list
(
batch_sequences
):
if
seq
.
is_finished
():
batch_sequences
.
remove
(
seq
)
# If we aren't at our batch size yet, continue
if
len
(
batch_sequences
)
<=
BATCH_SIZE
:
continue
# Otherwise, loop until at least one request is done
while
not
any
(
seq
.
is_finished
()
for
seq
in
batch_sequences
):
engine
.
step
()
# Remove it from the set
for
seq
in
list
(
batch_sequences
):
if
seq
.
is_finished
():
batch_sequences
.
remove
(
seq
)
# At this point, we are always at the case where we have finished
# processing some number of requests from the batch after running
# several _no_spec executions. The memory should not have
# increased between the previous time this was recorded and the
# current time.
if
previous_memory_allocated
is
None
:
previous_memory_allocated
=
torch
.
cuda
.
memory_allocated
()
else
:
assert
previous_memory_allocated
==
torch
.
cuda
.
memory_allocated
()
tests/spec_decode/test_metrics.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
from
unittest.mock
import
MagicMock
import
pytest
import
torch
from
vllm.spec_decode.metrics
import
AsyncMetricsCollector
def
test_initial_call_returns_none
():
"""Expect first call to get metrics to return None.
"""
spec_decode_sampler
=
MagicMock
()
spec_decode_sampler
.
num_accepted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
spec_decode_sampler
.
num_emitted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
spec_decode_sampler
.
num_draft_tokens
=
0
collector
=
AsyncMetricsCollector
(
spec_decode_sampler
)
collector
.
init_gpu_tensors
(
rank
=
0
)
maybe_metrics
=
collector
.
maybe_collect_rejsample_metrics
(
k
=
5
)
assert
maybe_metrics
is
None
def
test_second_call_returns_metrics
():
"""Expect second call to not return None.
"""
spec_decode_sampler
=
MagicMock
()
spec_decode_sampler
.
num_accepted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
spec_decode_sampler
.
num_emitted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
spec_decode_sampler
.
num_draft_tokens
=
0
collect_interval_s
=
5.0
timer
=
MagicMock
()
timer
.
side_effect
=
[
0.0
,
collect_interval_s
+
0.1
,
collect_interval_s
+
0.2
]
collector
=
AsyncMetricsCollector
(
spec_decode_sampler
=
spec_decode_sampler
,
timer
=
timer
,
collect_interval_s
=
collect_interval_s
)
collector
.
init_gpu_tensors
(
rank
=
0
)
_
=
collector
.
maybe_collect_rejsample_metrics
(
k
=
5
)
metrics
=
collector
.
maybe_collect_rejsample_metrics
(
k
=
5
)
assert
metrics
is
not
None
@
pytest
.
mark
.
parametrize
(
"rank"
,
[
1
,
2
,
3
,
4
])
def
test_nonzero_rank_noop
(
rank
):
"""Verify nonzero ranks don't collect metrics.
"""
spec_decode_sampler
=
MagicMock
()
spec_decode_sampler
.
num_accepted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
spec_decode_sampler
.
num_emitted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
spec_decode_sampler
.
num_draft_tokens
=
0
collector
=
AsyncMetricsCollector
(
spec_decode_sampler
)
collector
.
init_gpu_tensors
(
rank
=
rank
)
_
=
collector
.
maybe_collect_rejsample_metrics
(
k
=
5
)
metrics
=
collector
.
maybe_collect_rejsample_metrics
(
k
=
5
)
assert
metrics
is
None
def
test_noop_until_time
():
"""Verify metrics aren't collected until enough time passes.
"""
spec_decode_sampler
=
MagicMock
()
spec_decode_sampler
.
num_accepted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
spec_decode_sampler
.
num_emitted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
spec_decode_sampler
.
num_draft_tokens
=
0
collect_interval_s
=
5.0
timer
=
MagicMock
()
timer
.
side_effect
=
[
0.0
,
collect_interval_s
-
0.1
,
collect_interval_s
-
0.1
,
collect_interval_s
+
0.1
,
collect_interval_s
+
0.1
]
collector
=
AsyncMetricsCollector
(
spec_decode_sampler
=
spec_decode_sampler
,
timer
=
timer
,
collect_interval_s
=
collect_interval_s
)
collector
.
init_gpu_tensors
(
rank
=
0
)
_
=
collector
.
maybe_collect_rejsample_metrics
(
k
=
5
)
metrics
=
collector
.
maybe_collect_rejsample_metrics
(
k
=
5
)
assert
metrics
is
None
_
=
collector
.
maybe_collect_rejsample_metrics
(
k
=
5
)
metrics
=
collector
.
maybe_collect_rejsample_metrics
(
k
=
5
)
assert
metrics
is
not
None
def
test_timer_is_reset
():
"""Verify that the internal timer inside AsyncMetricsCollector
is reset after collection.
"""
spec_decode_sampler
=
MagicMock
()
spec_decode_sampler
.
num_accepted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
spec_decode_sampler
.
num_emitted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
spec_decode_sampler
.
num_draft_tokens
=
0
collect_interval_s
=
5.0
timer
=
MagicMock
()
timer
.
side_effect
=
[
0.0
,
collect_interval_s
+
0.1
,
collect_interval_s
+
0.1
,
collect_interval_s
+
0.2
,
collect_interval_s
+
0.2
,
2
*
collect_interval_s
+
0.1
,
2
*
collect_interval_s
+
0.1
,
]
collector
=
AsyncMetricsCollector
(
spec_decode_sampler
=
spec_decode_sampler
,
timer
=
timer
,
collect_interval_s
=
collect_interval_s
)
collector
.
init_gpu_tensors
(
rank
=
0
)
_
=
collector
.
maybe_collect_rejsample_metrics
(
k
=
5
)
metrics
=
collector
.
maybe_collect_rejsample_metrics
(
k
=
5
)
assert
metrics
is
not
None
_
=
collector
.
maybe_collect_rejsample_metrics
(
k
=
5
)
metrics
=
collector
.
maybe_collect_rejsample_metrics
(
k
=
5
)
assert
metrics
is
None
_
=
collector
.
maybe_collect_rejsample_metrics
(
k
=
5
)
metrics
=
collector
.
maybe_collect_rejsample_metrics
(
k
=
5
)
assert
metrics
is
not
None
@
pytest
.
mark
.
parametrize
(
"has_data"
,
[
True
,
False
])
def
test_initial_metrics_has_correct_values
(
has_data
:
bool
):
"""Test correctness of metrics data.
"""
if
has_data
:
num_accepted_tokens
=
103
num_emitted_tokens
=
104
num_draft_tokens
=
105
else
:
num_accepted_tokens
=
0
num_emitted_tokens
=
0
num_draft_tokens
=
0
k
=
5
max_num_emitted_tokens
=
AsyncMetricsCollector
.
get_max_num_emitted_tokens
(
num_draft_tokens
,
k
)
spec_decode_sampler
=
MagicMock
()
spec_decode_sampler
.
num_accepted_tokens
=
torch
.
tensor
(
num_accepted_tokens
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
spec_decode_sampler
.
num_emitted_tokens
=
torch
.
tensor
(
num_emitted_tokens
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
spec_decode_sampler
.
num_draft_tokens
=
num_draft_tokens
collect_interval_s
=
5.0
timer
=
MagicMock
()
timer
.
side_effect
=
[
0.0
,
collect_interval_s
+
0.1
,
collect_interval_s
+
0.2
]
collector
=
AsyncMetricsCollector
(
spec_decode_sampler
=
spec_decode_sampler
,
timer
=
timer
,
collect_interval_s
=
collect_interval_s
)
collector
.
init_gpu_tensors
(
rank
=
0
)
_
=
collector
.
maybe_collect_rejsample_metrics
(
k
)
metrics
=
collector
.
maybe_collect_rejsample_metrics
(
k
)
assert
metrics
.
num_spec_tokens
==
k
assert
metrics
.
accepted_tokens
==
num_accepted_tokens
assert
metrics
.
draft_tokens
==
num_draft_tokens
assert
metrics
.
emitted_tokens
==
num_emitted_tokens
if
has_data
:
assert
(
metrics
.
draft_acceptance_rate
==
num_accepted_tokens
/
num_draft_tokens
)
assert
(
metrics
.
system_efficiency
==
num_emitted_tokens
/
max_num_emitted_tokens
)
else
:
assert
math
.
isnan
(
metrics
.
draft_acceptance_rate
)
assert
math
.
isnan
(
metrics
.
system_efficiency
)
tests/spec_decode/test_multi_step_worker.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
random
from
unittest.mock
import
MagicMock
import
pytest
import
torch
from
vllm.attention.selector
import
(
_Backend
,
global_force_attn_backend_context_manager
)
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
(
ExecuteModelRequest
,
HiddenStates
,
Logprob
,
get_all_seq_ids
)
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.worker.worker
import
Worker
from
.utils
import
(
assert_logprobs_dict_allclose
,
create_batch
,
create_seq_group_metadata_from_prompts
,
create_worker
,
patch_execute_model_with_seeds
,
zero_kv_cache
)
@
pytest
.
mark
.
parametrize
(
'num_steps'
,
list
(
range
(
1
,
17
)))
def
test_assert_enough_kv_space
(
num_steps
:
int
):
"""Test that the multi step worker checks for sufficient space in the KV
cache. It should throw if it cannot run all the steps.
"""
block_size
=
16
num_gpu_blocks
=
2048
//
block_size
prompts
=
[
list
(
range
(
block_size
*
3
)),
list
(
range
(
block_size
*
2
)),
]
prev_output_tokens
=
[
list
(
range
(
block_size
*
1
)),
list
(
range
(
block_size
*
2
)),
]
final_prompt_lens
=
[
len
(
prompt
+
output
)
+
num_steps
for
prompt
,
output
in
zip
(
prompts
,
prev_output_tokens
)
]
inputs
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_prompt_lens
,
continuations
=
prev_output_tokens
)
assert_enough_kv_space
=
MultiStepWorker
.
_assert_enough_kv_space
# pylint: disable=protected-access
worker
=
MagicMock
()
worker
.
model_runner
.
block_size
=
block_size
for
seq_group_metadata
in
inputs
:
original_block_tables
=
seq_group_metadata
.
block_tables
# No exception.
assert_enough_kv_space
(
worker
,
inputs
,
num_steps
)
seq_group_metadata
.
block_tables
=
{
seq_id
:
[]
for
seq_id
,
physical_blocks
in
original_block_tables
.
items
()
}
# Expect exception.
with
pytest
.
raises
(
ValueError
,
match
=
'times but found insufficient KV space for'
):
assert_enough_kv_space
(
worker
,
inputs
,
num_steps
)
seq_group_metadata
.
block_tables
=
original_block_tables
@
torch
.
inference_mode
()
def
test_same_output_for_single_step
():
"""Verify the multi step worker produces the same output as the normal
worker for num_steps=1.
"""
seed
=
100
model_name
=
'JackFram/llama-68m'
block_size
=
32
num_gpu_blocks
=
2048
//
block_size
multi_step_worker
=
create_worker
(
MultiStepWorker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
,
model_runner_cls
=
TP1DraftModelRunner
,
)
worker
=
create_worker
(
Worker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
,
)
# multi_step_worker.model_runner = worker.model_runner
# multi_step_worker.cache_engine = worker.cache_engine
num_steps
=
1
prompts
=
[
[
1
,
2
,
3
,
4
,
5
],
[
6
,
7
,
8
,
9
,
10
],
]
final_prompt_lens
=
[
len
(
prompt
)
+
num_steps
for
prompt
in
prompts
]
multi_step_seq_group
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_prompt_lens
=
final_prompt_lens
)
zero_kv_cache
(
multi_step_worker
.
cache_engine
)
set_random_seed
(
seed
)
actual_output
,
_
=
multi_step_worker
.
sampler_output
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
multi_step_seq_group
),
sample_len
=
num_steps
,
seq_ids_with_bonus_token_in_last_step
=
set
())
assert
len
(
actual_output
)
==
num_steps
actual_output
=
actual_output
[
0
]
single_step_seq_group
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_prompt_lens
=
final_prompt_lens
)
zero_kv_cache
(
worker
.
cache_engine
)
set_random_seed
(
seed
)
expected_output
=
worker
.
execute_model
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
single_step_seq_group
))[
0
]
actual_token_ids
=
[
output
.
samples
[
0
].
output_token
for
output
in
actual_output
]
actual_logprobs
=
[
output
.
samples
[
0
].
logprobs
for
output
in
actual_output
]
expected_token_ids
=
[
output
.
samples
[
0
].
output_token
for
output
in
expected_output
]
expected_logprobs
=
[
output
.
samples
[
0
].
logprobs
for
output
in
expected_output
]
assert
actual_token_ids
==
expected_token_ids
print
(
f
'
{
actual_logprobs
=
}
'
)
print
(
f
'
{
expected_logprobs
=
}
'
)
assert_logprobs_dict_allclose
(
actual_logprobs
,
expected_logprobs
)
@
torch
.
inference_mode
()
def
test_same_output_for_multi_step
():
"""Verify the multi-step worker produces the same output as the normal
worker when num_steps > 1. This test runs the multi-step worker once, and
then runs the worker num_steps times, and compares the output.
"""
seed
=
100
model_name
=
'JackFram/llama-68m'
block_size
=
16
num_gpu_blocks
=
2048
//
block_size
multi_step_worker
=
create_worker
(
MultiStepWorker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
,
)
worker
=
create_worker
(
Worker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
,
)
# Make sure we go over the block boundary.
num_steps
=
block_size
+
1
random
.
seed
(
seed
)
prompts
=
[[
random
.
randint
(
0
,
1000
)
for
_
in
range
(
random
.
randint
(
10
,
20
))
]
for
_
in
range
(
10
)]
final_prompt_lens
=
[
len
(
prompt
)
+
num_steps
for
prompt
in
prompts
]
rand_seeds
=
list
(
random
.
randint
(
0
,
100
)
for
_
in
range
(
num_steps
))
multi_step_worker
.
execute_model
=
patch_execute_model_with_seeds
(
multi_step_worker
,
rand_seeds
)
worker
.
execute_model
=
patch_execute_model_with_seeds
(
worker
,
rand_seeds
)
continuations
=
[[
1
]
for
_
in
prompts
]
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
continuations
=
continuations
,
final_prompt_lens
=
final_prompt_lens
)
# Run multi-step.
zero_kv_cache
(
multi_step_worker
.
cache_engine
)
set_random_seed
(
seed
)
multi_step_output
,
_
=
multi_step_worker
.
sampler_output
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
),
sample_len
=
num_steps
,
seq_ids_with_bonus_token_in_last_step
=
set
())
# Run single-step repeatedly.
zero_kv_cache
(
worker
.
cache_engine
)
single_step_output
:
list
[
SamplerOutput
]
=
[]
continuations
=
[[
1
]
for
_
in
prompts
]
set_random_seed
(
seed
)
for
_
in
multi_step_output
:
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
continuations
=
continuations
,
final_prompt_lens
=
final_prompt_lens
)
single_step_output
.
extend
(
worker
.
execute_model
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
)))
# Append output tokens to new sequence data.
for
i
,
seq_group_output
in
enumerate
(
single_step_output
[
-
1
]):
continuations
[
i
].
append
(
seq_group_output
.
samples
[
0
].
output_token
)
# Get token ids and logprobs for comparison.
multi_step_output_logprobs
:
list
[
list
[
dict
[
int
,
Logprob
]]]
=
[[]
for
_
in
prompts
]
single_step_output_logprobs
:
list
[
list
[
dict
[
int
,
Logprob
]]]
=
[[]
for
_
in
prompts
]
multi_step_output_token_ids
:
list
[
list
[
int
]]
=
[[]
for
_
in
prompts
]
single_step_output_token_ids
:
list
[
list
[
int
]]
=
[[]
for
_
in
prompts
]
for
i
,
_
in
enumerate
(
prompts
):
for
multi_step
,
single_step
in
zip
(
multi_step_output
,
single_step_output
):
multi_step_output_token_ids
[
i
].
append
(
multi_step
[
i
].
samples
[
0
].
output_token
)
single_step_output_token_ids
[
i
].
append
(
single_step
[
i
].
samples
[
0
].
output_token
)
multi_step_output_logprobs
[
i
].
append
(
multi_step
[
i
].
samples
[
0
].
logprobs
)
single_step_output_logprobs
[
i
].
append
(
single_step
[
i
].
samples
[
0
].
logprobs
)
# Print per-sequence token ids
for
i
,
(
multi_step_tokens
,
single_step_tokens
)
in
enumerate
(
zip
(
multi_step_output_token_ids
,
single_step_output_token_ids
)):
print
(
f
'
{
i
=
}
{
multi_step_tokens
=
}
'
)
print
(
f
'
{
i
=
}
{
single_step_tokens
=
}
'
)
print
(
f
'
{
i
=
}
equal
{
multi_step_tokens
==
single_step_tokens
}
'
)
# Assert token ids are equal.
for
multi_step_tokens
,
single_step_tokens
in
zip
(
multi_step_output_token_ids
,
single_step_output_token_ids
):
assert
multi_step_tokens
==
single_step_tokens
# Assert logprobs are equal.
for
multi_step_logprobs
,
single_step_logprobs
in
zip
(
multi_step_output_logprobs
,
single_step_output_logprobs
):
assert_logprobs_dict_allclose
(
multi_step_logprobs
,
single_step_logprobs
)
@
torch
.
inference_mode
()
def
test_multi_step_with_batch_expansion_correct_output
():
"""
In this test we verify that the MultiStepWorker is able to handle bonus
tokens correctly. The test verifies that if a sequence has a
bonus token then the MultiStepWorker is able to expand the batch by adding
new sequences corresponding to the sequences with bonus tokens. The
expanded batch is then used for predicting the next tokens.
"""
seed
=
100
model_name
=
'JackFram/llama-68m'
block_size
=
16
num_gpu_blocks
=
2048
//
block_size
batch_size
=
128
multi_step_worker
=
create_worker
(
MultiStepWorker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
,
model_runner_cls
=
TP1DraftModelRunner
,
)
multi_step_worker
.
set_include_gpu_probs_tensor
()
worker
=
create_worker
(
Worker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
,
)
random
.
seed
(
seed
)
prompts
=
[[
0
]
for
_
in
range
(
batch_size
)]
num_steps
=
2
final_prompt_lens
=
[(
num_steps
+
1
)
for
prompt
in
prompts
]
rand_seeds
=
list
(
random
.
randint
(
0
,
100
)
for
_
in
range
(
num_steps
))
multi_step_worker
.
execute_model
=
patch_execute_model_with_seeds
(
multi_step_worker
,
rand_seeds
)
worker
.
execute_model
=
patch_execute_model_with_seeds
(
worker
,
rand_seeds
)
# Create the test continuations
continuations
=
[[
random
.
randint
(
0
,
1000
)]
for
_
in
prompts
]
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
continuations
=
continuations
,
final_prompt_lens
=
final_prompt_lens
)
# Run single-step twice to generate 2 tokens. This
# will simulate the bonus token case with the second token
# being the bonus token.
zero_kv_cache
(
worker
.
cache_engine
)
single_step_output
:
list
[
SamplerOutput
]
=
[]
set_random_seed
(
seed
)
for
_
in
range
(
num_steps
):
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
continuations
=
continuations
,
final_prompt_lens
=
final_prompt_lens
)
single_step_output
.
extend
(
worker
.
execute_model
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
)))
# Append output tokens to new sequence data.
for
i
,
seq_group_output
in
enumerate
(
single_step_output
[
-
1
]):
continuations
[
i
].
append
(
seq_group_output
.
samples
[
0
].
output_token
)
# Create continuations for the MultiStepWorker. The continuations have
# 2 tokens in order to simulate the bonus token case.
multi_step_continuations
=
[]
for
continuation
in
continuations
:
multi_step_continuations
.
append
(
continuation
[:
2
])
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
continuations
=
multi_step_continuations
,
final_prompt_lens
=
final_prompt_lens
)
# Run multi-step and verify that the third token prediction is accurate
# for all sequences.
zero_kv_cache
(
multi_step_worker
.
cache_engine
)
all_seq_ids
=
{
i
for
i
in
range
(
batch_size
)}
multi_step_output
,
_
=
multi_step_worker
.
sampler_output
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
),
sample_len
=
1
,
seq_ids_with_bonus_token_in_last_step
=
all_seq_ids
)
for
index
,
output
in
enumerate
(
multi_step_output
[
-
1
].
outputs
):
assert
(
continuations
[
index
][
-
1
]
==
output
.
samples
[
0
].
output_token
)
@
torch
.
inference_mode
()
def
test_multi_step_with_batch_expansion_incorrect_output
():
"""
Tests the MultiStepWorker's ability to handle batch expansion with bonus
tokens in a negative case scenario. This test provides the MultiStepWorker
with a batch containing sequences with bonus tokens but specifies the
sequence IDs with bonus tokens incorrectly. The test verifies that the
MultiStepWorker generates correct tokens for the sequences where the
sequence ID is specified correctly and incorrect tokens for those where
the sequence ID is specified incorrectly.
"""
seed
=
100
model_name
=
'JackFram/llama-68m'
block_size
=
16
num_gpu_blocks
=
2048
//
block_size
batch_size
=
128
multi_step_worker
=
create_worker
(
MultiStepWorker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
,
model_runner_cls
=
TP1DraftModelRunner
,
)
multi_step_worker
.
set_include_gpu_probs_tensor
()
worker
=
create_worker
(
Worker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
,
)
random
.
seed
(
seed
)
prompts
=
[[
0
]
for
_
in
range
(
batch_size
)]
num_steps
=
2
final_prompt_lens
=
[(
num_steps
+
1
)
for
prompt
in
prompts
]
rand_seeds
=
list
(
random
.
randint
(
0
,
100
)
for
_
in
range
(
num_steps
))
multi_step_worker
.
execute_model
=
patch_execute_model_with_seeds
(
multi_step_worker
,
rand_seeds
)
worker
.
execute_model
=
patch_execute_model_with_seeds
(
worker
,
rand_seeds
)
# Create the test continuations
continuations
=
[[
random
.
randint
(
0
,
1000
)]
for
_
in
prompts
]
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
continuations
=
continuations
,
final_prompt_lens
=
final_prompt_lens
)
# Run single-step twice to generate 2 tokens. This
# will simulate the bonus token case with the second token
# being the bonus token.
zero_kv_cache
(
worker
.
cache_engine
)
single_step_output
:
list
[
SamplerOutput
]
=
[]
set_random_seed
(
seed
)
for
_
in
range
(
num_steps
):
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
continuations
=
continuations
,
final_prompt_lens
=
final_prompt_lens
)
single_step_output
.
extend
(
worker
.
execute_model
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
)))
# Append output tokens to new sequence data.
for
i
,
seq_group_output
in
enumerate
(
single_step_output
[
-
1
]):
continuations
[
i
].
append
(
seq_group_output
.
samples
[
0
].
output_token
)
# Create continuations for the MultiStepWorker. The continuations have
# 2 tokens in order to simulate the bonus token case.
multi_step_continuations
=
[]
for
continuation
in
continuations
:
multi_step_continuations
.
append
(
continuation
[:
2
])
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
continuations
=
multi_step_continuations
,
final_prompt_lens
=
final_prompt_lens
)
# Run multi-step. In this run INCORRECTLY specify that only the odd number
# sequences have bonus tokens. Verify that with this setting the third token
# prediction is accurate only for the odd numbered sequences. Also verify
# that the prediction might be wrong for some of the even numbered
# sequences.
zero_kv_cache
(
multi_step_worker
.
cache_engine
)
set_random_seed
(
seed
)
odd_seq_ids
=
{
i
for
i
in
range
(
batch_size
)
if
i
%
2
!=
0
}
multi_step_output
,
_
=
multi_step_worker
.
sampler_output
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
),
sample_len
=
1
,
seq_ids_with_bonus_token_in_last_step
=
odd_seq_ids
)
num_mismatch
=
0
for
index
,
output
in
enumerate
(
multi_step_output
[
-
1
].
outputs
):
if
(
index
%
2
)
!=
0
:
assert
(
continuations
[
index
][
-
1
]
==
output
.
samples
[
0
].
output_token
)
elif
(
continuations
[
index
][
-
1
]
!=
output
.
samples
[
0
].
output_token
):
num_mismatch
+=
1
# The prediction is accurate for some of the sequences even without proper
# handling of the bonus tokens. Hence verify that the number of sequences
# for which there is a mismatch is > 0.
assert
(
num_mismatch
>
0
)
@
torch
.
inference_mode
()
@
pytest
.
mark
.
parametrize
(
'num_steps'
,
[
1
,
2
,
3
,
4
])
# The choice of backends forces the multi_step_worker to choose between
# the vanilla model_runner and TP1DraftModelRunner and that we can test
# both code paths.
@
pytest
.
mark
.
parametrize
(
'attn_backend'
,
[
_Backend
.
XFORMERS
,
_Backend
.
FLASH_ATTN
])
def
test_multi_step_correct_kvcache
(
num_steps
,
attn_backend
):
"""Verify that the KV cache of the draft model
is correctly updated for sequences with bonus token.
"""
seed
=
100
model_name
=
"JackFram/llama-68m"
block_size
=
16
num_gpu_blocks
=
2048
//
block_size
batch_size
=
1
with
global_force_attn_backend_context_manager
(
attn_backend
):
dtype
=
'float16'
if
attn_backend
==
_Backend
.
FLASH_ATTN
else
'float32'
multi_step_worker
=
create_worker
(
MultiStepWorker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
,
model_runner_cls
=
TP1DraftModelRunner
,
dtype
=
dtype
)
multi_step_worker
.
set_include_gpu_probs_tensor
()
worker
=
create_worker
(
Worker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
,
dtype
=
dtype
)
prompts
=
[[
0
]
for
_
in
range
(
batch_size
)]
# Already generate two tokens for the sequence
# so that we can simulate the bonus token case
multi_step_continuations
=
[[
random
.
randint
(
0
,
1000
),
random
.
randint
(
0
,
1000
)
]
for
_
in
prompts
]
final_prompt_lens
=
[
len
(
prompt
)
+
2
+
num_steps
for
prompt
in
prompts
]
seq_ids_with_bonus_token_in_last_step
=
set
(
range
(
batch_size
))
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
continuations
=
multi_step_continuations
,
final_prompt_lens
=
final_prompt_lens
)
# Run multi-step.
zero_kv_cache
(
multi_step_worker
.
cache_engine
)
multi_step_worker
.
sampler_output
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
),
sample_len
=
num_steps
,
seq_ids_with_bonus_token_in_last_step
=
seq_ids_with_bonus_token_in_last_step
)
# Run single-step repeatedly.
zero_kv_cache
(
worker
.
cache_engine
)
# Generate the kv cache for the bonus token first
single_step_continuations
=
[
c
[:
1
]
for
c
in
multi_step_continuations
]
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
continuations
=
single_step_continuations
,
final_prompt_lens
=
final_prompt_lens
)
single_step_output
=
worker
.
execute_model
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
))
for
_
in
range
(
num_steps
):
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
continuations
=
multi_step_continuations
,
final_prompt_lens
=
final_prompt_lens
)
single_step_output
=
worker
.
execute_model
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
))
for
i
,
seq_group_output
in
enumerate
(
single_step_output
[
-
1
]):
multi_step_continuations
[
i
].
append
(
seq_group_output
.
samples
[
0
].
output_token
)
# Verify that the KV cache of the single-step and
# multi-step workers are the same.
single_step_gpu_cache
=
worker
.
cache_engine
[
0
].
gpu_cache
multi_step_gpu_cache
=
multi_step_worker
.
cache_engine
[
0
].
gpu_cache
num_layers
=
len
(
single_step_gpu_cache
)
allclose
=
lambda
a
,
b
:
torch
.
allclose
(
a
.
cuda
(),
b
.
cuda
(),
rtol
=
1e-2
,
atol
=
1e-2
)
for
i
in
range
(
num_layers
):
assert
allclose
(
single_step_gpu_cache
[
i
][
0
],
multi_step_gpu_cache
[
i
][
0
])
assert
allclose
(
single_step_gpu_cache
[
i
][
1
],
multi_step_gpu_cache
[
i
][
1
])
@
torch
.
inference_mode
()
def
test_draft_proposals_full_speculation_len
():
"""Verify Top1Proposer correctly handles case where all sequences
can speculate.
"""
k
=
10
batch_size
=
32
vocab_size
=
32_000
device
=
'cuda:0'
draft_worker
=
MagicMock
()
proposer
=
Top1Proposer
(
worker
=
draft_worker
,
device
=
device
,
vocab_size
=
vocab_size
,
max_proposal_len
=
2048
,
)
draft_worker
.
sampler_output
.
return_value
=
[
SamplerOutput
(
outputs
=
[],
sampled_token_probs
=
torch
.
rand
(
batch_size
,
vocab_size
,
device
=
device
,
dtype
=
torch
.
float32
),
logprobs
=
torch
.
rand
(
batch_size
,
vocab_size
,
device
=
device
,
dtype
=
torch
.
float32
),
sampled_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
),
device
=
device
,
dtype
=
torch
.
long
),
)
for
_
in
range
(
k
)
],
True
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
proposals
=
proposer
.
get_spec_proposals
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
),
seq_ids_with_bonus_token_in_last_step
=
set
())
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
proposals
.
proposal_token_ids
.
shape
==
torch
.
Size
([
batch_size
,
k
])
assert
proposals
.
proposal_probs
.
shape
[:
-
1
]
==
torch
.
Size
([
batch_size
,
k
])
assert
proposals
.
proposal_lens
.
shape
==
torch
.
Size
([
batch_size
])
assert
proposals
.
proposal_lens
.
tolist
()
==
[
k
for
_
in
range
(
batch_size
)]
@
torch
.
inference_mode
()
def
test_draft_proposals_no_speculations
():
"""Verify Top1Proposer correctly handles case where no sequences
can speculate.
"""
k
=
10
batch_size
=
32
vocab_size
=
32_000
device
=
'cuda:0'
prompt_len
=
10
draft_worker
=
MagicMock
()
proposer
=
Top1Proposer
(
worker
=
draft_worker
,
device
=
device
,
vocab_size
=
vocab_size
,
max_proposal_len
=
prompt_len
+
k
-
1
,
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
,
prompt_len
=
prompt_len
)
proposals
=
proposer
.
get_spec_proposals
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
),
seq_ids_with_bonus_token_in_last_step
=
set
())
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
proposals
.
proposal_token_ids
.
shape
==
torch
.
Size
([
batch_size
,
k
])
assert
proposals
.
proposal_probs
.
shape
[:
-
1
]
==
torch
.
Size
([
batch_size
,
k
])
assert
proposals
.
proposal_lens
.
shape
==
torch
.
Size
([
batch_size
])
assert
proposals
.
proposal_lens
.
tolist
()
==
[
0
for
_
in
range
(
batch_size
)]
@
torch
.
inference_mode
()
def
test_draft_proposals_mixed_k
():
"""Verify Top1Proposer correctly handles case some sequences can
speculate and some can't.
"""
k
=
10
batch_size
=
32
vocab_size
=
32_000
device
=
'cuda:0'
small_prompt_len
=
5
long_prompt_len
=
10
prev_output_token_len
=
20
expected_num_proposal_seqs
=
6
expected_num_no_proposal_seqs
=
batch_size
-
expected_num_proposal_seqs
prompt_len
=
[
small_prompt_len
for
_
in
range
(
expected_num_proposal_seqs
-
1
)
]
+
[
long_prompt_len
for
_
in
range
(
expected_num_no_proposal_seqs
)]
+
[
small_prompt_len
]
draft_worker
=
MagicMock
()
proposer
=
Top1Proposer
(
worker
=
draft_worker
,
device
=
device
,
vocab_size
=
vocab_size
,
max_proposal_len
=
long_prompt_len
+
prev_output_token_len
+
k
-
1
,
)
draft_worker
.
sampler_output
.
return_value
=
[
SamplerOutput
(
outputs
=
[],
sampled_token_probs
=
torch
.
rand
(
expected_num_proposal_seqs
,
vocab_size
,
device
=
device
,
dtype
=
torch
.
float32
),
logprobs
=
torch
.
rand
(
expected_num_proposal_seqs
,
vocab_size
,
device
=
device
,
dtype
=
torch
.
float32
),
sampled_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
expected_num_proposal_seqs
,
),
device
=
device
,
dtype
=
torch
.
long
),
)
for
_
in
range
(
k
)
],
True
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
,
prompt_len
=
prompt_len
,
prev_output_token_len
=
prev_output_token_len
,
)
proposals
=
proposer
.
get_spec_proposals
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
),
seq_ids_with_bonus_token_in_last_step
=
set
())
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
proposals
.
proposal_token_ids
.
shape
==
torch
.
Size
([
batch_size
,
k
])
assert
proposals
.
proposal_probs
.
shape
[:
-
1
]
==
torch
.
Size
([
batch_size
,
k
])
assert
proposals
.
proposal_lens
.
shape
==
torch
.
Size
([
batch_size
])
assert
proposals
.
proposal_lens
.
tolist
()
==
[
k
for
_
in
range
(
expected_num_proposal_seqs
-
1
)
]
+
[
0
for
_
in
range
(
expected_num_no_proposal_seqs
)]
+
[
k
]
@
torch
.
inference_mode
()
def
test_use_draft_model_runner_advance_step
():
"""Verify that draft model runner triggers advance step
when applicable.
"""
seed
=
100
model_name
=
'JackFram/llama-68m'
k
=
5
batch_size
=
32
block_size
=
32
num_gpu_blocks
=
2048
//
block_size
worker
=
create_worker
(
MultiStepWorker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
,
model_runner_cls
=
TP1DraftModelRunner
,
)
# Mock "_gpu_advance_step" to raise an exception when called.
exception_secret
=
"artificial stop"
worker
.
model_runner
.
_gpu_advance_step
=
MagicMock
()
worker
.
model_runner
.
_gpu_advance_step
.
side_effect
=
ValueError
(
exception_secret
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
,
block_size
=
block_size
,
num_gpu_blocks
=
num_gpu_blocks
)
# Fallback (should not call) when num_steps=1.
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
,
num_steps
=
1
)
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
# Expect exception if _gpu_advance_step is called.
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
,
num_steps
=
k
)
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
call_args_list
=
worker
.
model_runner
.
_gpu_advance_step
.
call_args_list
assert
len
(
call_args_list
)
==
1
@
torch
.
inference_mode
()
def
test_expand_execute_model_request_sync_with_expand_hidden_states
():
"""
In this test we verify that the logic for expanding the
seq_group_metadata_list remains in sync with the expansion logic of
the HiddenStates in _expand_execute_model_request.
"""
k
=
5
batch_size
=
16
seq_with_bonus_token_in_last_step
=
[
1
,
3
,
8
,
10
,
13
,
15
]
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
execute_model_request
=
ExecuteModelRequest
(
seq_group_metadata_list
,
previous_hidden_states
=
HiddenStates
(
torch
.
arange
(
batch_size
),
seq_group_metadata_list
,
torch
.
arange
(
batch_size
,
2
*
batch_size
)))
expanded_execute_model_request
,
orig_seq_group_ids
=
MultiStepWorker
.
\
_expand_execute_model_request
(
execute_model_request
,
seq_with_bonus_token_in_last_step
)
all_seq_ids
=
torch
.
tensor
(
get_all_seq_ids
(
expanded_execute_model_request
.
seq_group_metadata_list
))
ref_expanded_hidden_states
=
all_seq_ids
+
batch_size
ref_expanded_hidden_states
[
orig_seq_group_ids
]
-=
batch_size
assert
(
ref_expanded_hidden_states
==
expanded_execute_model_request
.
previous_hidden_states
.
hidden_states
).
all
().
item
()
tests/spec_decode/test_ngram_worker.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.spec_decode.ngram_worker
import
NGramWorker
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
.utils
import
create_seq_group_metadata_from_prompts
,
create_worker
def
test_ngram_algo_correctness_for_single_no_match
():
"""Verify our ngram algo find the right candidate in the prompt
For the scenario cannot find any candidate in one single batch
"""
block_size
=
32
num_gpu_blocks
=
2048
//
block_size
seed
=
100
model_name
=
'JackFram/llama-68m'
vocab_size
=
32_000
device
=
'cuda:0'
ngram_worker
=
create_worker
(
NGramWorker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
,
)
proposer
=
Top1Proposer
(
worker
=
ngram_worker
,
device
=
device
,
vocab_size
=
vocab_size
,
max_proposal_len
=
20
,
)
# set ngram window [1, 3], which is window=1/2/3
ngram_worker
.
set_ngram_window_size
(
1
,
3
)
prompts
=
[
# shall find no candidate
[
1
,
2
,
3
,
4
,
5
,
6
,
7
],
]
proposal_len
=
5
final_prompt_lens
=
[
len
(
prompt
)
+
proposal_len
for
prompt
in
prompts
]
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_prompt_lens
=
final_prompt_lens
)
proposals
=
proposer
.
get_spec_proposals
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
proposal_len
),
seq_ids_with_bonus_token_in_last_step
=
None
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
proposals
.
proposal_token_ids
.
shape
==
torch
.
Size
([
1
,
proposal_len
])
assert
proposals
.
proposal_probs
.
shape
[:
-
1
]
==
torch
.
Size
([
1
,
proposal_len
])
assert
proposals
.
proposal_lens
.
shape
==
torch
.
Size
([
1
])
assert
proposals
.
proposal_lens
.
tolist
()
==
[
0
]
def
test_ngram_algo_correctness_for_batches_not_match_all
():
"""Verify our ngram algo find the right candidate in the prompt
For the scenario find some candidate not full in batchs
"""
block_size
=
32
num_gpu_blocks
=
2048
//
block_size
seed
=
100
model_name
=
'JackFram/llama-68m'
vocab_size
=
32_000
device
=
'cuda:0'
ngram_worker
=
create_worker
(
NGramWorker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
,
)
proposer
=
Top1Proposer
(
worker
=
ngram_worker
,
device
=
device
,
vocab_size
=
vocab_size
,
max_proposal_len
=
20
,
)
# set ngram window [1, 3], which is window=1/2/3
ngram_worker
.
set_ngram_window_size
(
1
,
3
)
prompts
=
[
# shall find no candidate
[
1
,
2
,
3
,
4
,
5
,
6
,
7
],
# shall find candidate 12,13,14,15,16
[
11
,
12
,
13
,
14
,
15
,
16
,
11
],
# shall find candidate 23,24,25,26,21
[
21
,
21
,
22
,
23
,
24
,
25
,
26
,
21
,
22
],
# shall find candidate 34,35,36,37,38
[
31
,
32
,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
31
,
32
,
33
],
# shall find no candidate as exceed max_proposal_len
[
31
,
32
,
31
,
32
,
31
,
32
,
31
,
32
,
31
,
32
,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
31
,
32
,
33
],
]
proposal_len
=
5
final_prompt_lens
=
[
len
(
prompt
)
+
proposal_len
for
prompt
in
prompts
]
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_prompt_lens
=
final_prompt_lens
)
for
sg
in
seq_group_metadata_list
:
sg
.
is_prompt
=
False
proposals
=
proposer
.
get_spec_proposals
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
proposal_len
),
seq_ids_with_bonus_token_in_last_step
=
None
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
proposals
.
proposal_token_ids
.
shape
==
torch
.
Size
([
5
,
proposal_len
])
assert
proposals
.
proposal_probs
.
shape
[:
-
1
]
==
torch
.
Size
([
5
,
proposal_len
])
assert
proposals
.
proposal_lens
.
shape
==
torch
.
Size
([
5
])
# the first sequence has no match so proposal_len should be overwritten to 0
assert
proposals
.
proposal_lens
.
tolist
(
)
==
[
0
]
+
[
proposal_len
for
_
in
range
(
3
)]
+
[
0
]
for
i
in
range
(
proposal_len
):
assert
proposals
.
proposal_token_ids
[
0
][
i
]
==
-
1
assert
proposals
.
proposal_token_ids
[
1
][
i
]
==
prompts
[
1
][
i
+
1
]
assert
proposals
.
proposal_token_ids
[
2
][
i
]
==
prompts
[
2
][
i
+
3
]
assert
proposals
.
proposal_token_ids
[
3
][
i
]
==
prompts
[
3
][
i
+
5
]
assert
proposals
.
proposal_token_ids
[
4
][
i
]
==
-
1
def
test_ngram_algo_correctness_for_batches_match_all
():
"""Verify our ngram algo find the right candidate in the prompt
For the scenario find candidate in all batches
"""
block_size
=
32
num_gpu_blocks
=
2048
//
block_size
seed
=
100
model_name
=
'JackFram/llama-68m'
vocab_size
=
32_000
device
=
'cuda:0'
ngram_worker
=
create_worker
(
NGramWorker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
,
)
proposer
=
Top1Proposer
(
worker
=
ngram_worker
,
device
=
device
,
vocab_size
=
vocab_size
,
max_proposal_len
=
20
,
)
# set ngram window [0, 3], which is window=1/2/3
ngram_worker
.
set_ngram_window_size
(
1
,
3
)
prompts
=
[
# shall find candidate 12,13,14,15,16
[
11
,
12
,
13
,
14
,
15
,
16
,
11
],
# shall find candidate 23,24,25,26,21
[
21
,
21
,
22
,
23
,
24
,
25
,
26
,
21
,
22
],
# shall find candidate 34,35,36,37,38
[
31
,
32
,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
31
,
32
,
33
],
]
proposal_len
=
5
final_prompt_lens
=
[
len
(
prompt
)
+
proposal_len
for
prompt
in
prompts
]
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_prompt_lens
=
final_prompt_lens
)
# Normally drafter is run on decode requests only; here we check the output
# of the ngram worker as it is the sole proposer that has no forward.
for
sg
in
seq_group_metadata_list
:
sg
.
is_prompt
=
False
proposals
=
proposer
.
get_spec_proposals
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
proposal_len
),
seq_ids_with_bonus_token_in_last_step
=
None
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
proposals
.
proposal_token_ids
.
shape
==
torch
.
Size
([
3
,
proposal_len
])
assert
proposals
.
proposal_probs
.
shape
[:
-
1
]
==
torch
.
Size
([
3
,
proposal_len
])
assert
proposals
.
proposal_lens
.
shape
==
torch
.
Size
([
3
])
assert
proposals
.
proposal_lens
.
tolist
()
==
[
proposal_len
for
_
in
range
(
3
)]
for
i
in
range
(
proposal_len
):
assert
proposals
.
proposal_token_ids
[
0
][
i
]
==
prompts
[
0
][
i
+
1
]
assert
proposals
.
proposal_token_ids
[
1
][
i
]
==
prompts
[
1
][
i
+
3
]
assert
proposals
.
proposal_token_ids
[
2
][
i
]
==
prompts
[
2
][
i
+
5
]
tests/spec_decode/test_scorer.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
random
import
pytest
import
torch
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
,
SpeculativeScores
from
vllm.spec_decode.mqa_scorer
import
MQAScorer
from
vllm.worker.worker
import
Worker
from
.utils
import
create_batch
,
create_worker
def
create_proposal
(
propose_lens
:
list
[
int
],
vocab_size
:
int
,
device
:
str
)
->
SpeculativeProposals
:
batch_size
=
len
(
propose_lens
)
max_propose_len
=
max
(
propose_lens
)
proposal_probs
=
torch
.
rand
((
batch_size
,
max_propose_len
,
vocab_size
),
device
=
device
)
proposal_token_ids
=
torch
.
full
((
batch_size
,
max_propose_len
),
fill_value
=-
1
,
device
=
device
)
for
i
in
range
(
batch_size
):
proposal_token_ids
[
i
][:
propose_lens
[
i
]]
=
torch
.
argmax
(
proposal_probs
[
i
][:
propose_lens
[
i
]],
dim
=-
1
)
propose_lens
=
torch
.
tensor
(
propose_lens
,
device
=
device
)
return
SpeculativeProposals
(
proposal_token_ids
,
proposal_probs
,
propose_lens
)
def
assert_score_equal
(
score1
:
SpeculativeScores
,
score2
:
SpeculativeScores
)
->
None
:
assert
torch
.
allclose
(
score1
.
probs
,
score2
.
probs
)
assert
torch
.
allclose
(
score1
.
logprobs
,
score2
.
logprobs
)
assert
torch
.
equal
(
score1
.
token_ids
,
score2
.
token_ids
),
f
"
{
score1
.
token_ids
}
,
{
score2
.
token_ids
}
"
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
'facebook/opt-125m'
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
,
2
,
4
,
8
,
16
])
@
pytest
.
mark
.
parametrize
(
'max_propose_len'
,
[
1
,
3
,
5
])
@
pytest
.
mark
.
parametrize
(
'mixed_propose_len'
,
[
True
])
@
pytest
.
mark
.
parametrize
(
'device'
,
[
'cuda'
])
@
pytest
.
mark
.
parametrize
(
'prefill_chunking'
,
[
False
,
True
])
def
test_scorer
(
model_name
:
str
,
batch_size
:
int
,
max_propose_len
:
int
,
mixed_propose_len
:
bool
,
device
:
str
,
prefill_chunking
:
bool
)
->
None
:
"""
Compare the batch expansion scorer and mqa scorer return the same score.
We test for both queries with the same propose length and different
propose length, as well as mixed prefill-decode batches.
"""
seed
=
0
block_size
=
32
num_gpu_blocks
=
2048
//
block_size
scorer_worker
=
create_worker
(
Worker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
)
scorer_worker
.
model_runner
.
disable_logprobs
=
True
# accessed by mqa_scorer
scorer_worker
.
model_runner
.
sampler
.
include_gpu_probs_tensor
=
True
scorer_worker
.
model_runner
.
sampler
.
should_modify_greedy_probs_inplace
=
True
vocab_size
=
scorer_worker
.
vocab_size
if
not
mixed_propose_len
:
propose_lens
=
[
max_propose_len
]
*
batch_size
else
:
# There must be at least 1 decode request, otherwise
# we have nothing to score (`_run_no_spec`).
non_zero_cnt
=
random
.
randint
(
1
,
batch_size
)
propose_lens
=
[
max_propose_len
]
*
non_zero_cnt
+
[
0
]
*
(
batch_size
-
non_zero_cnt
)
random
.
shuffle
(
propose_lens
)
seq_group_metadatalist
,
_
,
_
=
create_batch
(
batch_size
,
max_propose_len
,
block_size
=
block_size
,
num_gpu_blocks
=
num_gpu_blocks
)
if
mixed_propose_len
and
prefill_chunking
and
(
n_prefills
:
=
batch_size
-
non_zero_cnt
):
prefill
,
_
,
_
=
create_batch
(
n_prefills
,
None
,
prefill_chunk_size
=
4
,
block_size
=
block_size
,
num_gpu_blocks
=
num_gpu_blocks
,
seq_ids
=
list
(
range
(
batch_size
,
batch_size
+
n_prefills
)))
# re-order to guarantee prefill|decode order
target_group_metadatalist
=
[
seq_group_metadatalist
[
i
]
for
i
,
p
in
enumerate
(
propose_lens
)
if
p
>
0
]
seq_group_metadatalist
=
prefill
+
target_group_metadatalist
propose_lens
=
[
0
]
*
n_prefills
+
[
p
for
p
in
propose_lens
if
p
>
0
]
proposals
=
create_proposal
(
propose_lens
,
vocab_size
,
device
)
requests
=
ExecuteModelRequest
(
seq_group_metadatalist
,
num_lookahead_slots
=
max_propose_len
)
batch_expansion_scorer
=
BatchExpansionTop1Scorer
(
scorer_worker
,
device
,
vocab_size
)
batch_expansion_score
=
batch_expansion_scorer
.
score_proposals
(
requests
,
proposals
)
mqa_scorer
=
MQAScorer
(
scorer_worker
,
device
,
vocab_size
)
mqa_score
=
mqa_scorer
.
score_proposals
(
requests
,
proposals
)
assert_score_equal
(
batch_expansion_score
,
mqa_score
)
tests/spec_decode/test_spec_decode_worker.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
random
from
collections
import
defaultdict
from
types
import
SimpleNamespace
from
unittest.mock
import
MagicMock
import
pytest
import
torch
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
,
SequenceOutput
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.metrics
import
(
AsyncMetricsCollector
,
SpecDecodeWorkerMetrics
)
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.spec_decode_worker
import
(
SpecDecodeWorker
,
split_num_cache_blocks_evenly
)
from
vllm.worker.worker
import
Worker
from
.test_utils
import
mock_spec_decode_sampler
from
.utils
import
(
create_batch
,
create_sampler_output_list
,
create_worker
,
mock_worker
)
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
6
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
torch
.
inference_mode
()
def
test_correctly_calls_draft_model
(
k
:
int
,
batch_size
:
int
,
acceptance_sampler_method
:
str
):
"""Verify SpecDecodeWorker calls the draft worker with correct
inputs. Everything else is mocked out.
"""
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
target_worker
=
mock_worker
()
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
mock_spec_decode_sampler
(
acceptance_sampler_method
),
disable_logprobs
=
False
,
metrics_collector
=
metrics_collector
)
exception_secret
=
'artificial stop'
draft_worker
.
get_spec_proposals
.
side_effect
=
ValueError
(
exception_secret
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
)
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
call_args_list
=
draft_worker
.
get_spec_proposals
.
call_args_list
assert
len
(
call_args_list
)
==
1
for
args
,
_
in
call_args_list
:
actual_execute_model_data
=
args
[
0
]
assert
actual_execute_model_data
==
execute_model_req
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
6
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
torch
.
inference_mode
()
def
test_batch_expansion_correctly_calls_target_model
(
k
:
int
,
batch_size
:
int
,
acceptance_sampler_method
:
str
):
"""Verify SpecDecodeWorker calls the target model with correct
inputs with batch expansion. Everything else is mocked out.
"""
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
,
use_spec
=
False
)
target_worker
=
mock_worker
(
use_spec
=
False
)
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
draft_worker
.
device
=
'cuda'
target_worker
.
device
=
'cuda'
set_random_seed
(
1
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
mock_spec_decode_sampler
(
acceptance_sampler_method
),
disable_logprobs
=
False
,
metrics_collector
=
metrics_collector
,
disable_mqa_scorer
=
True
)
worker
.
init_device
()
vocab_size
=
32_000
proposal_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
),
dtype
=
torch
.
int64
,
device
=
'cuda'
)
proposal_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
proposal_lens
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
'cuda'
)
*
k
seq_group_metadata_list
,
prompts
,
prev_output_tokens
=
create_batch
(
batch_size
,
k
)
draft_worker
.
get_spec_proposals
.
return_value
=
SpeculativeProposals
(
proposal_token_ids
=
proposal_token_ids
,
proposal_probs
=
proposal_probs
,
proposal_lens
=
proposal_lens
)
exception_secret
=
'artificial stop'
target_worker
.
execute_model
.
side_effect
=
ValueError
(
exception_secret
)
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
worker
.
execute_model
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
))
seen_contexts
:
list
[
list
[
int
]]
=
[]
call_args_list
=
target_worker
.
execute_model
.
call_args_list
assert
len
(
call_args_list
)
==
1
for
_
,
kwargs
in
call_args_list
:
seq_group_metadata_list
=
kwargs
[
"execute_model_req"
].
seq_group_metadata_list
assert
len
(
seq_group_metadata_list
)
==
(
k
+
1
)
*
batch_size
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_data
in
seq_group_metadata
.
seq_data
.
values
():
seen_contexts
.
append
(
seq_data
.
get_token_ids
())
expected_seen_contexts
:
list
[
list
[
int
]]
=
[]
for
prompt
,
prev_generated
,
draft_tokens
in
zip
(
prompts
,
prev_output_tokens
,
proposal_token_ids
.
tolist
()):
for
i
in
range
(
len
(
draft_tokens
)
+
1
):
expected_seen_contexts
.
append
(
prompt
+
prev_generated
+
draft_tokens
[:
i
])
seen_contexts
.
sort
()
expected_seen_contexts
.
sort
()
assert
expected_seen_contexts
==
seen_contexts
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
6
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
torch
.
inference_mode
()
def
test_correctly_calls_spec_decode_sampler
(
k
:
int
,
batch_size
:
int
,
acceptance_sampler_method
:
str
):
"""Verify SpecDecodeWorker calls the rejection sampler with
correct inputs. Everything else is mocked out.
"""
vocab_size
=
32_000
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
,
vocab_size
=
vocab_size
,
use_spec
=
False
)
target_worker
=
mock_worker
(
vocab_size
=
vocab_size
,
use_spec
=
False
)
spec_decode_sampler
=
mock_spec_decode_sampler
(
acceptance_sampler_method
)
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
draft_worker
.
device
=
'cuda'
target_worker
.
device
=
'cuda'
set_random_seed
(
1
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
spec_decode_sampler
,
disable_logprobs
=
False
,
metrics_collector
=
metrics_collector
)
worker
.
init_device
()
proposal_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
),
dtype
=
torch
.
int64
,
device
=
'cuda'
)
proposal_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
proposal_lens
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
'cuda'
)
*
k
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
draft_worker
.
get_spec_proposals
.
return_value
=
SpeculativeProposals
(
proposal_token_ids
=
proposal_token_ids
,
proposal_probs
=
proposal_probs
,
proposal_lens
=
proposal_lens
)
target_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
1
,
batch_size
*
(
k
+
1
)),
dtype
=
torch
.
int64
,
device
=
'cuda'
)
target_token_probs
=
torch
.
rand
(
1
,
batch_size
*
(
k
+
1
),
vocab_size
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
target_token_logprobs
=
torch
.
rand
(
1
,
batch_size
*
(
k
+
1
),
vocab_size
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
target_output
=
create_sampler_output_list
(
target_token_ids
,
target_token_probs
,
target_token_logprobs
)
target_worker
.
execute_model
.
return_value
=
[
target_output
[
0
]]
exception_secret
=
'artificial stop'
spec_decode_sampler
.
side_effect
=
ValueError
(
exception_secret
)
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
worker
.
execute_model
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
))
assert
len
(
spec_decode_sampler
.
call_args_list
)
==
1
_
,
kwargs
=
spec_decode_sampler
.
call_args_list
[
0
]
actual
=
SimpleNamespace
(
**
kwargs
)
assert
torch
.
equal
(
actual
.
bonus_token_ids
,
target_token_ids
.
reshape
(
batch_size
,
k
+
1
)[:,
-
1
:])
assert
torch
.
equal
(
actual
.
target_with_bonus_probs
,
target_token_probs
.
reshape
(
batch_size
,
k
+
1
,
-
1
))
assert
torch
.
equal
(
actual
.
draft_token_ids
,
proposal_token_ids
)
assert
torch
.
equal
(
actual
.
draft_probs
,
proposal_probs
)
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
6
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
torch
.
inference_mode
()
def
test_correctly_formats_output
(
k
:
int
,
batch_size
:
int
,
acceptance_sampler_method
:
str
):
"""Verify SpecDecodeWorker formats sampler output correctly.
Everything else is mocked out.
"""
vocab_size
=
32_000
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
,
vocab_size
=
vocab_size
,
use_spec
=
False
)
target_worker
=
mock_worker
(
vocab_size
=
vocab_size
,
use_spec
=
False
)
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
draft_worker
.
device
=
'cuda'
target_worker
.
device
=
'cuda'
set_random_seed
(
1
)
spec_decode_sampler
=
mock_spec_decode_sampler
(
acceptance_sampler_method
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
spec_decode_sampler
,
disable_logprobs
=
False
,
metrics_collector
=
metrics_collector
)
worker
.
init_device
()
proposal_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
),
dtype
=
torch
.
int64
,
device
=
'cuda'
)
proposal_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
proposal_lens
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
'cuda'
)
*
k
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
draft_worker
.
get_spec_proposals
.
return_value
=
SpeculativeProposals
(
proposal_token_ids
=
proposal_token_ids
,
proposal_probs
=
proposal_probs
,
proposal_lens
=
proposal_lens
)
target_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
1
,
batch_size
*
(
k
+
1
)),
dtype
=
torch
.
int64
,
device
=
'cuda'
)
target_token_probs
=
torch
.
rand
(
1
,
batch_size
*
(
k
+
1
),
vocab_size
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
target_token_logprobs
=
torch
.
rand
(
1
,
batch_size
*
(
k
+
1
),
vocab_size
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
target_output
=
create_sampler_output_list
(
target_token_ids
,
target_token_probs
,
target_token_logprobs
)
target_worker
.
execute_model
.
return_value
=
[
target_output
[
0
]]
spec_decode_sampler_output
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
+
1
),
dtype
=
torch
.
int64
,
device
=
'cuda'
)
for
i
in
range
(
batch_size
):
minimum_accepted_tokens
=
1
spec_decode_sampler_output
[
i
][
-
random
.
randint
(
minimum_accepted_tokens
,
k
+
1
):]
=
-
1
spec_decode_sampler
.
return_value
=
spec_decode_sampler_output
output
=
worker
.
execute_model
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
))
expected_output
=
create_sampler_output_list
(
token_ids
=
spec_decode_sampler_output
.
transpose
(
0
,
1
),
probs
=
[
None
for
_
in
range
(
k
+
1
)],
logprobs
=
[
None
for
_
in
range
(
k
+
1
)])
seq_ids
=
[
next
(
iter
(
seq_group_metadata
.
seq_data
.
keys
()))
for
seq_group_metadata
in
seq_group_metadata_list
]
actual_output_by_seq
:
dict
[
int
,
list
[
SequenceOutput
]]
=
{
seq_id
:
[]
for
seq_id
in
seq_ids
}
expected_output_by_seq
:
dict
[
int
,
list
[
SequenceOutput
]]
=
{
seq_id
:
[]
for
seq_id
in
seq_ids
}
for
step
in
output
:
for
seq_group
in
step
:
for
sample
in
seq_group
.
samples
:
seq_id
=
sample
.
parent_seq_id
actual_output_by_seq
[
seq_id
].
append
(
sample
)
for
step
in
expected_output
:
for
seq_group
in
step
:
for
sample
in
seq_group
.
samples
:
seq_id
=
sample
.
parent_seq_id
expected_output_by_seq
[
seq_id
].
append
(
sample
)
all_seen_seq_ids
=
set
(
list
(
actual_output_by_seq
.
keys
())
+
list
(
expected_output_by_seq
.
keys
()))
for
seq_id
in
all_seen_seq_ids
:
actual_by_step
=
actual_output_by_seq
[
seq_id
]
expected_by_step
=
expected_output_by_seq
[
seq_id
]
for
i
in
range
(
k
+
1
):
if
i
>=
len
(
actual_by_step
):
assert
expected_by_step
[
i
].
output_token
==
-
1
continue
assert
actual_by_step
[
i
].
output_token
==
expected_by_step
[
i
].
output_token
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
])
@
pytest
.
mark
.
parametrize
(
'returns_metrics'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
torch
.
inference_mode
()
def
test_collects_metrics
(
k
:
int
,
batch_size
:
int
,
returns_metrics
:
bool
,
acceptance_sampler_method
:
str
):
"""Verify SpecDecodeWorker collects metrics.
"""
vocab_size
=
32_000
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
,
vocab_size
=
vocab_size
,
use_spec
=
False
)
target_worker
=
mock_worker
(
vocab_size
=
vocab_size
,
use_spec
=
False
)
spec_decode_sampler
=
mock_spec_decode_sampler
(
acceptance_sampler_method
)
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
draft_worker
.
device
=
'cuda'
target_worker
.
device
=
'cuda'
set_random_seed
(
1
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
spec_decode_sampler
,
disable_logprobs
=
False
,
metrics_collector
=
metrics_collector
)
worker
.
init_device
()
proposal_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
),
dtype
=
torch
.
int64
,
device
=
'cuda'
)
proposal_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
proposal_lens
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
'cuda'
)
*
k
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
draft_worker
.
get_spec_proposals
.
return_value
=
SpeculativeProposals
(
proposal_token_ids
=
proposal_token_ids
,
proposal_probs
=
proposal_probs
,
proposal_lens
=
proposal_lens
)
target_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
1
,
batch_size
*
(
k
+
1
)),
dtype
=
torch
.
int64
,
device
=
'cuda'
)
target_token_probs
=
torch
.
rand
(
1
,
batch_size
*
(
k
+
1
),
vocab_size
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
target_token_logprobs
=
torch
.
rand
(
1
,
batch_size
*
(
k
+
1
),
vocab_size
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
target_output
=
create_sampler_output_list
(
target_token_ids
,
target_token_probs
,
target_token_logprobs
)
target_worker
.
execute_model
.
return_value
=
[
target_output
[
0
]]
spec_decode_sampler_output
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
+
1
),
dtype
=
torch
.
int64
,
device
=
'cuda'
)
for
i
in
range
(
batch_size
):
minimum_accepted_tokens
=
1
spec_decode_sampler_output
[
i
][
-
random
.
randint
(
minimum_accepted_tokens
,
k
+
1
):]
=
-
1
spec_decode_sampler
.
return_value
=
spec_decode_sampler_output
mock_rejsample_metrics
=
MagicMock
(
spec
=
SpecDecodeWorkerMetrics
)
if
returns_metrics
else
None
metrics_collector
.
maybe_collect_rejsample_metrics
.
return_value
=
(
mock_rejsample_metrics
)
output
=
worker
.
execute_model
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
))
assert
output
[
0
].
spec_decode_worker_metrics
==
mock_rejsample_metrics
call_args_list
=
(
metrics_collector
.
maybe_collect_rejsample_metrics
.
call_args_list
)
assert
len
(
call_args_list
)
==
1
args
,
kwargs
=
call_args_list
[
0
]
assert
args
[
0
]
==
k
or
kwargs
.
get
(
'k'
,
-
1
)
==
k
@
pytest
.
mark
.
parametrize
(
'k'
,
[
0
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
torch
.
inference_mode
()
def
test_k_equals_zero
(
k
:
int
,
batch_size
:
int
,
acceptance_sampler_method
:
str
):
"""Verify that the SpecDecodeWorker calls the draft and target workers
when k is zero. This happens during prefill.
"""
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
target_worker
=
mock_worker
()
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
sampler_output
=
MagicMock
(
spec
=
SamplerOutput
)
sampler_output
.
hidden_states
=
None
target_worker
.
execute_model
.
return_value
=
[
sampler_output
]
draft_worker
.
device
=
'cuda'
target_worker
.
device
=
'cuda'
set_random_seed
(
1
)
worker
=
SpecDecodeWorker
(
proposer_worker
=
draft_worker
,
scorer_worker
=
target_worker
,
spec_decode_sampler
=
mock_spec_decode_sampler
(
acceptance_sampler_method
),
disable_logprobs
=
False
,
metrics_collector
=
metrics_collector
,
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
,
prev_output_token_len
=
0
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
)
out
=
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
assert
len
(
out
)
==
1
,
f
"expected only one token output when
{
k
=
}
"
assert
out
[
0
].
sampled_token_probs
is
None
,
(
"expect gpu tensor references to be None"
)
assert
out
[
0
].
sampled_token_ids
is
None
,
"expect gpu tensor references to be None"
draft_worker
.
execute_model
.
assert_called_once_with
(
execute_model_req
)
target_worker
.
execute_model
.
assert_called_once_with
(
execute_model_req
)
@
pytest
.
mark
.
parametrize
(
'k'
,
[
0
,
5
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
0
])
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
torch
.
inference_mode
()
def
test_empty_input_batch
(
k
:
int
,
batch_size
:
int
,
acceptance_sampler_method
:
str
):
"""Verify that the SpecDecodeWorker calls the draft and target workers
when the input batch is empty. This can happen if the engine communicates
to the workers information without scheduling a batch.
"""
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
target_worker
=
mock_worker
()
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
sampler_output
=
MagicMock
(
spec
=
SamplerOutput
)
sampler_output
.
hidden_states
=
None
target_worker
.
execute_model
.
return_value
=
[
sampler_output
]
draft_worker
.
device
=
'cuda'
target_worker
.
device
=
'cuda'
set_random_seed
(
1
)
worker
=
SpecDecodeWorker
(
proposer_worker
=
draft_worker
,
scorer_worker
=
target_worker
,
spec_decode_sampler
=
mock_spec_decode_sampler
(
acceptance_sampler_method
),
disable_logprobs
=
False
,
metrics_collector
=
metrics_collector
,
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
,
prev_output_token_len
=
0
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
)
out
=
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
assert
len
(
out
)
==
1
,
f
"expected only one token output when
{
k
=
}
"
assert
out
[
0
].
sampled_token_probs
is
None
,
(
"expect gpu tensor references to be None"
)
assert
out
[
0
].
sampled_token_ids
is
None
,
"expect gpu tensor references to be None"
draft_worker
.
execute_model
.
assert_called_once_with
(
execute_model_req
)
target_worker
.
execute_model
.
assert_called_once_with
(
execute_model_req
)
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
pytest
.
mark
.
skip_global_cleanup
def
test_init_device
(
acceptance_sampler_method
:
str
):
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
well as other GPU initialization.
"""
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
,
use_spec
=
False
)
target_worker
=
mock_worker
(
use_spec
=
False
)
spec_decode_sampler
=
mock_spec_decode_sampler
(
acceptance_sampler_method
)
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
worker
=
SpecDecodeWorker
(
proposer_worker
=
draft_worker
,
scorer_worker
=
target_worker
,
spec_decode_sampler
=
spec_decode_sampler
,
disable_logprobs
=
False
,
metrics_collector
=
metrics_collector
,
)
worker
.
init_device
()
draft_worker
.
init_device
.
assert_called_once
()
target_worker
.
init_device
.
assert_called_once
()
metrics_collector
.
init_tensors
.
assert_called_once
()
spec_decode_sampler
.
init_tensors
.
assert_called_once
()
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
torch
.
inference_mode
()
def
test_initialize_cache
(
acceptance_sampler_method
):
"""Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer
workers.
"""
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
target_worker
=
mock_worker
()
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
worker
=
SpecDecodeWorker
(
proposer_worker
=
draft_worker
,
scorer_worker
=
target_worker
,
spec_decode_sampler
=
mock_spec_decode_sampler
(
acceptance_sampler_method
),
metrics_collector
=
metrics_collector
)
kwargs
=
{
"num_gpu_blocks"
:
1024
,
"num_cpu_blocks"
:
1023
}
worker
.
initialize_cache
(
**
kwargs
)
draft_worker
.
initialize_cache
.
assert_called_once_with
(
**
kwargs
)
target_worker
.
initialize_cache
.
assert_called_once_with
(
**
kwargs
)
@
pytest
.
mark
.
parametrize
(
'available_gpu_blocks'
,
[
1
,
1024
])
@
pytest
.
mark
.
parametrize
(
'available_cpu_blocks'
,
[
500
])
@
pytest
.
mark
.
parametrize
(
'target_cache_block_size_bytes'
,
[
2
*
2
*
4096
])
@
pytest
.
mark
.
parametrize
(
'draft_kv_size_bytes'
,
[
0
,
2
*
2
*
768
,
2
*
2
*
4096
])
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
pytest
.
mark
.
skip_global_cleanup
def
test_determine_num_available_blocks
(
available_gpu_blocks
:
int
,
available_cpu_blocks
:
int
,
target_cache_block_size_bytes
:
int
,
draft_kv_size_bytes
:
int
,
acceptance_sampler_method
:
str
):
"""Verify SpecDecodeWorker correctly profiles num available GPU blocks.
Specifically, it should run profiling in the scorer worker, and then evenly
split the blocks between proposer and scorer worker.
"""
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
target_worker
=
mock_worker
()
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
target_worker
.
determine_num_available_blocks
.
return_value
=
(
available_gpu_blocks
,
available_cpu_blocks
)
target_worker
.
get_cache_block_size_bytes
.
return_value
=
(
target_cache_block_size_bytes
)
draft_worker
.
get_cache_block_size_bytes
.
return_value
=
draft_kv_size_bytes
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
mock_spec_decode_sampler
(
acceptance_sampler_method
),
metrics_collector
)
num_gpu_blocks
,
num_cpu_blocks
=
worker
.
determine_num_available_blocks
()
target_worker
.
determine_num_available_blocks
.
assert_called_once
()
assert
num_cpu_blocks
==
available_cpu_blocks
assert
num_gpu_blocks
==
split_num_cache_blocks_evenly
(
target_cache_block_size_bytes
,
draft_kv_size_bytes
,
available_gpu_blocks
)
@
pytest
.
mark
.
parametrize
(
'available_gpu_blocks'
,
list
(
range
(
20
))
+
[
1024
,
1024
**
2
])
@
pytest
.
mark
.
parametrize
(
'target_cache_block_size_bytes'
,
[
2
*
2
*
4096
,
2
*
2
*
8192
])
@
pytest
.
mark
.
parametrize
(
'draft_kv_size_bytes'
,
[
0
,
2
*
2
*
768
,
2
*
2
*
4096
])
@
pytest
.
mark
.
skip_global_cleanup
def
test_split_num_cache_blocks_evenly
(
available_gpu_blocks
:
int
,
target_cache_block_size_bytes
:
int
,
draft_kv_size_bytes
:
int
):
"""Verify split_num_cache_blocks_evenly does not exceed original memory
allocation in bytes.
"""
num_blocks
=
split_num_cache_blocks_evenly
(
target_cache_block_size_bytes
,
draft_kv_size_bytes
,
available_gpu_blocks
)
assert
(
num_blocks
*
target_cache_block_size_bytes
)
+
(
num_blocks
*
draft_kv_size_bytes
)
<=
(
available_gpu_blocks
*
target_cache_block_size_bytes
)
@
torch
.
inference_mode
()
def
test_populate_seq_ids_with_bonus_tokens
():
"""
Verify that a call to _create_output_sampler_list correctly updates
seq_with_bonus_token_in_last_step.
seq_with_bonus_token_in_last_step is an internal data structure in
SpecDecodeWorker that tracks the sequence IDs which are assigned bonus
tokens by the target model in their last forward pass. This state is
maintained only for models relying on the KV cache, such as those using
the MultiStepWorker.
"""
batch_size
=
10
k
=
5
vocab_size
=
10000
num_sequences_with_bonus_tokens
=
5
target_worker
=
mock_worker
(
vocab_size
=
vocab_size
,
use_spec
=
False
)
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
target_worker
.
execute_model
.
return_value
=
[
MagicMock
(
spec
=
SamplerOutput
)]
target_worker
.
device
=
'cuda'
set_random_seed
(
1
)
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
draft_worker
.
device
=
'cuda'
# The sequence_ids attached to each sequence in the batch.
# The sequence at index i has seq_id assigned_seq_ids[i]
assigned_seq_ids
=
list
(
range
(
batch_size
))
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
,
seq_ids
=
assigned_seq_ids
,
prev_output_token_len
=
10
)
target_token_logprobs
=
torch
.
rand
(
batch_size
,
(
k
+
1
),
vocab_size
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
accepted_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
(
k
+
1
)),
dtype
=
torch
.
int64
,
device
=
'cuda'
)
expected_request_id_seq_ids_mapping
:
dict
[
str
,
set
[
int
]]
=
defaultdict
(
set
)
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_id
in
seq_group_metadata
.
seq_data
:
expected_request_id_seq_ids_mapping
[
seq_group_metadata
.
request_id
].
add
(
seq_id
)
# Generate a random sample of sequence indexes with bonus tokens
seq_indexes_with_bonus_tokens
=
random
.
sample
(
range
(
batch_size
),
num_sequences_with_bonus_tokens
)
# Create a mask that is True for indices in seq_indexes_with_bonus_tokens
mask
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
bool
,
device
=
'cuda'
)
mask
[
seq_indexes_with_bonus_tokens
]
=
False
# Set the last token ID to -1 for all indices not in
# seq_indexes_with_bonus_tokens to indicate the lack of bonus token in
# those indices.
accepted_token_ids
[
mask
,
-
1
:]
=
-
1
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
mock_spec_decode_sampler
(
"rejection_sampler"
),
disable_logprobs
=
False
,
metrics_collector
=
metrics_collector
)
# Initialize _seq_with_bonus_token_in_last_step with a set of sequence IDs.
# This set includes all sequence IDs in the batch as well as an additional
# `num_extra_sequence_ids` sequence IDs. Note that the sequence IDs are in
# the range [0, batch_size + num_extra_sequence_ids).
num_extra_sequence_ids
=
10
worker
.
_seq_with_bonus_token_in_last_step
=
set
(
range
(
batch_size
+
num_extra_sequence_ids
))
worker
.
_create_output_sampler_list
(
seq_group_metadata_list
=
seq_group_metadata_list
,
accepted_token_ids
=
accepted_token_ids
,
target_logprobs
=
target_token_logprobs
,
prompt_logprobs
=
None
,
k
=
k
,
stage_times
=
(
0
,
0
,
0
))
# Verify that _seq_with_bonus_token_in_last_step contains the following:
# 1. Sequence IDs that were already present in
# _seq_with_bonus_token_in_last_step but were not part of the current
# batch are retained.
# 2. Of the sequence IDs present in the current batch, only those with a
# bonus token are retained in _seq_with_bonus_token_in_last_step.
# Sequence IDs that are present in the current batch but do not have
# bonus tokens are removed from _seq_with_bonus_token_in_last_step.
expected_seq_ids_with_bonus_tokens
=
\
set
([
assigned_seq_ids
[
i
]
for
i
in
seq_indexes_with_bonus_tokens
])
additional_sequence_ids
=
\
set
(
range
(
batch_size
,
batch_size
+
num_extra_sequence_ids
))
assert
worker
.
_seq_with_bonus_token_in_last_step
==
\
expected_seq_ids_with_bonus_tokens
.
union
(
additional_sequence_ids
)
assert
worker
.
_request_id_seq_id_mapping
==
\
expected_request_id_seq_ids_mapping
@
torch
.
inference_mode
()
def
test_handle_finished_requests
():
"""
Test to verify that finished request IDs are appropriately processed to
update the internal state of the SpecDecodeWorker.
This test initializes the SpecDecodeWorker with mock data, marks certain
requests as finished, and ensures that the corresponding sequence IDs are
correctly removed from the internal mappings.
"""
batch_size
=
32
k
=
3
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
target_worker
=
mock_worker
()
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
mock_spec_decode_sampler
(
"rejection_sampler"
),
metrics_collector
)
# Initialize the request_id_seq_id_mapping mapping dict with a few fake
# request ids and corresponding sequence ids.
worker
.
_request_id_seq_id_mapping
=
\
{
'request-1'
:
{
1
,
2
,
3
},
'request-2'
:
{
4
,
5
,
6
,
7
},
'request-3'
:
{
8
,
9
},
'request-4'
:
{
10
,
11
}}
# Initialize seq_with_bonus_token_in_last_step with a few fake
# sequence ids.
worker
.
_seq_with_bonus_token_in_last_step
=
{
1
,
4
,
5
,
8
,
9
,
10
}
exception_secret
=
'artificial stop'
draft_worker
.
get_spec_proposals
.
side_effect
=
ValueError
(
exception_secret
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
# Mark requests with ids request-1 and request-3 as finished.
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
,
finished_requests_ids
=
[
'request-1'
,
'request-3'
])
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
# Verify that request-1 and request-3 are removed from
# request_id_seq_id_mapping
assert
worker
.
_request_id_seq_id_mapping
==
\
{
'request-2'
:
{
4
,
5
,
6
,
7
},
'request-4'
:
{
10
,
11
}}
# Verify that all sequence ids corresponding to 'request-1'
# and 'request-3' are removed from seq_with_bonus_token_in_last_step.
assert
worker
.
_seq_with_bonus_token_in_last_step
==
\
{
4
,
5
,
10
}
@
pytest
.
mark
.
parametrize
(
'k'
,
[
3
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"batch_composition"
,
[
"prefill_only"
,
"decode_only"
,
"mixed"
])
@
torch
.
inference_mode
()
def
test_chunked_prefill_flow
(
k
:
int
,
batch_size
:
int
,
batch_composition
:
str
):
"""
Verify SpecDecodeWorker calls match the expected flow.
"""
vocab_size
=
32_000
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
)
target_worker
=
mock_worker
()
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
mock_spec_decode_sampler
(
"rejection_sampler"
),
disable_logprobs
=
False
,
metrics_collector
=
metrics_collector
)
exception_secret
=
'artificial stop'
worker
.
scorer
=
mock_worker
(
BatchExpansionTop1Scorer
)
worker
.
scorer
.
score_proposals
.
side_effect
=
ValueError
(
exception_secret
)
# Create batch with combination of terminal/non-terminal prefill chunks
# and decodes (different seq_ids).
decodes
,
_
,
_
=
create_batch
(
batch_size
,
k
)
# Pre-chunking here, get 'batch_size' chunks.
prefill
,
_
,
_
=
create_batch
(
batch_size
,
k
,
prefill_chunk_size
=
4
,
seq_ids
=
list
(
range
(
batch_size
,
batch_size
*
2
)))
if
batch_composition
==
"prefill_only"
:
n_prefills
=
batch_size
elif
batch_composition
==
"decode_only"
:
n_prefills
=
0
else
:
n_prefills
=
random
.
randint
(
1
,
batch_size
-
1
)
n_decodes
=
batch_size
-
n_prefills
prefill
=
random
.
sample
(
prefill
,
n_prefills
)
decodes
=
random
.
sample
(
decodes
,
n_decodes
)
target_group_metadata_list
=
prefill
+
decodes
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
target_group_metadata_list
,
# For prefill only batches we expect num_lookahead_slots = 0.
num_lookahead_slots
=
k
if
n_decodes
>
0
else
0
)
target_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
1
,
batch_size
*
(
k
+
1
)),
dtype
=
torch
.
int64
,
device
=
'cuda'
)
target_token_probs
=
torch
.
rand
(
1
,
batch_size
*
(
k
+
1
),
vocab_size
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
target_token_logprobs
=
torch
.
rand
(
1
,
batch_size
*
(
k
+
1
),
vocab_size
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
target_output
=
create_sampler_output_list
(
target_token_ids
,
target_token_probs
,
target_token_logprobs
)
target_worker
.
execute_model
.
return_value
=
[
target_output
[
0
]]
if
not
len
(
decodes
):
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
# no spec run (prefill only)
draft_worker
.
execute_model
.
assert_called_once_with
(
execute_model_req
)
target_worker
.
execute_model
.
assert_called_once_with
(
execute_model_req
)
else
:
# Decode-only run OR mixed batch, scorer call fails (it's mocked)
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
# but first draft still counted
assert
draft_worker
.
get_spec_proposals
.
call_count
==
1
def
test_correctly_load_weight_for_eagle
():
"""
Verify SpecDecodeWorker loads lm_head weight for eagle correctly.
"""
seed
=
100
block_size
=
32
num_gpu_blocks
=
8096
//
block_size
target_worker
=
create_worker
(
Worker
,
"JackFram/llama-68m"
,
block_size
,
num_gpu_blocks
,
seed
,
)
draft_worker
=
create_worker
(
MultiStepWorker
,
"abhigoyal/vllm-eagle-llama-68m-random"
,
block_size
,
num_gpu_blocks
,
seed
,
model_runner_cls
=
TP1DraftModelRunner
,
)
spec_decode_sampler
=
mock_spec_decode_sampler
(
"rejection_sampler"
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
spec_decode_sampler
,
disable_logprobs
=
False
)
worker
.
proposer_worker
.
maybe_load_lm_head_weight
(
target_worker
.
model_runner
.
model
.
lm_head
.
weight
.
data
)
assert
torch
.
allclose
(
worker
.
proposer_worker
.
worker
.
model_runner
.
model
.
lm_head
.
weight
.
data
,
worker
.
scorer_worker
.
model_runner
.
model
.
lm_head
.
weight
.
data
)
tests/spec_decode/test_utils.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
unittest.mock
import
MagicMock
import
pytest
import
torch
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.layers.sampler
import
_get_ranks
from
vllm.model_executor.layers.typical_acceptance_sampler
import
(
TypicalAcceptanceSampler
)
from
vllm.sequence
import
SequenceGroupMetadata
,
get_all_seq_ids
from
vllm.spec_decode.util
import
(
get_sampled_token_logprobs
,
split_batch_by_proposal_len
)
def
test_get_all_seq_ids
():
"""Verify get_all_seq_ids extracts all seq ids.
"""
expected_seq_ids
=
list
(
range
(
10
))
+
list
(
range
(
100
,
110
))
seq_group_metadata_list
=
[
SequenceGroupMetadata
(
request_id
=
str
(
seq_id
),
is_prompt
=
True
,
seq_data
=
{
seq_id
:
MagicMock
(),
},
sampling_params
=
MagicMock
(),
block_tables
=
{
seq_id
:
MagicMock
(),
},
lora_request
=
None
,
)
for
seq_id
in
expected_seq_ids
]
actual_seq_ids
=
get_all_seq_ids
(
seq_group_metadata_list
)
assert
actual_seq_ids
==
expected_seq_ids
@
pytest
.
fixture
def
fake_sequence_group_metadata
():
seq_ids
=
list
(
range
(
3
))
return
[
SequenceGroupMetadata
(
request_id
=
str
(
i
),
is_prompt
=
True
,
seq_data
=
{
i
:
MagicMock
(),
},
sampling_params
=
MagicMock
(),
block_tables
=
{
i
:
MagicMock
(),
},
lora_request
=
None
,
)
for
i
in
seq_ids
]
def
test_filter_zero_length_proposals
(
fake_sequence_group_metadata
):
proposal_lens
=
[
0
,
1
,
0
]
_
,
(
filtered_groups
,
indices
)
=
split_batch_by_proposal_len
(
fake_sequence_group_metadata
,
proposal_lens
)
expected_groups
=
[
fake_sequence_group_metadata
[
0
],
fake_sequence_group_metadata
[
2
]
]
expected_indices
=
[
0
,
2
]
assert
filtered_groups
==
expected_groups
assert
indices
==
expected_indices
def
test_filter_non_zero_length_proposals
(
fake_sequence_group_metadata
):
proposal_lens
=
[
0
,
1
,
2
]
(
filtered_groups
,
indices
),
_
=
split_batch_by_proposal_len
(
fake_sequence_group_metadata
,
proposal_lens
)
expected_groups
=
[
fake_sequence_group_metadata
[
1
],
fake_sequence_group_metadata
[
2
]
]
expected_indices
=
[
1
,
2
]
assert
filtered_groups
==
expected_groups
assert
indices
==
expected_indices
def
test_empty_inputs
():
_
,
(
filtered_groups
,
indices
)
=
split_batch_by_proposal_len
([],
[])
assert
filtered_groups
==
[]
assert
indices
==
[]
def
test_all_zero_with_non_zero_filter
(
fake_sequence_group_metadata
):
proposal_lens
=
[
0
,
0
,
0
]
(
filtered_groups
,
indices
),
_
=
split_batch_by_proposal_len
(
fake_sequence_group_metadata
,
proposal_lens
)
assert
filtered_groups
==
[]
assert
indices
==
[]
def
test_all_non_zero_with_zero_filter
(
fake_sequence_group_metadata
):
proposal_lens
=
[
1
,
1
,
1
]
_
,
(
filtered_groups
,
indices
)
=
split_batch_by_proposal_len
(
fake_sequence_group_metadata
,
proposal_lens
)
assert
filtered_groups
==
[]
assert
indices
==
[]
def
mock_spec_decode_sampler
(
acceptance_sampler_method
):
"""
Returns either a RejectionSampler or TypicalAcceptanceSampler
object depending on whether acceptance_sampler_method is
'rejection_sampler' or 'typical_acceptance_sampler' respectively.
"""
if
acceptance_sampler_method
==
"rejection_sampler"
:
sampler
=
MagicMock
(
spec
=
RejectionSampler
)
sampler
.
token_id_dtype
=
torch
.
int64
return
sampler
elif
acceptance_sampler_method
==
"typical_acceptance_sampler"
:
sampler
=
MagicMock
(
spec
=
TypicalAcceptanceSampler
)
sampler
.
token_id_dtype
=
torch
.
int64
return
sampler
else
:
raise
ValueError
(
f
"Invalid sampler name
{
acceptance_sampler_method
}
"
)
def
test_get_sampled_token_logprobs
():
"""Verify get_sampled_token_logprobs returns consistent rankings
with regular get_ranks when probabilities match exactly.
"""
logprob_tensor
=
torch
.
tensor
(
[[[
-
.
1
,
-
.
1
]]
*
2
])
# shape (num_steps, batch_size, vocab_size)
sampled_token_tensor
=
torch
.
tensor
([[
1
,
0
]])
# shape (num_steps, batch_size)
ranks_spec_dec
,
_
=
get_sampled_token_logprobs
(
logprob_tensor
,
sampled_token_tensor
)
ranks_regular
=
_get_ranks
(
logprob_tensor
.
reshape
((
2
,
-
1
)),
sampled_token_tensor
.
reshape
(
-
1
))
assert
torch
.
equal
(
ranks_spec_dec
.
reshape
(
-
1
),
ranks_regular
)
tests/spec_decode/utils.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Sequence
as
GenericSequence
from
itertools
import
count
from
typing
import
Callable
,
Optional
,
TypeVar
,
Union
from
unittest.mock
import
MagicMock
import
torch
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
SequenceData
,
SequenceGroupMetadata
,
SequenceOutput
)
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.model_runner
import
ModelRunner
from
vllm.worker.worker
import
Worker
T
=
TypeVar
(
"T"
,
bound
=
Worker
)
def
round_up_to_next_block
(
seq_len
:
int
,
block_size
:
int
)
->
int
:
return
(
seq_len
+
block_size
-
1
)
//
block_size
def
mock_worker
(
cls
=
None
,
vocab_size
:
int
=
30_000
,
max_model_len
:
int
=
2048
,
rank
:
int
=
0
,
use_spec
:
bool
=
True
)
->
MagicMock
:
if
cls
is
None
:
cls
=
Worker
spec
=
cls
if
use_spec
else
None
worker
=
MagicMock
(
spec
=
spec
)
worker
.
vocab_size
=
vocab_size
worker
.
max_model_len
=
max_model_len
worker
.
rank
=
rank
worker
.
device
=
'cuda:0'
return
worker
def
patch_execute_model_with_seeds
(
worker
:
Worker
,
rand_seeds
:
list
[
int
]):
seed_iter
=
iter
(
rand_seeds
)
original_execute_model
=
worker
.
execute_model
def
new_execute_model
(
*
args
,
**
kwargs
):
result
=
original_execute_model
(
*
args
,
**
kwargs
)
set_random_seed
(
next
(
seed_iter
))
return
result
return
new_execute_model
def
zero_kv_cache
(
cache_engine
:
list
[
CacheEngine
]):
assert
cache_engine
[
0
].
gpu_cache
for
key_blocks
,
value_blocks
in
cache_engine
[
0
].
gpu_cache
:
key_blocks
.
zero_
()
value_blocks
.
zero_
()
def
create_worker
(
cls
:
Callable
[...,
T
],
model_name
:
str
,
block_size
:
int
,
num_gpu_blocks
:
int
,
seed
:
int
,
is_driver_worker
:
bool
=
True
,
enforce_eager
:
bool
=
True
,
model_runner_cls
:
Optional
[
ModelRunner
]
=
None
,
dtype
:
Optional
[
str
]
=
"auto"
)
->
T
:
engine_args
=
EngineArgs
(
model
=
model_name
,
seed
=
seed
,
block_size
=
block_size
,
enforce_eager
=
enforce_eager
,
dtype
=
dtype
,
)
engine_config
=
engine_args
.
create_engine_config
()
distributed_init_method
=
get_distributed_init_method
(
get_ip
(),
get_open_port
())
worker
=
cls
(
vllm_config
=
engine_config
,
local_rank
=
0
,
rank
=
0
,
distributed_init_method
=
distributed_init_method
,
is_driver_worker
=
is_driver_worker
,
model_runner_cls
=
model_runner_cls
,
)
worker
.
init_device
()
worker
.
load_model
()
engine_config
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
engine_config
.
cache_config
.
num_cpu_blocks
=
0
worker
.
initialize_cache
(
num_gpu_blocks
=
engine_config
.
cache_config
.
num_gpu_blocks
,
num_cpu_blocks
=
engine_config
.
cache_config
.
num_cpu_blocks
)
return
worker
def
create_seq_group_metadata_from_prompts
(
prompts
:
list
[
list
[
int
]],
num_gpu_blocks
:
int
,
block_size
:
int
,
final_prompt_lens
:
list
[
int
],
continuations
:
Optional
[
list
[
list
[
int
]]]
=
None
,
seq_ids
:
Optional
[
list
[
int
]]
=
None
,
)
->
list
[
SequenceGroupMetadata
]:
if
continuations
is
None
:
continuations
=
[[]
for
_
in
prompts
]
if
seq_ids
is
None
:
seq_ids
=
list
(
i
for
i
,
_
in
enumerate
(
prompts
))
free_gpu_blocks
=
list
(
range
(
num_gpu_blocks
))
block_allocations
=
{
i
:
[
free_gpu_blocks
.
pop
()
for
_
in
range
(
round_up_to_next_block
(
final_len
,
block_size
))
]
for
i
,
final_len
in
enumerate
(
final_prompt_lens
)
}
seq_grou_metadata_list
=
[]
for
i
,
(
prompt_token_ids
,
cont_token_ids
)
in
enumerate
(
zip
(
prompts
,
continuations
)):
data
=
SequenceData
.
from_seqs
(
prompt_token_ids
,
cont_token_ids
)
data
.
update_num_computed_tokens
(
len
(
prompt_token_ids
)
+
len
(
cont_token_ids
)
-
1
)
seq_data
=
{
i
:
data
}
seq_grou_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
str
(
i
),
is_prompt
=
len
(
cont_token_ids
)
==
0
,
seq_data
=
seq_data
,
sampling_params
=
SamplingParams
(
temperature
=
0.0
),
block_tables
=
{
i
:
block_allocations
[
i
][:]},
))
return
seq_grou_metadata_list
def
create_chunked_seq_group_metadata_from_prompt
(
prompt
:
list
[
int
],
num_gpu_blocks
:
int
,
chunk_size
:
int
,
block_size
:
int
,
seq_id
:
Optional
[
int
]
=
None
)
->
list
[
SequenceGroupMetadata
]:
if
seq_id
is
None
:
seq_id
=
0
free_gpu_blocks
=
list
(
range
(
num_gpu_blocks
))
block_allocations
=
[
free_gpu_blocks
.
pop
()
for
_
in
range
(
round_up_to_next_block
(
len
(
prompt
),
block_size
))
]
seq_group_metadata_list
=
[]
for
i
,
idx
in
enumerate
(
range
(
0
,
len
(
prompt
),
chunk_size
)):
chunk_ids
=
prompt
[
idx
:
idx
+
chunk_size
]
data
=
SequenceData
.
from_seqs
(
prompt
)
data
.
update_num_computed_tokens
(
idx
)
seq_data
=
{
i
:
data
}
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
str
(
seq_id
),
is_prompt
=
True
,
do_sample
=
idx
+
chunk_size
>=
len
(
prompt
),
# terminal chunk
seq_data
=
seq_data
,
sampling_params
=
SamplingParams
(
temperature
=
0.0
),
block_tables
=
{
i
:
block_allocations
},
token_chunk_size
=
len
(
chunk_ids
)))
return
seq_group_metadata_list
def
assert_logprobs_dict_allclose
(
actual_logprobs
:
list
[
dict
[
int
,
Logprob
]],
expected_logprobs
:
list
[
dict
[
int
,
Logprob
]])
->
None
:
for
single_step_actual_logprobs
,
single_step_expected_logprobs
in
zip
(
actual_logprobs
,
expected_logprobs
):
assert
set
(
single_step_actual_logprobs
.
keys
())
==
set
(
single_step_expected_logprobs
.
keys
())
for
token_id
in
single_step_actual_logprobs
:
actual
=
torch
.
tensor
(
single_step_actual_logprobs
[
token_id
].
logprob
)
expected
=
torch
.
tensor
(
single_step_expected_logprobs
[
token_id
].
logprob
)
torch
.
testing
.
assert_close
(
actual
,
expected
)
def
create_sampler_output_list
(
token_ids
:
torch
.
Tensor
,
probs
:
GenericSequence
[
Optional
[
torch
.
Tensor
]],
logprobs
:
GenericSequence
[
Optional
[
torch
.
Tensor
]],
seq_ids
:
Optional
[
list
[
int
]]
=
None
)
->
list
[
SamplerOutput
]:
num_steps
,
batch_size
=
token_ids
.
shape
token_ids_by_step
=
token_ids
.
tolist
()
if
seq_ids
is
None
:
seq_ids
=
list
(
range
(
batch_size
))
return
[
SamplerOutput
(
outputs
=
[
CompletionSequenceGroupOutput
(
samples
=
[
SequenceOutput
(
output_token
=
token_id
,
parent_seq_id
=
seq_ids
[
seq_index
],
logprobs
=
{
token_id
:
Logprob
(
0
)},
)
],
prompt_logprobs
=
None
,
)
for
seq_index
,
token_id
in
enumerate
(
token_ids_by_step
[
step
])
],
sampled_token_probs
=
probs
[
step
],
logprobs
=
logprobs
[
step
],
sampled_token_ids
=
token_ids
[
step
])
for
step
in
range
(
num_steps
)
]
def
create_batch
(
batch_size
,
k
,
prompt_len
:
Union
[
int
,
list
[
int
]]
=
10
,
prev_output_token_len
:
int
=
10
,
seq_ids
:
Optional
[
list
[
int
]]
=
None
,
num_gpu_blocks
:
Optional
[
int
]
=
None
,
block_size
:
Optional
[
int
]
=
None
,
prefill_chunk_size
:
Optional
[
int
]
=
None
):
if
block_size
is
None
:
block_size
=
8
if
num_gpu_blocks
is
None
:
num_gpu_blocks
=
2048
//
block_size
iterator
=
count
()
if
isinstance
(
prompt_len
,
int
):
prompt_lens
=
[
prompt_len
for
_
in
range
(
batch_size
)]
else
:
prompt_lens
=
prompt_len
prompts
=
[[
next
(
iterator
)
for
_
in
range
(
p_len
)]
for
p_len
in
prompt_lens
]
if
prefill_chunk_size
:
# Create a batch of chunked prompts.
if
not
seq_ids
:
seq_ids
=
list
(
range
(
len
(
prompts
)))
seq_group_metadata_list
=
[]
for
p
,
sid
in
zip
(
prompts
,
seq_ids
):
seq_group_metadata_list
+=
\
create_chunked_seq_group_metadata_from_prompt
(
p
,
num_gpu_blocks
,
prefill_chunk_size
,
block_size
,
sid
)
seq_group_metadata_list
=
seq_group_metadata_list
[:
batch_size
]
prev_output_tokens
=
[]
else
:
prev_output_tokens
=
[[
next
(
iterator
)
for
_
in
range
(
prev_output_token_len
)
]
for
_
in
range
(
batch_size
)]
final_prompt_lens
=
[
len
(
prompt
)
+
len
(
prev_output_token
)
+
k
+
1
for
prompt
,
prev_output_token
in
zip
(
prompts
,
prev_output_tokens
)
]
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_prompt_lens
,
prev_output_tokens
,
seq_ids
)
return
seq_group_metadata_list
,
prompts
,
prev_output_tokens
def
maybe_enable_chunked_prefill
(
prefill_chunk_size
,
llm_kwargs
):
if
prefill_chunk_size
>
0
:
llm_kwargs
.
update
(
**
{
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
prefill_chunk_size
,
"max_num_seqs"
:
prefill_chunk_size
})
else
:
llm_kwargs
[
"enable_chunked_prefill"
]
=
False
tests/test_sequence.py
View file @
dd572c0a
...
...
@@ -29,7 +29,6 @@ def test_sampler_output_initialization(sampler_output, sample_outputs):
assert
len
(
sampler_output
)
==
len
(
sample_outputs
)
assert
sampler_output
.
sampled_token_probs
is
None
assert
sampler_output
.
sampled_token_ids
is
None
assert
sampler_output
.
spec_decode_worker_metrics
is
None
def
test_sampler_output_getitem
(
sampler_output
,
sample_outputs
):
...
...
tests/v1/test_oracle.py
View file @
dd572c0a
...
...
@@ -40,12 +40,6 @@ def test_unsupported_configs(monkeypatch):
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
with
pytest
.
raises
(
NotImplementedError
):
AsyncEngineArgs
(
model
=
MODEL
,
kv_cache_dtype
=
"fp8"
,
).
create_engine_config
()
with
pytest
.
raises
(
NotImplementedError
):
AsyncEngineArgs
(
model
=
MODEL
,
...
...
tools/mypy.sh
View file @
dd572c0a
...
...
@@ -32,6 +32,5 @@ run_mypy vllm/lora
run_mypy vllm/model_executor
run_mypy vllm/plugins
run_mypy vllm/prompt_adapter
run_mypy vllm/spec_decode
run_mypy vllm/worker
run_mypy vllm/v1
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment