Unverified Commit 50c9636d authored by shangmingc's avatar shangmingc Committed by GitHub
Browse files

[V1][Usage] Refactor speculative decoding configuration and tests (#14434)


Signed-off-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
parent 0661cfef
......@@ -30,8 +30,10 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(
model="facebook/opt-6.7b",
tensor_parallel_size=1,
speculative_model="facebook/opt-125m",
num_speculative_tokens=5,
speculative_config={
"model": "facebook/opt-125m",
"num_speculative_tokens": 5,
},
)
outputs = llm.generate(prompts, sampling_params)
......@@ -45,10 +47,14 @@ To perform the same with an online mode launch the server:
```bash
python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8000 --model facebook/opt-6.7b \
--seed 42 -tp 1 --speculative_model facebook/opt-125m \
--num_speculative_tokens 5 --gpu_memory_utilization 0.8
--seed 42 -tp 1 --gpu_memory_utilization 0.8 \
--speculative_config '{"model": "facebook/opt-125m", "num_speculative_tokens": 5}'
```
:::{warning}
Note: Please use `--speculative_config` to set all configurations related to speculative decoding. The previous method of specifying the model through `--speculative_model` and adding related parameters (e.g., `--num_speculative_tokens`) separately will be deprecated in the next release.
:::
Then use a client:
```python
......@@ -101,9 +107,11 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(
model="facebook/opt-6.7b",
tensor_parallel_size=1,
speculative_model="[ngram]",
num_speculative_tokens=5,
ngram_prompt_lookup_max=4,
speculative_config={
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 4,
},
)
outputs = llm.generate(prompts, sampling_params)
......@@ -131,8 +139,10 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(
model="meta-llama/Meta-Llama-3.1-70B-Instruct",
tensor_parallel_size=4,
speculative_model="ibm-ai-platform/llama3-70b-accelerator",
speculative_draft_tensor_parallel_size=1,
speculative_config={
"model": "ibm-ai-platform/llama3-70b-accelerator",
"draft_tensor_parallel_size": 1,
},
)
outputs = llm.generate(prompts, sampling_params)
......@@ -175,8 +185,10 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(
model="meta-llama/Meta-Llama-3-8B-Instruct",
tensor_parallel_size=4,
speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
speculative_draft_tensor_parallel_size=1,
speculative_config={
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
"draft_tensor_parallel_size": 1,
},
)
outputs = llm.generate(prompts, sampling_params)
......@@ -194,11 +206,10 @@ A few important things to consider when using the EAGLE based draft models:
be able to be loaded and used directly by vLLM after [PR 12304](https://github.com/vllm-project/vllm/pull/12304).
If you are using vllm version before [PR 12304](https://github.com/vllm-project/vllm/pull/12304), please use the
[script](https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d) to convert the speculative model,
and specify `speculative_model="path/to/modified/eagle/model"`. If weight-loading problems still occur when using
the latest version of vLLM, please leave a comment or raise an issue.
and specify `"model": "path/to/modified/eagle/model"` in `speculative_config`. If weight-loading problems still occur when using the latest version of vLLM, please leave a comment or raise an issue.
2. The EAGLE based draft models need to be run without tensor parallelism
(i.e. speculative_draft_tensor_parallel_size is set to 1), although
(i.e. draft_tensor_parallel_size is set to 1 in `speculative_config`), although
it is possible to run the main model using tensor parallelism (see example above).
3. When using EAGLE-based speculators with vLLM, the observed speedup is lower than what is
......
......@@ -50,7 +50,9 @@ if __name__ == "__main__":
# Create an LLM with spec decoding
llm = LLM(
model="meta-llama/Llama-2-13b-chat-hf",
speculative_model="ibm-ai-platform/llama-13b-accelerator",
speculative_config={
"model": "ibm-ai-platform/llama-13b-accelerator",
},
)
print("With speculation")
......
......@@ -56,7 +56,7 @@ def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
def maybe_assert_ngram_worker(llm):
# Verify the proposer worker is ngram if ngram is specified.
if (llm.llm_engine.speculative_config is not None
and llm.llm_engine.speculative_config.ngram_prompt_lookup_max > 0):
and llm.llm_engine.speculative_config.method == "ngram"):
from vllm.spec_decode.ngram_worker import NGramWorker
assert isinstance(
llm.llm_engine.model_executor.driver_worker.proposer_worker,
......
......@@ -7,28 +7,39 @@ from vllm import SamplingParams
from .conftest import get_output_from_llm_generator
@pytest.mark.parametrize("common_llm_kwargs", [{
"model": "meta-llama/Llama-3.2-1B-Instruct",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
}])
@pytest.mark.parametrize("common_llm_kwargs",
[{
"model": "meta-llama/Llama-3.2-1B-Instruct",
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
# Speculative max model len > overridden max model len should raise.
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"max_model_len": 129,
},
"max_model_len": 128,
"speculative_max_model_len": 129,
},
{
# Speculative max model len > draft max model len should raise.
# https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12
"speculative_max_model_len": 2048 + 1,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"max_model_len": 2048 + 1,
},
},
{
# Speculative max model len > target max model len should raise.
# https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct/blob/9213176726f574b556790deb65791e0c5aa438b6/config.json#L18
"speculative_max_model_len": 131072 + 1,
# https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/blob/9213176726f574b556790deb65791e0c5aa438b6/config.json#L18
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"max_model_len": 131072 + 1,
},
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
......
......@@ -57,8 +57,10 @@ PRECISION = "float32"
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
......@@ -95,18 +97,19 @@ def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": False,
"disable_logprobs": False,
},
{
"speculative_model": SPEC_MODEL,
}, {
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": True,
"disable_logprobs": True,
},
])
}])
@pytest.mark.parametrize("output_len", [
128,
])
......@@ -119,18 +122,19 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
batch_size: int, output_len: int, seed: int,
logprobs: int):
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
......@@ -151,8 +155,10 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
......@@ -193,8 +199,10 @@ def test_eagle_e2e_greedy_correctness_cuda_graph(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize(
......@@ -236,8 +244,10 @@ def test_eagle_e2e_greedy_correctness_with_preemption(
"test_llm_kwargs",
[
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": k,
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": k,
},
}
# Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS)
......@@ -277,12 +287,13 @@ def test_eagle_different_k(vllm_runner, common_llm_kwargs,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_by_batch_size": 4,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
......@@ -324,8 +335,10 @@ def test_eagle_disable_queue(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "yuhuili/EAGLE-llama2-chat-7B",
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"model": "yuhuili/EAGLE-llama2-chat-7B",
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize(
......@@ -372,8 +385,10 @@ def test_llama2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize(
......@@ -420,8 +435,10 @@ def test_llama3_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "yuhuili/EAGLE-Qwen2-7B-Instruct",
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"model": "yuhuili/EAGLE-Qwen2-7B-Instruct",
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize(
......
......@@ -23,8 +23,10 @@ MAIN_MODEL = "JackFram/llama-68m"
[
{
# Identical models.
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
......@@ -57,26 +59,33 @@ def test_spec_decode_cuda_graph(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"speculative_model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
# Explicitly specify draft model quantization
{
"speculative_model_quantization": "gptq",
"speculative_config": {
"model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
"num_speculative_tokens": 5,
"quantization": "gptq",
},
},
# Explicitly specify GPTQ-based draft model to use marlin quantization
{
"speculative_model_quantization": "marlin",
"speculative_config": {
"model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
"num_speculative_tokens": 5,
"quantization": "marlin",
},
},
# Not explicitly specify draft model quantization
{
"speculative_model_quantization": None,
"speculative_config": {
"model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
"num_speculative_tokens": 5,
"quantization": None,
},
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
......@@ -107,15 +116,16 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_disable_mqa_scorer": True,
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_mqa_scorer": True,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
......@@ -127,7 +137,7 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
"""Verify that ngram speculative decoding generates the same output
"""Verify that speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
run_equality_correctness_test(vllm_runner,
......
......@@ -27,18 +27,19 @@ from .conftest import run_equality_correctness_test_tp
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
@pytest.mark.parametrize("test_llm_kwargs", [
[
"--speculative-model",
"JackFram/llama-68m",
"--num-speculative-tokens",
"3",
"--speculative_config",
str({
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
}),
],
[
"--speculative-model",
"[ngram]",
"--num-speculative-tokens",
"5",
"--ngram-prompt-lookup-max",
"3",
"--speculative_config",
str({
"model": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
}),
],
])
@pytest.mark.parametrize("batch_size", [2])
......@@ -83,23 +84,24 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
]])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [[]])
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
@pytest.mark.parametrize("model, test_llm_kwargs",
[("JackFram/llama-68m", [
"--speculative-model",
"JackFram/llama-68m",
"--num_speculative-tokens",
"5",
"--speculative-draft-tensor-parallel-size",
"1",
]),
("ibm-granite/granite-3b-code-instruct", [
"--speculative-model",
"ibm-granite/granite-3b-code-instruct",
"--num_speculative-tokens",
"5",
"--speculative-draft-tensor-parallel-size",
"1",
])])
@pytest.mark.parametrize(
"model, test_llm_kwargs",
[("JackFram/llama-68m", [
"--speculative_config",
str({
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"draft_tensor_parallel_size": 1,
}),
]),
("ibm-granite/granite-3b-code-instruct", [
"--speculative_config",
str({
"model": "ibm-granite/granite-3b-code-instruct",
"num_speculative_tokens": 5,
"draft_tensor_parallel_size": 1,
}),
])])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
......@@ -144,18 +146,19 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
@pytest.mark.parametrize("model, test_llm_kwargs",
[("JackFram/llama-68m", [
"--speculative-model",
"JackFram/llama-68m",
"--num_speculative-tokens",
"3",
"--speculative_config",
str({
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
}),
]),
("JackFram/llama-68m", [
"--speculative-model",
"JackFram/llama-68m",
"--num_speculative-tokens",
"3",
"--speculative-draft-tensor-parallel-size",
"1",
"--speculative_config",
str({
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"draft_tensor_parallel_size": 1,
}),
])])
@pytest.mark.parametrize("logprobs", [None, 2])
@pytest.mark.parametrize("batch_size", [2])
......
......@@ -24,12 +24,7 @@ SPEC_MODEL = "JackFram/llama-68m"
"4",
]])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
[
"--speculative-model",
f"{SPEC_MODEL}",
"--num-speculative-tokens",
"5",
],
[],
])
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
@pytest.mark.parametrize(
......@@ -37,8 +32,12 @@ SPEC_MODEL = "JackFram/llama-68m"
[
#TODO(wooyeon): add spec_draft_dp=2 case
[
"--speculative-draft-tensor-parallel-size",
"1",
"--speculative_config",
str({
"model": f"{SPEC_MODEL}",
"num_speculative_tokens": 5,
"draft_tensor_parallel_size": 1,
}),
],
])
@pytest.mark.parametrize("batch_size", [2])
......@@ -78,15 +77,14 @@ def test_draft_model_tp_lt_target_model_tp4(common_llm_kwargs,
"test_llm_kwargs",
[
[
"--speculative-model",
f"{SPEC_MODEL}",
"--num-speculative-tokens",
"5",
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"--speculative-max-model-len",
"32",
"--speculative_config",
str({
"model": f"{SPEC_MODEL}",
"num_speculative_tokens": 5,
"max_model_len": 32,
}),
],
])
@pytest.mark.parametrize("batch_size", [8])
......
......@@ -20,16 +20,19 @@ from .conftest import run_equality_correctness_test
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
}, {
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": True,
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs": False,
},
}, {
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs": True,
},
}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
"output_len",
......@@ -48,19 +51,20 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
as well as with and without chunked prefill.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
......@@ -73,16 +77,19 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
}, {
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 6,
"disable_logprobs_during_spec_decoding": False,
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs": False,
},
}, {
"speculative_config": {
"model": "JackFram/llama-160m",
"num_speculative_tokens": 6,
"disable_logprobs": False,
},
}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
"output_len",
......@@ -98,18 +105,19 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
output_len: int, seed: int, logprobs: int):
"""Veriy logprob greedy equality with different speculation lens.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
......@@ -125,13 +133,15 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize(
"test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_max_model_len": 32,
"speculative_config": {
"model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs": False,
# Artificially limit the draft model max model len; this forces
# vLLM to skip speculation once the sequences grow beyond 32-k
# tokens.
"max_model_len": 32,
},
}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
......@@ -149,18 +159,19 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
seed: int, logprobs: int):
"""Verify logprobs greedy equality when some sequences skip speculation.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
......@@ -173,12 +184,13 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs": False,
},
}])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize(
"output_len",
......@@ -248,12 +260,13 @@ def test_logprobs_temp_1(vllm_runner, common_llm_kwargs,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": True,
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs": True,
},
}])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize(
......@@ -270,15 +283,16 @@ def test_logprobs_disabled(vllm_runner, common_llm_kwargs,
"""Check the behavior when logprobs are disabled.
Token choices should match with the base model.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
......@@ -60,8 +60,10 @@ PRECISION = "float32"
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
......@@ -107,14 +109,18 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": False,
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs": False,
},
},
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": True,
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs": True,
},
},
])
@pytest.mark.parametrize("output_len", [
......@@ -132,19 +138,20 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
prefill_chunk_size: int):
"""Verify greedy equality with different batch size."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
......@@ -165,8 +172,10 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
......@@ -214,8 +223,10 @@ def test_medusa_e2e_greedy_correctness_cuda_graph(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize(
......@@ -264,8 +275,10 @@ def test_medusa_e2e_greedy_correctness_with_preemption(
"test_llm_kwargs",
[
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": k,
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": k,
},
}
# Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS)
......@@ -312,12 +325,13 @@ def test_medusa_different_k(vllm_runner, common_llm_kwargs,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_by_batch_size": 4,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
......@@ -359,16 +373,17 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
# Main model
"model_name": MAIN_MODEL,
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_disable_mqa_scorer": True,
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_by_batch_size": 4,
"disable_mqa_scorer": True,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
......
......@@ -62,7 +62,9 @@ PRECISION = "float32"
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"speculative_config": {
"model": SPEC_MODEL,
},
},
])
@pytest.mark.parametrize("output_len", [
......@@ -108,12 +110,16 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"disable_logprobs_during_spec_decoding": False,
"speculative_config": {
"model": SPEC_MODEL,
"disable_logprobs": False,
},
},
{
"speculative_model": SPEC_MODEL,
"disable_logprobs_during_spec_decoding": True,
"speculative_config": {
"model": SPEC_MODEL,
"disable_logprobs": True,
},
},
])
@pytest.mark.parametrize("output_len", [8])
......@@ -133,19 +139,20 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
# up sampling different tokens at the tail (ie top tokens don't change).
# TL;DR: sd+cp == org+cp but sd+cp != org..is this expected?
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
......@@ -167,7 +174,9 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"speculative_config": {
"model": SPEC_MODEL,
},
},
])
@pytest.mark.parametrize("output_len", [2048])
......@@ -209,8 +218,10 @@ def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
# Main model
"model_name": MAIN_MODEL,
# Speculative model
"speculative_model": SPEC_MODEL,
# Speculative config
"speculative_config": {
"model": SPEC_MODEL,
},
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
......@@ -274,7 +285,9 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"speculative_config": {
"model": SPEC_MODEL,
},
},
])
@pytest.mark.parametrize(
......@@ -326,7 +339,9 @@ def test_mlp_e2e_greedy_correctness_with_preemption(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"speculative_config": {
"model": SPEC_MODEL,
},
},
])
@pytest.mark.parametrize(
......@@ -382,8 +397,10 @@ def test_mlp_e2e_greedy_correctness_with_padding(
"test_llm_kwargs",
[
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": k,
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": k,
},
}
# Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS)
......@@ -430,11 +447,12 @@ def test_mlp_different_k(vllm_runner, common_llm_kwargs,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": SPEC_MODEL,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": SPEC_MODEL,
"disable_by_batch_size": 4,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
......@@ -475,14 +493,15 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
"speculative_model": SPEC_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_disable_mqa_scorer": True,
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": SPEC_MODEL,
"disable_mqa_scorer": True,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
......
......@@ -57,7 +57,9 @@ PRECISION = "bfloat16"
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
......@@ -99,12 +101,16 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": False,
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs": False,
},
},
{
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": True,
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs": True,
},
},
])
@pytest.mark.parametrize("output_len", [
......@@ -119,18 +125,19 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
batch_size: int, output_len: int, seed: int,
logprobs: int):
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
......@@ -152,7 +159,9 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
......@@ -198,7 +207,9 @@ def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize(
......@@ -243,7 +254,9 @@ def test_mtp_e2e_greedy_correctness_with_preemption(
"test_llm_kwargs",
[
{
"num_speculative_tokens": k,
"speculative_config": {
"num_speculative_tokens": k,
},
}
# Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS)
......@@ -286,11 +299,12 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_by_batch_size": 4
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
......
......@@ -61,15 +61,19 @@ from .conftest import (get_output_from_llm_generator,
"per_test_common_llm_kwargs",
[
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False,
},
{
# Chunked prefill enabled with small value
# to make sure we get mixed batches.
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
......@@ -148,20 +152,23 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
"disable_logprobs_during_spec_decoding": False
}, {
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
"disable_logprobs_during_spec_decoding": False
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"disable_logprobs": False,
},
"enable_chunked_prefill": False,
}, {
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs": False,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
}])
@pytest.mark.parametrize(
"output_len",
[
......@@ -184,7 +191,7 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
whether all speculative tokens are accepted.
"""
ensure_all_accepted = per_test_common_llm_kwargs.get(
"model_name") == test_llm_kwargs.get("speculative_model")
"model_name") == test_llm_kwargs.get("speculative_config")["model"]
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
......@@ -224,13 +231,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
......@@ -283,13 +294,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
......@@ -336,13 +351,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
......@@ -391,13 +410,17 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
......@@ -449,13 +472,17 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
......@@ -514,13 +541,17 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
......@@ -567,21 +598,25 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs,
"test_llm_kwargs",
[
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_max_model_len": 32,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"max_model_len": 32,
},
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"max_model_len": 32,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
"speculative_max_model_len": 32,
},
])
@pytest.mark.parametrize("batch_size", [8])
......@@ -627,15 +662,19 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_disable_by_batch_size": 2,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"disable_by_batch_size": 2,
},
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_disable_by_batch_size": 2,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"disable_by_batch_size": 2,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
......@@ -676,15 +715,19 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs,
"test_llm_kwargs",
[
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": k,
},
"enable_chunked_prefill": False,
}
# Try a range of common k, as well as large speculation.
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
] + [{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": k,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
......@@ -729,17 +772,21 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
"test_llm_kwargs",
[
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"spec_decoding_acceptance_method": "typical_acceptance_sampler",
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"acceptance_method": "typical_acceptance_sampler",
},
"enable_chunked_prefill": False
}
# Try a range of common k.
for k in [1, 2, 3]
] + [{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"spec_decoding_acceptance_method": "typical_acceptance_sampler",
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"acceptance_method": "typical_acceptance_sampler",
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
......
......@@ -48,16 +48,20 @@ from .conftest import run_equality_correctness_test
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"speculative_disable_mqa_scorer": False,
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_mqa_scorer": False,
},
},
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"speculative_disable_mqa_scorer": True,
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_mqa_scorer": True,
},
},
])
@pytest.mark.parametrize("output_len", [
......@@ -101,16 +105,20 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"disable_logprobs_during_spec_decoding": False,
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_logprobs": False,
},
},
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"disable_logprobs_during_spec_decoding": True,
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_logprobs": True,
},
},
])
@pytest.mark.parametrize("output_len", [
......@@ -125,19 +133,20 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
batch_size: int, output_len: int, seed: int,
logprobs: int):
"""Verify greedy equality on a tiny model with different batch size."""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
......@@ -159,17 +168,21 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
},
"enable_chunked_prefill": False,
},
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_mqa_scorer": True,
},
"enable_chunked_prefill": True,
"speculative_disable_mqa_scorer": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
......@@ -214,17 +227,21 @@ def test_ngram_e2e_greedy_correctness_with_preemption(
"test_llm_kwargs",
[
{
"speculative_model": "[ngram]",
"num_speculative_tokens": k,
"ngram_prompt_lookup_max": 3,
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": k,
"prompt_lookup_max": 3,
},
}
# Try a range of common k, as well as large speculation.
for k in [1, 3, 5]
] + [
{
"speculative_model": "[ngram]",
"num_speculative_tokens": k,
"ngram_prompt_lookup_max": 1,
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": k,
"prompt_lookup_max": 1,
},
}
# Try a range of common k, as well as large speculation.
for k in [1, 3, 5]
......@@ -243,7 +260,7 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
seed: int):
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and
different ngram_prompt_lookup_max.
different ngram prompt_lookup_max.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
......@@ -266,22 +283,25 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"speculative_disable_by_batch_size": 4
}, {
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"speculative_disable_by_batch_size": 4,
"enable_chunked_prefill": True,
"speculative_disable_mqa_scorer": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_by_batch_size": 4
},
}, {
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_by_batch_size": 4,
"disable_mqa_scorer": True,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
......@@ -296,7 +316,7 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
seed: int):
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and
different ngram_prompt_lookup_max.
different ngram prompt_lookup_max.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
......@@ -316,18 +336,17 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_disable_mqa_scorer": True,
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_mqa_scorer": True,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
......
......@@ -19,11 +19,11 @@ SPEC_MODEL = "JackFram/llama-160m"
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# speculative model
"speculative_model": "JackFram/llama-160m",
# num speculative tokens
"num_speculative_tokens": 3,
# speculative config
"speculative_config": {
"model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
},
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
......
......@@ -70,12 +70,16 @@ def test_ngram_correctness(
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
spec_llm = LLM(model=model_name,
speculative_model='[ngram]',
ngram_prompt_lookup_max=5,
ngram_prompt_lookup_min=3,
num_speculative_tokens=3,
max_model_len=1024)
spec_llm = LLM(
model=model_name,
speculative_config={
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": 3,
},
max_model_len=1024,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
......
This diff is collapsed.
......@@ -177,7 +177,10 @@ class EngineArgs:
guided_decoding_backend: str = 'xgrammar'
logits_processor_pattern: Optional[str] = None
# Speculative decoding configuration.
speculative_config: Optional[Union[str, Dict[str, Any]]] = None
# TODO(Shangming): Deprecate these out-of-date params after next release
speculative_model: Optional[str] = None
speculative_model_quantization: Optional[str] = None
speculative_draft_tensor_parallel_size: Optional[int] = None
......@@ -190,9 +193,9 @@ class EngineArgs:
spec_decoding_acceptance_method: str = 'rejection_sampler'
typical_acceptance_sampler_posterior_threshold: Optional[float] = None
typical_acceptance_sampler_posterior_alpha: Optional[float] = None
qlora_adapter_name_or_path: Optional[str] = None
disable_logprobs_during_spec_decoding: Optional[bool] = None
qlora_adapter_name_or_path: Optional[str] = None
show_hidden_metrics_for_version: Optional[str] = None
otlp_traces_endpoint: Optional[str] = None
collect_detailed_traces: Optional[str] = None
......@@ -780,7 +783,11 @@ class EngineArgs:
const="True",
help='If set, the prefill requests can be chunked based on the '
'max_num_batched_tokens.')
parser.add_argument('--speculative-config',
type=nullable_str,
default=None,
help='The configurations for speculative decoding.'
' Should be a JSON string.')
parser.add_argument(
'--speculative-model',
type=nullable_str,
......@@ -1192,6 +1199,82 @@ class EngineArgs:
use_tqdm_on_load=self.use_tqdm_on_load,
)
def create_speculative_config(
self,
target_model_config: ModelConfig,
target_parallel_config: ParallelConfig,
enable_chunked_prefill: bool,
disable_log_stats: bool,
) -> Optional["SpeculativeConfig"]:
"""Initializes and returns a SpeculativeConfig object based on
`speculative_config`.
This function utilizes `speculative_config` to create a
SpeculativeConfig object. The `speculative_config` can either be
provided as a JSON string input via CLI arguments or directly as a
dictionary from the engine. If `speculative_config` is not set, this
function will attempt to construct a configuration dictionary using
certain parameters, which are scheduled for deprecation in the next
release. Note that in next releases, `speculative_config` must be
provided, and the deprecated standalone speculative-related parameters
will be removed.
"""
if self.speculative_config is None:
if (self.speculative_model is None
and self.num_speculative_tokens is None):
return None
# TODO(Shangming): Deprecate this way of setting SpeculativeConfig,
# only allow '--speculative-config' after next release
logger.warning_once(
"Please use '--speculative-config' to set all configurations "
"related to speculative decoding. The current method of "
"specifying the model through '--speculative-model' and "
"adding related parameters (e.g., '--num-speculative-tokens') "
"separately will be deprecated in the next release.")
spec_config_dict = {
"model": self.speculative_model,
"quantization": self.speculative_model_quantization,
"max_model_len": self.speculative_max_model_len,
"draft_tensor_parallel_size":
self.speculative_draft_tensor_parallel_size,
"num_speculative_tokens": self.num_speculative_tokens,
"disable_mqa_scorer": self.speculative_disable_mqa_scorer,
"disable_by_batch_size":
self.speculative_disable_by_batch_size,
"prompt_lookup_max": self.ngram_prompt_lookup_max,
"prompt_lookup_min": self.ngram_prompt_lookup_min,
"acceptance_method": self.spec_decoding_acceptance_method,
"posterior_threshold":
self.typical_acceptance_sampler_posterior_threshold,
"posterior_alpha":
self.typical_acceptance_sampler_posterior_alpha,
"disable_logprobs": self.disable_logprobs_during_spec_decoding,
}
self.speculative_config = spec_config_dict
else:
if isinstance(self.speculative_config, str):
import ast
self.speculative_config = ast.literal_eval(
self.speculative_config)
# Note(Shangming): These parameters are not obtained from the cli arg
# '--speculative-config' and must be passed in when creating the engine
# config.
assert isinstance(self.speculative_config, dict)
self.speculative_config.update({
"target_model_config": target_model_config,
"target_parallel_config": target_parallel_config,
"enable_chunked_prefill": enable_chunked_prefill,
"disable_log_stats": disable_log_stats,
})
speculative_config = SpeculativeConfig.from_dict(
self.speculative_config)
return speculative_config
def create_engine_config(
self,
usage_context: Optional[UsageContext] = None,
......@@ -1238,6 +1321,8 @@ class EngineArgs:
else:
self._set_default_args_v0(model_config)
assert self.enable_chunked_prefill is not None
cache_config = CacheConfig(
block_size=self.block_size,
gpu_memory_utilization=self.gpu_memory_utilization,
......@@ -1280,31 +1365,11 @@ class EngineArgs:
worker_extension_cls=self.worker_extension_cls,
)
speculative_config = SpeculativeConfig.maybe_create_spec_config(
speculative_config = self.create_speculative_config(
target_model_config=model_config,
target_parallel_config=parallel_config,
target_dtype=self.dtype,
speculative_model=self.speculative_model,
speculative_model_quantization = \
self.speculative_model_quantization,
speculative_draft_tensor_parallel_size = \
self.speculative_draft_tensor_parallel_size,
num_speculative_tokens=self.num_speculative_tokens,
speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer,
speculative_disable_by_batch_size=self.
speculative_disable_by_batch_size,
speculative_max_model_len=self.speculative_max_model_len,
enable_chunked_prefill=self.enable_chunked_prefill,
disable_log_stats=self.disable_log_stats,
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
draft_token_acceptance_method=\
self.spec_decoding_acceptance_method,
typical_acceptance_sampler_posterior_threshold=self.
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=self.
typical_acceptance_sampler_posterior_alpha,
disable_logprobs=self.disable_logprobs_during_spec_decoding,
)
# Reminder: Please update docs/source/features/compatibility_matrix.md
......@@ -1569,7 +1634,7 @@ class EngineArgs:
if (self.speculative_model is not None
or self.num_speculative_tokens is not None):
# This is supported but experimental (handled below).
if self.speculative_model == "[ngram]":
if self.speculative_model in ("ngram", "[ngram]"):
pass
else:
_raise_or_fallback(feature_name="Speculative Decoding",
......@@ -1617,7 +1682,8 @@ class EngineArgs:
return False
# ngram is supported on V1, but off by default for now.
if self.speculative_model == "[ngram]" and _warn_or_fallback("ngram"):
if self.speculative_model in (
"ngram", "[ngram]") and _warn_or_fallback("ngram"):
return False
# Non-CUDA is supported on V1, but off by default for now.
......
......@@ -92,22 +92,20 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
# Override draft-model specific worker args.
draft_worker_kwargs.update(
vllm_config=draft_worker_config,
ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max,
ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min,
ngram_prompt_lookup_max=speculative_config.prompt_lookup_max,
ngram_prompt_lookup_min=speculative_config.prompt_lookup_min,
)
spec_decode_worker = SpecDecodeWorker.create_worker(
scorer_worker=target_worker,
draft_worker_kwargs=draft_worker_kwargs,
disable_mqa_scorer=speculative_config.speculative_disable_mqa_scorer,
disable_by_batch_size=speculative_config.
speculative_disable_by_batch_size,
draft_token_acceptance_method=speculative_config.
draft_token_acceptance_method,
disable_mqa_scorer=speculative_config.disable_mqa_scorer,
disable_by_batch_size=speculative_config.disable_by_batch_size,
draft_token_acceptance_method=speculative_config.acceptance_method,
typical_acceptance_sampler_posterior_threshold=speculative_config.
typical_acceptance_sampler_posterior_threshold,
posterior_threshold,
typical_acceptance_sampler_posterior_alpha=speculative_config.
typical_acceptance_sampler_posterior_alpha,
posterior_alpha,
disable_logprobs=speculative_config.disable_logprobs,
disable_log_stats=speculative_config.disable_log_stats,
num_speculative_tokens=speculative_config.num_speculative_tokens,
......
......@@ -151,8 +151,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.use_spec_decode = False
if self.speculative_config:
self.use_spec_decode = True
# TODO: find a better way to check if we are using ngram.
assert self.speculative_config.ngram_prompt_lookup_min, \
assert self.speculative_config.method == "ngram", \
"Currently, only ngram spec decode is supported in V1."
if get_pp_group().is_last_rank:
self.drafter = NgramProposer()
......@@ -160,7 +159,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# This usually takes less than 1 second.
self.drafter.propose(
np.zeros(1024, dtype=np.int32),
self.speculative_config.ngram_prompt_lookup_min,
self.speculative_config.prompt_lookup_min,
self.speculative_config.num_speculative_tokens,
)
self.rejection_sampler = RejectionSampler()
......@@ -1155,7 +1154,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
drafter_output = self.drafter.propose(
self.input_batch.token_ids_cpu[i, :end_idx],
self.speculative_config.ngram_prompt_lookup_min,
self.speculative_config.prompt_lookup_min,
self.speculative_config.num_speculative_tokens,
)
if drafter_output is None or len(drafter_output) == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment