Unverified Commit 50c9636d authored by shangmingc's avatar shangmingc Committed by GitHub
Browse files

[V1][Usage] Refactor speculative decoding configuration and tests (#14434)


Signed-off-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
parent 0661cfef
...@@ -30,8 +30,10 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) ...@@ -30,8 +30,10 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM( llm = LLM(
model="facebook/opt-6.7b", model="facebook/opt-6.7b",
tensor_parallel_size=1, tensor_parallel_size=1,
speculative_model="facebook/opt-125m", speculative_config={
num_speculative_tokens=5, "model": "facebook/opt-125m",
"num_speculative_tokens": 5,
},
) )
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
...@@ -45,10 +47,14 @@ To perform the same with an online mode launch the server: ...@@ -45,10 +47,14 @@ To perform the same with an online mode launch the server:
```bash ```bash
python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8000 --model facebook/opt-6.7b \ python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8000 --model facebook/opt-6.7b \
--seed 42 -tp 1 --speculative_model facebook/opt-125m \ --seed 42 -tp 1 --gpu_memory_utilization 0.8 \
--num_speculative_tokens 5 --gpu_memory_utilization 0.8 --speculative_config '{"model": "facebook/opt-125m", "num_speculative_tokens": 5}'
``` ```
:::{warning}
Note: Please use `--speculative_config` to set all configurations related to speculative decoding. The previous method of specifying the model through `--speculative_model` and adding related parameters (e.g., `--num_speculative_tokens`) separately will be deprecated in the next release.
:::
Then use a client: Then use a client:
```python ```python
...@@ -101,9 +107,11 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) ...@@ -101,9 +107,11 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM( llm = LLM(
model="facebook/opt-6.7b", model="facebook/opt-6.7b",
tensor_parallel_size=1, tensor_parallel_size=1,
speculative_model="[ngram]", speculative_config={
num_speculative_tokens=5, "method": "ngram",
ngram_prompt_lookup_max=4, "num_speculative_tokens": 5,
"prompt_lookup_max": 4,
},
) )
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
...@@ -131,8 +139,10 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) ...@@ -131,8 +139,10 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM( llm = LLM(
model="meta-llama/Meta-Llama-3.1-70B-Instruct", model="meta-llama/Meta-Llama-3.1-70B-Instruct",
tensor_parallel_size=4, tensor_parallel_size=4,
speculative_model="ibm-ai-platform/llama3-70b-accelerator", speculative_config={
speculative_draft_tensor_parallel_size=1, "model": "ibm-ai-platform/llama3-70b-accelerator",
"draft_tensor_parallel_size": 1,
},
) )
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
...@@ -175,8 +185,10 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) ...@@ -175,8 +185,10 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM( llm = LLM(
model="meta-llama/Meta-Llama-3-8B-Instruct", model="meta-llama/Meta-Llama-3-8B-Instruct",
tensor_parallel_size=4, tensor_parallel_size=4,
speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B", speculative_config={
speculative_draft_tensor_parallel_size=1, "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
"draft_tensor_parallel_size": 1,
},
) )
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
...@@ -194,11 +206,10 @@ A few important things to consider when using the EAGLE based draft models: ...@@ -194,11 +206,10 @@ A few important things to consider when using the EAGLE based draft models:
be able to be loaded and used directly by vLLM after [PR 12304](https://github.com/vllm-project/vllm/pull/12304). be able to be loaded and used directly by vLLM after [PR 12304](https://github.com/vllm-project/vllm/pull/12304).
If you are using vllm version before [PR 12304](https://github.com/vllm-project/vllm/pull/12304), please use the If you are using vllm version before [PR 12304](https://github.com/vllm-project/vllm/pull/12304), please use the
[script](https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d) to convert the speculative model, [script](https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d) to convert the speculative model,
and specify `speculative_model="path/to/modified/eagle/model"`. If weight-loading problems still occur when using and specify `"model": "path/to/modified/eagle/model"` in `speculative_config`. If weight-loading problems still occur when using the latest version of vLLM, please leave a comment or raise an issue.
the latest version of vLLM, please leave a comment or raise an issue.
2. The EAGLE based draft models need to be run without tensor parallelism 2. The EAGLE based draft models need to be run without tensor parallelism
(i.e. speculative_draft_tensor_parallel_size is set to 1), although (i.e. draft_tensor_parallel_size is set to 1 in `speculative_config`), although
it is possible to run the main model using tensor parallelism (see example above). it is possible to run the main model using tensor parallelism (see example above).
3. When using EAGLE-based speculators with vLLM, the observed speedup is lower than what is 3. When using EAGLE-based speculators with vLLM, the observed speedup is lower than what is
......
...@@ -50,7 +50,9 @@ if __name__ == "__main__": ...@@ -50,7 +50,9 @@ if __name__ == "__main__":
# Create an LLM with spec decoding # Create an LLM with spec decoding
llm = LLM( llm = LLM(
model="meta-llama/Llama-2-13b-chat-hf", model="meta-llama/Llama-2-13b-chat-hf",
speculative_model="ibm-ai-platform/llama-13b-accelerator", speculative_config={
"model": "ibm-ai-platform/llama-13b-accelerator",
},
) )
print("With speculation") print("With speculation")
......
...@@ -56,7 +56,7 @@ def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, ...@@ -56,7 +56,7 @@ def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
def maybe_assert_ngram_worker(llm): def maybe_assert_ngram_worker(llm):
# Verify the proposer worker is ngram if ngram is specified. # Verify the proposer worker is ngram if ngram is specified.
if (llm.llm_engine.speculative_config is not None if (llm.llm_engine.speculative_config is not None
and llm.llm_engine.speculative_config.ngram_prompt_lookup_max > 0): and llm.llm_engine.speculative_config.method == "ngram"):
from vllm.spec_decode.ngram_worker import NGramWorker from vllm.spec_decode.ngram_worker import NGramWorker
assert isinstance( assert isinstance(
llm.llm_engine.model_executor.driver_worker.proposer_worker, llm.llm_engine.model_executor.driver_worker.proposer_worker,
......
...@@ -7,28 +7,39 @@ from vllm import SamplingParams ...@@ -7,28 +7,39 @@ from vllm import SamplingParams
from .conftest import get_output_from_llm_generator from .conftest import get_output_from_llm_generator
@pytest.mark.parametrize("common_llm_kwargs", [{ @pytest.mark.parametrize("common_llm_kwargs",
[{
"model": "meta-llama/Llama-3.2-1B-Instruct", "model": "meta-llama/Llama-3.2-1B-Instruct",
"speculative_model": "JackFram/llama-68m", }])
"num_speculative_tokens": 5,
}])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"per_test_common_llm_kwargs", "per_test_common_llm_kwargs",
[ [
{ {
# Speculative max model len > overridden max model len should raise. # Speculative max model len > overridden max model len should raise.
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"max_model_len": 129,
},
"max_model_len": 128, "max_model_len": 128,
"speculative_max_model_len": 129,
}, },
{ {
# Speculative max model len > draft max model len should raise. # Speculative max model len > draft max model len should raise.
# https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12 # https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12
"speculative_max_model_len": 2048 + 1, "speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"max_model_len": 2048 + 1,
},
}, },
{ {
# Speculative max model len > target max model len should raise. # Speculative max model len > target max model len should raise.
# https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct/blob/9213176726f574b556790deb65791e0c5aa438b6/config.json#L18 # https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/blob/9213176726f574b556790deb65791e0c5aa438b6/config.json#L18
"speculative_max_model_len": 131072 + 1, "speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"max_model_len": 131072 + 1,
},
}, },
]) ])
@pytest.mark.parametrize("test_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [{}])
......
...@@ -57,9 +57,11 @@ PRECISION = "float32" ...@@ -57,9 +57,11 @@ PRECISION = "float32"
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS, "num_speculative_tokens": MAX_SPEC_TOKENS,
}, },
},
]) ])
@pytest.mark.parametrize("output_len", [ @pytest.mark.parametrize("output_len", [
128, 128,
...@@ -95,18 +97,19 @@ def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, ...@@ -95,18 +97,19 @@ def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [{
{ "speculative_config": {
"speculative_model": SPEC_MODEL, "model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS, "num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": False, "disable_logprobs": False,
}, },
{ }, {
"speculative_model": SPEC_MODEL, "speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS, "num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": True, "disable_logprobs": True,
}, },
]) }])
@pytest.mark.parametrize("output_len", [ @pytest.mark.parametrize("output_len", [
128, 128,
]) ])
...@@ -119,7 +122,8 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, ...@@ -119,7 +122,8 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
batch_size: int, output_len: int, seed: int, batch_size: int, output_len: int, seed: int,
logprobs: int): logprobs: int):
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(
vllm_runner,
common_llm_kwargs, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, baseline_llm_kwargs,
...@@ -129,8 +133,8 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, ...@@ -129,8 +133,8 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
seed, seed,
logprobs=logprobs, logprobs=logprobs,
prompt_logprobs=logprobs, prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[ disable_logprobs=test_llm_kwargs["speculative_config"]
'disable_logprobs_during_spec_decoding']) ["disable_logprobs"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -151,9 +155,11 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, ...@@ -151,9 +155,11 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS, "num_speculative_tokens": MAX_SPEC_TOKENS,
}, },
},
]) ])
@pytest.mark.parametrize("output_len", [ @pytest.mark.parametrize("output_len", [
128, 128,
...@@ -193,9 +199,11 @@ def test_eagle_e2e_greedy_correctness_cuda_graph( ...@@ -193,9 +199,11 @@ def test_eagle_e2e_greedy_correctness_cuda_graph(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS, "num_speculative_tokens": MAX_SPEC_TOKENS,
}, },
},
]) ])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
...@@ -236,8 +244,10 @@ def test_eagle_e2e_greedy_correctness_with_preemption( ...@@ -236,8 +244,10 @@ def test_eagle_e2e_greedy_correctness_with_preemption(
"test_llm_kwargs", "test_llm_kwargs",
[ [
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": k, "num_speculative_tokens": k,
},
} }
# Try a range of num. speculative tokens # Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS) for k in range(1, 1 + MAX_SPEC_TOKENS)
...@@ -277,12 +287,13 @@ def test_eagle_different_k(vllm_runner, common_llm_kwargs, ...@@ -277,12 +287,13 @@ def test_eagle_different_k(vllm_runner, common_llm_kwargs,
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", @pytest.mark.parametrize("test_llm_kwargs", [{
[{ "speculative_config": {
"speculative_model": SPEC_MODEL, "model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS, "num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4 "disable_by_batch_size": 4,
}]) },
}])
@pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
...@@ -324,9 +335,11 @@ def test_eagle_disable_queue(vllm_runner, common_llm_kwargs, ...@@ -324,9 +335,11 @@ def test_eagle_disable_queue(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": "yuhuili/EAGLE-llama2-chat-7B", "speculative_config": {
"model": "yuhuili/EAGLE-llama2-chat-7B",
"num_speculative_tokens": MAX_SPEC_TOKENS, "num_speculative_tokens": MAX_SPEC_TOKENS,
}, },
},
]) ])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
...@@ -372,9 +385,11 @@ def test_llama2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, ...@@ -372,9 +385,11 @@ def test_llama2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", "speculative_config": {
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
"num_speculative_tokens": MAX_SPEC_TOKENS, "num_speculative_tokens": MAX_SPEC_TOKENS,
}, },
},
]) ])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
...@@ -420,9 +435,11 @@ def test_llama3_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, ...@@ -420,9 +435,11 @@ def test_llama3_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": "yuhuili/EAGLE-Qwen2-7B-Instruct", "speculative_config": {
"model": "yuhuili/EAGLE-Qwen2-7B-Instruct",
"num_speculative_tokens": MAX_SPEC_TOKENS, "num_speculative_tokens": MAX_SPEC_TOKENS,
}, },
},
]) ])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
......
...@@ -23,9 +23,11 @@ MAIN_MODEL = "JackFram/llama-68m" ...@@ -23,9 +23,11 @@ MAIN_MODEL = "JackFram/llama-68m"
[ [
{ {
# Identical models. # Identical models.
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
}, },
},
]) ])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [{}])
...@@ -57,26 +59,33 @@ def test_spec_decode_cuda_graph(vllm_runner, common_llm_kwargs, ...@@ -57,26 +59,33 @@ def test_spec_decode_cuda_graph(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [ @pytest.mark.parametrize("per_test_common_llm_kwargs", [])
{
"speculative_model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"test_llm_kwargs", "test_llm_kwargs",
[ [
# Explicitly specify draft model quantization # Explicitly specify draft model quantization
{ {
"speculative_model_quantization": "gptq", "speculative_config": {
"model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
"num_speculative_tokens": 5,
"quantization": "gptq",
},
}, },
# Explicitly specify GPTQ-based draft model to use marlin quantization # Explicitly specify GPTQ-based draft model to use marlin quantization
{ {
"speculative_model_quantization": "marlin", "speculative_config": {
"model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
"num_speculative_tokens": 5,
"quantization": "marlin",
},
}, },
# Not explicitly specify draft model quantization # Not explicitly specify draft model quantization
{ {
"speculative_model_quantization": None, "speculative_config": {
"model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
"num_speculative_tokens": 5,
"quantization": None,
},
}, },
]) ])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
...@@ -107,15 +116,16 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs, ...@@ -107,15 +116,16 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", @pytest.mark.parametrize("test_llm_kwargs", [{
[{ "speculative_config": {
"speculative_disable_mqa_scorer": True, "model": "JackFram/llama-68m",
}]) "num_speculative_tokens": 3,
"disable_mqa_scorer": True,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
...@@ -127,7 +137,7 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs, ...@@ -127,7 +137,7 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, seed: int): output_len: int, seed: int):
"""Verify that ngram speculative decoding generates the same output """Verify that speculative decoding generates the same output
with batch expansion scorer and mqa scorer. with batch expansion scorer and mqa scorer.
""" """
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(vllm_runner,
......
...@@ -27,18 +27,19 @@ from .conftest import run_equality_correctness_test_tp ...@@ -27,18 +27,19 @@ from .conftest import run_equality_correctness_test_tp
@pytest.mark.parametrize("baseline_llm_kwargs", [[]]) @pytest.mark.parametrize("baseline_llm_kwargs", [[]])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
[ [
"--speculative-model", "--speculative_config",
"JackFram/llama-68m", str({
"--num-speculative-tokens", "model": "JackFram/llama-68m",
"3", "num_speculative_tokens": 3,
}),
], ],
[ [
"--speculative-model", "--speculative_config",
"[ngram]", str({
"--num-speculative-tokens", "model": "ngram",
"5", "num_speculative_tokens": 5,
"--ngram-prompt-lookup-max", "prompt_lookup_max": 3,
"3", }),
], ],
]) ])
@pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize("batch_size", [2])
...@@ -83,22 +84,23 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs, ...@@ -83,22 +84,23 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
]]) ]])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [[]]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [[]])
@pytest.mark.parametrize("baseline_llm_kwargs", [[]]) @pytest.mark.parametrize("baseline_llm_kwargs", [[]])
@pytest.mark.parametrize("model, test_llm_kwargs", @pytest.mark.parametrize(
"model, test_llm_kwargs",
[("JackFram/llama-68m", [ [("JackFram/llama-68m", [
"--speculative-model", "--speculative_config",
"JackFram/llama-68m", str({
"--num_speculative-tokens", "model": "JackFram/llama-68m",
"5", "num_speculative_tokens": 5,
"--speculative-draft-tensor-parallel-size", "draft_tensor_parallel_size": 1,
"1", }),
]), ]),
("ibm-granite/granite-3b-code-instruct", [ ("ibm-granite/granite-3b-code-instruct", [
"--speculative-model", "--speculative_config",
"ibm-granite/granite-3b-code-instruct", str({
"--num_speculative-tokens", "model": "ibm-granite/granite-3b-code-instruct",
"5", "num_speculative_tokens": 5,
"--speculative-draft-tensor-parallel-size", "draft_tensor_parallel_size": 1,
"1", }),
])]) ])])
@pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
...@@ -144,18 +146,19 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs, ...@@ -144,18 +146,19 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [[]]) @pytest.mark.parametrize("baseline_llm_kwargs", [[]])
@pytest.mark.parametrize("model, test_llm_kwargs", @pytest.mark.parametrize("model, test_llm_kwargs",
[("JackFram/llama-68m", [ [("JackFram/llama-68m", [
"--speculative-model", "--speculative_config",
"JackFram/llama-68m", str({
"--num_speculative-tokens", "model": "JackFram/llama-68m",
"3", "num_speculative_tokens": 3,
}),
]), ]),
("JackFram/llama-68m", [ ("JackFram/llama-68m", [
"--speculative-model", "--speculative_config",
"JackFram/llama-68m", str({
"--num_speculative-tokens", "model": "JackFram/llama-68m",
"3", "num_speculative_tokens": 3,
"--speculative-draft-tensor-parallel-size", "draft_tensor_parallel_size": 1,
"1", }),
])]) ])])
@pytest.mark.parametrize("logprobs", [None, 2]) @pytest.mark.parametrize("logprobs", [None, 2])
@pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize("batch_size", [2])
......
...@@ -24,12 +24,7 @@ SPEC_MODEL = "JackFram/llama-68m" ...@@ -24,12 +24,7 @@ SPEC_MODEL = "JackFram/llama-68m"
"4", "4",
]]) ]])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [ @pytest.mark.parametrize("per_test_common_llm_kwargs", [
[ [],
"--speculative-model",
f"{SPEC_MODEL}",
"--num-speculative-tokens",
"5",
],
]) ])
@pytest.mark.parametrize("baseline_llm_kwargs", [[]]) @pytest.mark.parametrize("baseline_llm_kwargs", [[]])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -37,8 +32,12 @@ SPEC_MODEL = "JackFram/llama-68m" ...@@ -37,8 +32,12 @@ SPEC_MODEL = "JackFram/llama-68m"
[ [
#TODO(wooyeon): add spec_draft_dp=2 case #TODO(wooyeon): add spec_draft_dp=2 case
[ [
"--speculative-draft-tensor-parallel-size", "--speculative_config",
"1", str({
"model": f"{SPEC_MODEL}",
"num_speculative_tokens": 5,
"draft_tensor_parallel_size": 1,
}),
], ],
]) ])
@pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize("batch_size", [2])
...@@ -78,15 +77,14 @@ def test_draft_model_tp_lt_target_model_tp4(common_llm_kwargs, ...@@ -78,15 +77,14 @@ def test_draft_model_tp_lt_target_model_tp4(common_llm_kwargs,
"test_llm_kwargs", "test_llm_kwargs",
[ [
[ [
"--speculative-model",
f"{SPEC_MODEL}",
"--num-speculative-tokens",
"5",
# Artificially limit the draft model max model len; this forces vLLM # Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens. # to skip speculation once the sequences grow beyond 32-k tokens.
"--speculative-max-model-len", "--speculative_config",
"32", str({
"model": f"{SPEC_MODEL}",
"num_speculative_tokens": 5,
"max_model_len": 32,
}),
], ],
]) ])
@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("batch_size", [8])
......
...@@ -20,16 +20,19 @@ from .conftest import run_equality_correctness_test ...@@ -20,16 +20,19 @@ from .conftest import run_equality_correctness_test
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", @pytest.mark.parametrize("test_llm_kwargs", [{
[{ "speculative_config": {
"speculative_model": "JackFram/llama-68m", "model": "JackFram/llama-68m",
"num_speculative_tokens": 3, "num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False, "disable_logprobs": False,
}, { },
"speculative_model": "JackFram/llama-68m", }, {
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3, "num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": True, "disable_logprobs": True,
}]) },
}])
@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
...@@ -48,7 +51,8 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs, ...@@ -48,7 +51,8 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
as well as with and without chunked prefill. as well as with and without chunked prefill.
""" """
maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs) maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(
vllm_runner,
common_llm_kwargs, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, baseline_llm_kwargs,
...@@ -59,8 +63,8 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs, ...@@ -59,8 +63,8 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
temperature=0.0, temperature=0.0,
logprobs=logprobs, logprobs=logprobs,
prompt_logprobs=logprobs, prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[ disable_logprobs=test_llm_kwargs["speculative_config"]
'disable_logprobs_during_spec_decoding']) ["disable_logprobs"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -73,16 +77,19 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs, ...@@ -73,16 +77,19 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", @pytest.mark.parametrize("test_llm_kwargs", [{
[{ "speculative_config": {
"speculative_model": "JackFram/llama-160m", "model": "JackFram/llama-160m",
"num_speculative_tokens": 3, "num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False, "disable_logprobs": False,
}, { },
"speculative_model": "JackFram/llama-160m", }, {
"speculative_config": {
"model": "JackFram/llama-160m",
"num_speculative_tokens": 6, "num_speculative_tokens": 6,
"disable_logprobs_during_spec_decoding": False, "disable_logprobs": False,
}]) },
}])
@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
...@@ -98,7 +105,8 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs, ...@@ -98,7 +105,8 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
output_len: int, seed: int, logprobs: int): output_len: int, seed: int, logprobs: int):
"""Veriy logprob greedy equality with different speculation lens. """Veriy logprob greedy equality with different speculation lens.
""" """
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(
vllm_runner,
common_llm_kwargs, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, baseline_llm_kwargs,
...@@ -108,8 +116,8 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs, ...@@ -108,8 +116,8 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
seed, seed,
temperature=0.0, temperature=0.0,
logprobs=logprobs, logprobs=logprobs,
disable_logprobs=test_llm_kwargs[ disable_logprobs=test_llm_kwargs["speculative_config"]
'disable_logprobs_during_spec_decoding']) ["disable_logprobs"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -125,13 +133,15 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs, ...@@ -125,13 +133,15 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize( @pytest.mark.parametrize(
"test_llm_kwargs", "test_llm_kwargs",
[{ [{
"speculative_model": "JackFram/llama-160m", "speculative_config": {
"model": "JackFram/llama-160m",
"num_speculative_tokens": 3, "num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False, "disable_logprobs": False,
# Artificially limit the draft model max model len; this forces
# Artificially limit the draft model max model len; this forces vLLM # vLLM to skip speculation once the sequences grow beyond 32-k
# to skip speculation once the sequences grow beyond 32-k tokens. # tokens.
"speculative_max_model_len": 32, "max_model_len": 32,
},
}]) }])
@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -149,7 +159,8 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs, ...@@ -149,7 +159,8 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
seed: int, logprobs: int): seed: int, logprobs: int):
"""Verify logprobs greedy equality when some sequences skip speculation. """Verify logprobs greedy equality when some sequences skip speculation.
""" """
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(
vllm_runner,
common_llm_kwargs, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, baseline_llm_kwargs,
...@@ -159,8 +170,8 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs, ...@@ -159,8 +170,8 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
seed, seed,
temperature=0.0, temperature=0.0,
logprobs=logprobs, logprobs=logprobs,
disable_logprobs=test_llm_kwargs[ disable_logprobs=test_llm_kwargs["speculative_config"]
'disable_logprobs_during_spec_decoding']) ["disable_logprobs"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -173,12 +184,13 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs, ...@@ -173,12 +184,13 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", @pytest.mark.parametrize("test_llm_kwargs", [{
[{ "speculative_config": {
"speculative_model": "JackFram/llama-160m", "model": "JackFram/llama-160m",
"num_speculative_tokens": 3, "num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False, "disable_logprobs": False,
}]) },
}])
@pytest.mark.parametrize("batch_size", [1]) @pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
...@@ -248,12 +260,13 @@ def test_logprobs_temp_1(vllm_runner, common_llm_kwargs, ...@@ -248,12 +260,13 @@ def test_logprobs_temp_1(vllm_runner, common_llm_kwargs,
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", @pytest.mark.parametrize("test_llm_kwargs", [{
[{ "speculative_config": {
"speculative_model": "JackFram/llama-68m", "model": "JackFram/llama-68m",
"num_speculative_tokens": 3, "num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": True, "disable_logprobs": True,
}]) },
}])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -270,7 +283,8 @@ def test_logprobs_disabled(vllm_runner, common_llm_kwargs, ...@@ -270,7 +283,8 @@ def test_logprobs_disabled(vllm_runner, common_llm_kwargs,
"""Check the behavior when logprobs are disabled. """Check the behavior when logprobs are disabled.
Token choices should match with the base model. Token choices should match with the base model.
""" """
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(
vllm_runner,
common_llm_kwargs, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, baseline_llm_kwargs,
...@@ -280,5 +294,5 @@ def test_logprobs_disabled(vllm_runner, common_llm_kwargs, ...@@ -280,5 +294,5 @@ def test_logprobs_disabled(vllm_runner, common_llm_kwargs,
seed, seed,
temperature=0.0, temperature=0.0,
logprobs=logprobs, logprobs=logprobs,
disable_logprobs=test_llm_kwargs[ disable_logprobs=test_llm_kwargs["speculative_config"]
'disable_logprobs_during_spec_decoding']) ["disable_logprobs"])
...@@ -60,9 +60,11 @@ PRECISION = "float32" ...@@ -60,9 +60,11 @@ PRECISION = "float32"
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS, "num_speculative_tokens": MAX_SPEC_TOKENS,
}, },
},
]) ])
@pytest.mark.parametrize("output_len", [ @pytest.mark.parametrize("output_len", [
128, 128,
...@@ -107,14 +109,18 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, ...@@ -107,14 +109,18 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS, "num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": False, "disable_logprobs": False,
},
}, },
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS, "num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": True, "disable_logprobs": True,
},
}, },
]) ])
@pytest.mark.parametrize("output_len", [ @pytest.mark.parametrize("output_len", [
...@@ -132,7 +138,8 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, ...@@ -132,7 +138,8 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
prefill_chunk_size: int): prefill_chunk_size: int):
"""Verify greedy equality with different batch size.""" """Verify greedy equality with different batch size."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(
vllm_runner,
common_llm_kwargs, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, baseline_llm_kwargs,
...@@ -143,8 +150,8 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, ...@@ -143,8 +150,8 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
temperature=0.0, temperature=0.0,
logprobs=logprobs, logprobs=logprobs,
prompt_logprobs=logprobs, prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[ disable_logprobs=test_llm_kwargs["speculative_config"]
'disable_logprobs_during_spec_decoding']) ["disable_logprobs"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -165,9 +172,11 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, ...@@ -165,9 +172,11 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS, "num_speculative_tokens": MAX_SPEC_TOKENS,
}, },
},
]) ])
@pytest.mark.parametrize("output_len", [ @pytest.mark.parametrize("output_len", [
128, 128,
...@@ -214,9 +223,11 @@ def test_medusa_e2e_greedy_correctness_cuda_graph( ...@@ -214,9 +223,11 @@ def test_medusa_e2e_greedy_correctness_cuda_graph(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS, "num_speculative_tokens": MAX_SPEC_TOKENS,
}, },
},
]) ])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
...@@ -264,8 +275,10 @@ def test_medusa_e2e_greedy_correctness_with_preemption( ...@@ -264,8 +275,10 @@ def test_medusa_e2e_greedy_correctness_with_preemption(
"test_llm_kwargs", "test_llm_kwargs",
[ [
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": k, "num_speculative_tokens": k,
},
} }
# Try a range of num. speculative tokens # Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS) for k in range(1, 1 + MAX_SPEC_TOKENS)
...@@ -312,12 +325,13 @@ def test_medusa_different_k(vllm_runner, common_llm_kwargs, ...@@ -312,12 +325,13 @@ def test_medusa_different_k(vllm_runner, common_llm_kwargs,
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", @pytest.mark.parametrize("test_llm_kwargs", [{
[{ "speculative_config": {
"speculative_model": SPEC_MODEL, "model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS, "num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4 "disable_by_batch_size": 4,
}]) },
}])
@pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
...@@ -359,16 +373,17 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs, ...@@ -359,16 +373,17 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
# Main model # Main model
"model_name": MAIN_MODEL, "model_name": MAIN_MODEL,
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", @pytest.mark.parametrize("test_llm_kwargs", [{
[{ "speculative_config": {
"speculative_disable_mqa_scorer": True, "model": SPEC_MODEL,
}]) "num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_by_batch_size": 4,
"disable_mqa_scorer": True,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
......
...@@ -62,7 +62,9 @@ PRECISION = "float32" ...@@ -62,7 +62,9 @@ PRECISION = "float32"
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"model": SPEC_MODEL,
},
}, },
]) ])
@pytest.mark.parametrize("output_len", [ @pytest.mark.parametrize("output_len", [
...@@ -108,12 +110,16 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, ...@@ -108,12 +110,16 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"disable_logprobs_during_spec_decoding": False, "model": SPEC_MODEL,
"disable_logprobs": False,
},
}, },
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"disable_logprobs_during_spec_decoding": True, "model": SPEC_MODEL,
"disable_logprobs": True,
},
}, },
]) ])
@pytest.mark.parametrize("output_len", [8]) @pytest.mark.parametrize("output_len", [8])
...@@ -133,7 +139,8 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, ...@@ -133,7 +139,8 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
# up sampling different tokens at the tail (ie top tokens don't change). # up sampling different tokens at the tail (ie top tokens don't change).
# TL;DR: sd+cp == org+cp but sd+cp != org..is this expected? # TL;DR: sd+cp == org+cp but sd+cp != org..is this expected?
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs) maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(
vllm_runner,
common_llm_kwargs, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, baseline_llm_kwargs,
...@@ -144,8 +151,8 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, ...@@ -144,8 +151,8 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
temperature=0.0, temperature=0.0,
logprobs=logprobs, logprobs=logprobs,
prompt_logprobs=logprobs, prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[ disable_logprobs=test_llm_kwargs["speculative_config"]
'disable_logprobs_during_spec_decoding']) ["disable_logprobs"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -167,7 +174,9 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, ...@@ -167,7 +174,9 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"model": SPEC_MODEL,
},
}, },
]) ])
@pytest.mark.parametrize("output_len", [2048]) @pytest.mark.parametrize("output_len", [2048])
...@@ -209,8 +218,10 @@ def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs, ...@@ -209,8 +218,10 @@ def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
# Main model # Main model
"model_name": MAIN_MODEL, "model_name": MAIN_MODEL,
# Speculative model # Speculative config
"speculative_model": SPEC_MODEL, "speculative_config": {
"model": SPEC_MODEL,
},
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
...@@ -274,7 +285,9 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs, ...@@ -274,7 +285,9 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"model": SPEC_MODEL,
},
}, },
]) ])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -326,7 +339,9 @@ def test_mlp_e2e_greedy_correctness_with_preemption( ...@@ -326,7 +339,9 @@ def test_mlp_e2e_greedy_correctness_with_preemption(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"model": SPEC_MODEL,
},
}, },
]) ])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -382,8 +397,10 @@ def test_mlp_e2e_greedy_correctness_with_padding( ...@@ -382,8 +397,10 @@ def test_mlp_e2e_greedy_correctness_with_padding(
"test_llm_kwargs", "test_llm_kwargs",
[ [
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": k, "num_speculative_tokens": k,
},
} }
# Try a range of num. speculative tokens # Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS) for k in range(1, 1 + MAX_SPEC_TOKENS)
...@@ -430,11 +447,12 @@ def test_mlp_different_k(vllm_runner, common_llm_kwargs, ...@@ -430,11 +447,12 @@ def test_mlp_different_k(vllm_runner, common_llm_kwargs,
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", @pytest.mark.parametrize("test_llm_kwargs", [{
[{ "speculative_config": {
"speculative_model": SPEC_MODEL, "model": SPEC_MODEL,
"speculative_disable_by_batch_size": 4 "disable_by_batch_size": 4,
}]) },
}])
@pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
...@@ -475,14 +493,15 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs, ...@@ -475,14 +493,15 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
"speculative_model": SPEC_MODEL,
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", @pytest.mark.parametrize("test_llm_kwargs", [{
[{ "speculative_config": {
"speculative_disable_mqa_scorer": True, "model": SPEC_MODEL,
}]) "disable_mqa_scorer": True,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
......
...@@ -57,8 +57,10 @@ PRECISION = "bfloat16" ...@@ -57,8 +57,10 @@ PRECISION = "bfloat16"
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS, "num_speculative_tokens": MAX_SPEC_TOKENS,
}, },
},
]) ])
@pytest.mark.parametrize("output_len", [ @pytest.mark.parametrize("output_len", [
128, 128,
...@@ -99,12 +101,16 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, ...@@ -99,12 +101,16 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS, "num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": False, "disable_logprobs": False,
},
}, },
{ {
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS, "num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": True, "disable_logprobs": True,
},
}, },
]) ])
@pytest.mark.parametrize("output_len", [ @pytest.mark.parametrize("output_len", [
...@@ -119,7 +125,8 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, ...@@ -119,7 +125,8 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
batch_size: int, output_len: int, seed: int, batch_size: int, output_len: int, seed: int,
logprobs: int): logprobs: int):
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(
vllm_runner,
common_llm_kwargs, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, baseline_llm_kwargs,
...@@ -129,8 +136,8 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, ...@@ -129,8 +136,8 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
seed, seed,
logprobs=logprobs, logprobs=logprobs,
prompt_logprobs=logprobs, prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[ disable_logprobs=test_llm_kwargs["speculative_config"]
'disable_logprobs_during_spec_decoding']) ["disable_logprobs"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -152,8 +159,10 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, ...@@ -152,8 +159,10 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS, "num_speculative_tokens": MAX_SPEC_TOKENS,
}, },
},
]) ])
@pytest.mark.parametrize("output_len", [ @pytest.mark.parametrize("output_len", [
128, 128,
...@@ -198,8 +207,10 @@ def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs, ...@@ -198,8 +207,10 @@ def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS, "num_speculative_tokens": MAX_SPEC_TOKENS,
}, },
},
]) ])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
...@@ -243,7 +254,9 @@ def test_mtp_e2e_greedy_correctness_with_preemption( ...@@ -243,7 +254,9 @@ def test_mtp_e2e_greedy_correctness_with_preemption(
"test_llm_kwargs", "test_llm_kwargs",
[ [
{ {
"speculative_config": {
"num_speculative_tokens": k, "num_speculative_tokens": k,
},
} }
# Try a range of num. speculative tokens # Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS) for k in range(1, 1 + MAX_SPEC_TOKENS)
...@@ -286,11 +299,12 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs, ...@@ -286,11 +299,12 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs,
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", @pytest.mark.parametrize("test_llm_kwargs", [{
[{ "speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS, "num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4 "disable_by_batch_size": 4
}]) },
}])
@pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
......
...@@ -61,15 +61,19 @@ from .conftest import (get_output_from_llm_generator, ...@@ -61,15 +61,19 @@ from .conftest import (get_output_from_llm_generator,
"per_test_common_llm_kwargs", "per_test_common_llm_kwargs",
[ [
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
},
"enable_chunked_prefill": False, "enable_chunked_prefill": False,
}, },
{ {
# Chunked prefill enabled with small value # Chunked prefill enabled with small value
# to make sure we get mixed batches. # to make sure we get mixed batches.
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
},
"enable_chunked_prefill": True, "enable_chunked_prefill": True,
"max_num_batched_tokens": 4, "max_num_batched_tokens": 4,
"max_num_seqs": 4 "max_num_seqs": 4
...@@ -148,20 +152,23 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator, ...@@ -148,20 +152,23 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
}, },
]) ])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", @pytest.mark.parametrize("test_llm_kwargs", [{
[{ "speculative_config": {
"speculative_model": "JackFram/llama-68m", "model": "JackFram/llama-68m",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"disable_logprobs": False,
},
"enable_chunked_prefill": False, "enable_chunked_prefill": False,
"disable_logprobs_during_spec_decoding": False }, {
}, { "speculative_config": {
"speculative_model": "JackFram/llama-68m", "model": "JackFram/llama-68m",
"num_speculative_tokens": 3, "num_speculative_tokens": 3,
"disable_logprobs": False,
},
"enable_chunked_prefill": True, "enable_chunked_prefill": True,
"max_num_batched_tokens": 4, "max_num_batched_tokens": 4,
"max_num_seqs": 4, "max_num_seqs": 4,
"disable_logprobs_during_spec_decoding": False }])
}])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
[ [
...@@ -184,7 +191,7 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( ...@@ -184,7 +191,7 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
whether all speculative tokens are accepted. whether all speculative tokens are accepted.
""" """
ensure_all_accepted = per_test_common_llm_kwargs.get( ensure_all_accepted = per_test_common_llm_kwargs.get(
"model_name") == test_llm_kwargs.get("speculative_model") "model_name") == test_llm_kwargs.get("speculative_config")["model"]
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(vllm_runner,
common_llm_kwargs, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
...@@ -224,13 +231,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( ...@@ -224,13 +231,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
},
"enable_chunked_prefill": False, "enable_chunked_prefill": False,
}, },
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
},
"enable_chunked_prefill": True, "enable_chunked_prefill": True,
"max_num_batched_tokens": 4, "max_num_batched_tokens": 4,
"max_num_seqs": 4 "max_num_seqs": 4
...@@ -283,13 +294,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs( ...@@ -283,13 +294,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
},
"enable_chunked_prefill": False, "enable_chunked_prefill": False,
}, },
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
},
"enable_chunked_prefill": True, "enable_chunked_prefill": True,
"max_num_batched_tokens": 4, "max_num_batched_tokens": 4,
"max_num_seqs": 4 "max_num_seqs": 4
...@@ -336,13 +351,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len( ...@@ -336,13 +351,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
},
"enable_chunked_prefill": False, "enable_chunked_prefill": False,
}, },
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
},
"enable_chunked_prefill": True, "enable_chunked_prefill": True,
"max_num_batched_tokens": 4, "max_num_batched_tokens": 4,
"max_num_seqs": 4 "max_num_seqs": 4
...@@ -391,13 +410,17 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1( ...@@ -391,13 +410,17 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
},
"enable_chunked_prefill": False, "enable_chunked_prefill": False,
}, },
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
},
"enable_chunked_prefill": True, "enable_chunked_prefill": True,
"max_num_batched_tokens": 4, "max_num_batched_tokens": 4,
"max_num_seqs": 4 "max_num_seqs": 4
...@@ -449,13 +472,17 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs( ...@@ -449,13 +472,17 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
},
"enable_chunked_prefill": False, "enable_chunked_prefill": False,
}, },
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
},
"enable_chunked_prefill": True, "enable_chunked_prefill": True,
"max_num_batched_tokens": 4, "max_num_batched_tokens": 4,
"max_num_seqs": 4 "max_num_seqs": 4
...@@ -514,13 +541,17 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption( ...@@ -514,13 +541,17 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
},
"enable_chunked_prefill": False, "enable_chunked_prefill": False,
}, },
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
},
"enable_chunked_prefill": True, "enable_chunked_prefill": True,
"max_num_batched_tokens": 4, "max_num_batched_tokens": 4,
"max_num_seqs": 4 "max_num_seqs": 4
...@@ -567,21 +598,25 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs, ...@@ -567,21 +598,25 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs,
"test_llm_kwargs", "test_llm_kwargs",
[ [
{ {
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
# Artificially limit the draft model max model len; this forces vLLM # Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens. # to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_max_model_len": 32, "speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"max_model_len": 32,
},
"enable_chunked_prefill": False, "enable_chunked_prefill": False,
}, },
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"max_model_len": 32,
},
"enable_chunked_prefill": True, "enable_chunked_prefill": True,
"max_num_batched_tokens": 4, "max_num_batched_tokens": 4,
"max_num_seqs": 4, "max_num_seqs": 4,
"speculative_max_model_len": 32,
}, },
]) ])
@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("batch_size", [8])
...@@ -627,15 +662,19 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs, ...@@ -627,15 +662,19 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"speculative_disable_by_batch_size": 2, "disable_by_batch_size": 2,
},
"enable_chunked_prefill": False, "enable_chunked_prefill": False,
}, },
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"speculative_disable_by_batch_size": 2, "disable_by_batch_size": 2,
},
"enable_chunked_prefill": True, "enable_chunked_prefill": True,
"max_num_batched_tokens": 4, "max_num_batched_tokens": 4,
"max_num_seqs": 4, "max_num_seqs": 4,
...@@ -676,15 +715,19 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs, ...@@ -676,15 +715,19 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs,
"test_llm_kwargs", "test_llm_kwargs",
[ [
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": k, "num_speculative_tokens": k,
},
"enable_chunked_prefill": False, "enable_chunked_prefill": False,
} }
# Try a range of common k, as well as large speculation. # Try a range of common k, as well as large speculation.
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63] for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
] + [{ ] + [{
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": k, "num_speculative_tokens": k,
},
"enable_chunked_prefill": True, "enable_chunked_prefill": True,
"max_num_batched_tokens": 4, "max_num_batched_tokens": 4,
"max_num_seqs": 4, "max_num_seqs": 4,
...@@ -729,17 +772,21 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, ...@@ -729,17 +772,21 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
"test_llm_kwargs", "test_llm_kwargs",
[ [
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": k, "num_speculative_tokens": k,
"spec_decoding_acceptance_method": "typical_acceptance_sampler", "acceptance_method": "typical_acceptance_sampler",
},
"enable_chunked_prefill": False "enable_chunked_prefill": False
} }
# Try a range of common k. # Try a range of common k.
for k in [1, 2, 3] for k in [1, 2, 3]
] + [{ ] + [{
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": k, "num_speculative_tokens": k,
"spec_decoding_acceptance_method": "typical_acceptance_sampler", "acceptance_method": "typical_acceptance_sampler",
},
"enable_chunked_prefill": True, "enable_chunked_prefill": True,
"max_num_batched_tokens": 4, "max_num_batched_tokens": 4,
"max_num_seqs": 4 "max_num_seqs": 4
......
...@@ -48,16 +48,20 @@ from .conftest import run_equality_correctness_test ...@@ -48,16 +48,20 @@ from .conftest import run_equality_correctness_test
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": "[ngram]", "speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3, "prompt_lookup_max": 3,
"speculative_disable_mqa_scorer": False, "disable_mqa_scorer": False,
},
}, },
{ {
"speculative_model": "[ngram]", "speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3, "prompt_lookup_max": 3,
"speculative_disable_mqa_scorer": True, "disable_mqa_scorer": True,
},
}, },
]) ])
@pytest.mark.parametrize("output_len", [ @pytest.mark.parametrize("output_len", [
...@@ -101,16 +105,20 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, ...@@ -101,16 +105,20 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": "[ngram]", "speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3, "prompt_lookup_max": 3,
"disable_logprobs_during_spec_decoding": False, "disable_logprobs": False,
},
}, },
{ {
"speculative_model": "[ngram]", "speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3, "prompt_lookup_max": 3,
"disable_logprobs_during_spec_decoding": True, "disable_logprobs": True,
},
}, },
]) ])
@pytest.mark.parametrize("output_len", [ @pytest.mark.parametrize("output_len", [
...@@ -125,7 +133,8 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, ...@@ -125,7 +133,8 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
batch_size: int, output_len: int, seed: int, batch_size: int, output_len: int, seed: int,
logprobs: int): logprobs: int):
"""Verify greedy equality on a tiny model with different batch size.""" """Verify greedy equality on a tiny model with different batch size."""
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(
vllm_runner,
common_llm_kwargs, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, baseline_llm_kwargs,
...@@ -136,8 +145,8 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, ...@@ -136,8 +145,8 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
temperature=0.0, temperature=0.0,
logprobs=logprobs, logprobs=logprobs,
prompt_logprobs=logprobs, prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[ disable_logprobs=test_llm_kwargs["speculative_config"]
'disable_logprobs_during_spec_decoding']) ["disable_logprobs"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -159,17 +168,21 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, ...@@ -159,17 +168,21 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": "[ngram]", "speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3, "prompt_lookup_max": 3,
},
"enable_chunked_prefill": False, "enable_chunked_prefill": False,
}, },
{ {
"speculative_model": "[ngram]", "speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3, "prompt_lookup_max": 3,
"disable_mqa_scorer": True,
},
"enable_chunked_prefill": True, "enable_chunked_prefill": True,
"speculative_disable_mqa_scorer": True,
"max_num_batched_tokens": 4, "max_num_batched_tokens": 4,
"max_num_seqs": 4 "max_num_seqs": 4
}, },
...@@ -214,17 +227,21 @@ def test_ngram_e2e_greedy_correctness_with_preemption( ...@@ -214,17 +227,21 @@ def test_ngram_e2e_greedy_correctness_with_preemption(
"test_llm_kwargs", "test_llm_kwargs",
[ [
{ {
"speculative_model": "[ngram]", "speculative_config": {
"method": "ngram",
"num_speculative_tokens": k, "num_speculative_tokens": k,
"ngram_prompt_lookup_max": 3, "prompt_lookup_max": 3,
},
} }
# Try a range of common k, as well as large speculation. # Try a range of common k, as well as large speculation.
for k in [1, 3, 5] for k in [1, 3, 5]
] + [ ] + [
{ {
"speculative_model": "[ngram]", "speculative_config": {
"method": "ngram",
"num_speculative_tokens": k, "num_speculative_tokens": k,
"ngram_prompt_lookup_max": 1, "prompt_lookup_max": 1,
},
} }
# Try a range of common k, as well as large speculation. # Try a range of common k, as well as large speculation.
for k in [1, 3, 5] for k in [1, 3, 5]
...@@ -243,7 +260,7 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs, ...@@ -243,7 +260,7 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
seed: int): seed: int):
"""Verify that ngram speculative decoding produces exact equality """Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and to without spec decode with many different values of k and
different ngram_prompt_lookup_max. different ngram prompt_lookup_max.
""" """
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(vllm_runner,
common_llm_kwargs, common_llm_kwargs,
...@@ -266,22 +283,25 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs, ...@@ -266,22 +283,25 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", @pytest.mark.parametrize("test_llm_kwargs", [{
[{ "speculative_config": {
"speculative_model": "[ngram]", "method": "ngram",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3, "prompt_lookup_max": 3,
"speculative_disable_by_batch_size": 4 "disable_by_batch_size": 4
}, { },
"speculative_model": "[ngram]", }, {
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3, "prompt_lookup_max": 3,
"speculative_disable_by_batch_size": 4, "disable_by_batch_size": 4,
"disable_mqa_scorer": True,
},
"enable_chunked_prefill": True, "enable_chunked_prefill": True,
"speculative_disable_mqa_scorer": True,
"max_num_batched_tokens": 4, "max_num_batched_tokens": 4,
"max_num_seqs": 4 "max_num_seqs": 4
}]) }])
@pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
...@@ -296,7 +316,7 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs, ...@@ -296,7 +316,7 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
seed: int): seed: int):
"""Verify that ngram speculative decoding produces exact equality """Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and to without spec decode with many different values of k and
different ngram_prompt_lookup_max. different ngram prompt_lookup_max.
""" """
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(vllm_runner,
common_llm_kwargs, common_llm_kwargs,
...@@ -316,18 +336,17 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs, ...@@ -316,18 +336,17 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
# Required for spec decode.
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", @pytest.mark.parametrize("test_llm_kwargs", [{
[{ "speculative_config": {
"speculative_disable_mqa_scorer": True, "method": "ngram",
}]) "num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_mqa_scorer": True,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
......
...@@ -19,11 +19,11 @@ SPEC_MODEL = "JackFram/llama-160m" ...@@ -19,11 +19,11 @@ SPEC_MODEL = "JackFram/llama-160m"
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
# speculative model # speculative config
"speculative_model": "JackFram/llama-160m", "speculative_config": {
"model": "JackFram/llama-160m",
# num speculative tokens
"num_speculative_tokens": 3, "num_speculative_tokens": 3,
},
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
......
...@@ -70,12 +70,16 @@ def test_ngram_correctness( ...@@ -70,12 +70,16 @@ def test_ngram_correctness(
ref_outputs = ref_llm.chat(test_prompts, sampling_config) ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm del ref_llm
spec_llm = LLM(model=model_name, spec_llm = LLM(
speculative_model='[ngram]', model=model_name,
ngram_prompt_lookup_max=5, speculative_config={
ngram_prompt_lookup_min=3, "method": "ngram",
num_speculative_tokens=3, "prompt_lookup_max": 5,
max_model_len=1024) "prompt_lookup_min": 3,
"num_speculative_tokens": 3,
},
max_model_len=1024,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config) spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0 matches = 0
misses = 0 misses = 0
......
This diff is collapsed.
...@@ -177,7 +177,10 @@ class EngineArgs: ...@@ -177,7 +177,10 @@ class EngineArgs:
guided_decoding_backend: str = 'xgrammar' guided_decoding_backend: str = 'xgrammar'
logits_processor_pattern: Optional[str] = None logits_processor_pattern: Optional[str] = None
# Speculative decoding configuration.
speculative_config: Optional[Union[str, Dict[str, Any]]] = None
# TODO(Shangming): Deprecate these out-of-date params after next release
speculative_model: Optional[str] = None speculative_model: Optional[str] = None
speculative_model_quantization: Optional[str] = None speculative_model_quantization: Optional[str] = None
speculative_draft_tensor_parallel_size: Optional[int] = None speculative_draft_tensor_parallel_size: Optional[int] = None
...@@ -190,9 +193,9 @@ class EngineArgs: ...@@ -190,9 +193,9 @@ class EngineArgs:
spec_decoding_acceptance_method: str = 'rejection_sampler' spec_decoding_acceptance_method: str = 'rejection_sampler'
typical_acceptance_sampler_posterior_threshold: Optional[float] = None typical_acceptance_sampler_posterior_threshold: Optional[float] = None
typical_acceptance_sampler_posterior_alpha: Optional[float] = None typical_acceptance_sampler_posterior_alpha: Optional[float] = None
qlora_adapter_name_or_path: Optional[str] = None
disable_logprobs_during_spec_decoding: Optional[bool] = None disable_logprobs_during_spec_decoding: Optional[bool] = None
qlora_adapter_name_or_path: Optional[str] = None
show_hidden_metrics_for_version: Optional[str] = None show_hidden_metrics_for_version: Optional[str] = None
otlp_traces_endpoint: Optional[str] = None otlp_traces_endpoint: Optional[str] = None
collect_detailed_traces: Optional[str] = None collect_detailed_traces: Optional[str] = None
...@@ -780,7 +783,11 @@ class EngineArgs: ...@@ -780,7 +783,11 @@ class EngineArgs:
const="True", const="True",
help='If set, the prefill requests can be chunked based on the ' help='If set, the prefill requests can be chunked based on the '
'max_num_batched_tokens.') 'max_num_batched_tokens.')
parser.add_argument('--speculative-config',
type=nullable_str,
default=None,
help='The configurations for speculative decoding.'
' Should be a JSON string.')
parser.add_argument( parser.add_argument(
'--speculative-model', '--speculative-model',
type=nullable_str, type=nullable_str,
...@@ -1192,6 +1199,82 @@ class EngineArgs: ...@@ -1192,6 +1199,82 @@ class EngineArgs:
use_tqdm_on_load=self.use_tqdm_on_load, use_tqdm_on_load=self.use_tqdm_on_load,
) )
def create_speculative_config(
self,
target_model_config: ModelConfig,
target_parallel_config: ParallelConfig,
enable_chunked_prefill: bool,
disable_log_stats: bool,
) -> Optional["SpeculativeConfig"]:
"""Initializes and returns a SpeculativeConfig object based on
`speculative_config`.
This function utilizes `speculative_config` to create a
SpeculativeConfig object. The `speculative_config` can either be
provided as a JSON string input via CLI arguments or directly as a
dictionary from the engine. If `speculative_config` is not set, this
function will attempt to construct a configuration dictionary using
certain parameters, which are scheduled for deprecation in the next
release. Note that in next releases, `speculative_config` must be
provided, and the deprecated standalone speculative-related parameters
will be removed.
"""
if self.speculative_config is None:
if (self.speculative_model is None
and self.num_speculative_tokens is None):
return None
# TODO(Shangming): Deprecate this way of setting SpeculativeConfig,
# only allow '--speculative-config' after next release
logger.warning_once(
"Please use '--speculative-config' to set all configurations "
"related to speculative decoding. The current method of "
"specifying the model through '--speculative-model' and "
"adding related parameters (e.g., '--num-speculative-tokens') "
"separately will be deprecated in the next release.")
spec_config_dict = {
"model": self.speculative_model,
"quantization": self.speculative_model_quantization,
"max_model_len": self.speculative_max_model_len,
"draft_tensor_parallel_size":
self.speculative_draft_tensor_parallel_size,
"num_speculative_tokens": self.num_speculative_tokens,
"disable_mqa_scorer": self.speculative_disable_mqa_scorer,
"disable_by_batch_size":
self.speculative_disable_by_batch_size,
"prompt_lookup_max": self.ngram_prompt_lookup_max,
"prompt_lookup_min": self.ngram_prompt_lookup_min,
"acceptance_method": self.spec_decoding_acceptance_method,
"posterior_threshold":
self.typical_acceptance_sampler_posterior_threshold,
"posterior_alpha":
self.typical_acceptance_sampler_posterior_alpha,
"disable_logprobs": self.disable_logprobs_during_spec_decoding,
}
self.speculative_config = spec_config_dict
else:
if isinstance(self.speculative_config, str):
import ast
self.speculative_config = ast.literal_eval(
self.speculative_config)
# Note(Shangming): These parameters are not obtained from the cli arg
# '--speculative-config' and must be passed in when creating the engine
# config.
assert isinstance(self.speculative_config, dict)
self.speculative_config.update({
"target_model_config": target_model_config,
"target_parallel_config": target_parallel_config,
"enable_chunked_prefill": enable_chunked_prefill,
"disable_log_stats": disable_log_stats,
})
speculative_config = SpeculativeConfig.from_dict(
self.speculative_config)
return speculative_config
def create_engine_config( def create_engine_config(
self, self,
usage_context: Optional[UsageContext] = None, usage_context: Optional[UsageContext] = None,
...@@ -1238,6 +1321,8 @@ class EngineArgs: ...@@ -1238,6 +1321,8 @@ class EngineArgs:
else: else:
self._set_default_args_v0(model_config) self._set_default_args_v0(model_config)
assert self.enable_chunked_prefill is not None
cache_config = CacheConfig( cache_config = CacheConfig(
block_size=self.block_size, block_size=self.block_size,
gpu_memory_utilization=self.gpu_memory_utilization, gpu_memory_utilization=self.gpu_memory_utilization,
...@@ -1280,31 +1365,11 @@ class EngineArgs: ...@@ -1280,31 +1365,11 @@ class EngineArgs:
worker_extension_cls=self.worker_extension_cls, worker_extension_cls=self.worker_extension_cls,
) )
speculative_config = SpeculativeConfig.maybe_create_spec_config( speculative_config = self.create_speculative_config(
target_model_config=model_config, target_model_config=model_config,
target_parallel_config=parallel_config, target_parallel_config=parallel_config,
target_dtype=self.dtype,
speculative_model=self.speculative_model,
speculative_model_quantization = \
self.speculative_model_quantization,
speculative_draft_tensor_parallel_size = \
self.speculative_draft_tensor_parallel_size,
num_speculative_tokens=self.num_speculative_tokens,
speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer,
speculative_disable_by_batch_size=self.
speculative_disable_by_batch_size,
speculative_max_model_len=self.speculative_max_model_len,
enable_chunked_prefill=self.enable_chunked_prefill, enable_chunked_prefill=self.enable_chunked_prefill,
disable_log_stats=self.disable_log_stats, disable_log_stats=self.disable_log_stats,
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
draft_token_acceptance_method=\
self.spec_decoding_acceptance_method,
typical_acceptance_sampler_posterior_threshold=self.
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=self.
typical_acceptance_sampler_posterior_alpha,
disable_logprobs=self.disable_logprobs_during_spec_decoding,
) )
# Reminder: Please update docs/source/features/compatibility_matrix.md # Reminder: Please update docs/source/features/compatibility_matrix.md
...@@ -1569,7 +1634,7 @@ class EngineArgs: ...@@ -1569,7 +1634,7 @@ class EngineArgs:
if (self.speculative_model is not None if (self.speculative_model is not None
or self.num_speculative_tokens is not None): or self.num_speculative_tokens is not None):
# This is supported but experimental (handled below). # This is supported but experimental (handled below).
if self.speculative_model == "[ngram]": if self.speculative_model in ("ngram", "[ngram]"):
pass pass
else: else:
_raise_or_fallback(feature_name="Speculative Decoding", _raise_or_fallback(feature_name="Speculative Decoding",
...@@ -1617,7 +1682,8 @@ class EngineArgs: ...@@ -1617,7 +1682,8 @@ class EngineArgs:
return False return False
# ngram is supported on V1, but off by default for now. # ngram is supported on V1, but off by default for now.
if self.speculative_model == "[ngram]" and _warn_or_fallback("ngram"): if self.speculative_model in (
"ngram", "[ngram]") and _warn_or_fallback("ngram"):
return False return False
# Non-CUDA is supported on V1, but off by default for now. # Non-CUDA is supported on V1, but off by default for now.
......
...@@ -92,22 +92,20 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": ...@@ -92,22 +92,20 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
# Override draft-model specific worker args. # Override draft-model specific worker args.
draft_worker_kwargs.update( draft_worker_kwargs.update(
vllm_config=draft_worker_config, vllm_config=draft_worker_config,
ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max, ngram_prompt_lookup_max=speculative_config.prompt_lookup_max,
ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min, ngram_prompt_lookup_min=speculative_config.prompt_lookup_min,
) )
spec_decode_worker = SpecDecodeWorker.create_worker( spec_decode_worker = SpecDecodeWorker.create_worker(
scorer_worker=target_worker, scorer_worker=target_worker,
draft_worker_kwargs=draft_worker_kwargs, draft_worker_kwargs=draft_worker_kwargs,
disable_mqa_scorer=speculative_config.speculative_disable_mqa_scorer, disable_mqa_scorer=speculative_config.disable_mqa_scorer,
disable_by_batch_size=speculative_config. disable_by_batch_size=speculative_config.disable_by_batch_size,
speculative_disable_by_batch_size, draft_token_acceptance_method=speculative_config.acceptance_method,
draft_token_acceptance_method=speculative_config.
draft_token_acceptance_method,
typical_acceptance_sampler_posterior_threshold=speculative_config. typical_acceptance_sampler_posterior_threshold=speculative_config.
typical_acceptance_sampler_posterior_threshold, posterior_threshold,
typical_acceptance_sampler_posterior_alpha=speculative_config. typical_acceptance_sampler_posterior_alpha=speculative_config.
typical_acceptance_sampler_posterior_alpha, posterior_alpha,
disable_logprobs=speculative_config.disable_logprobs, disable_logprobs=speculative_config.disable_logprobs,
disable_log_stats=speculative_config.disable_log_stats, disable_log_stats=speculative_config.disable_log_stats,
num_speculative_tokens=speculative_config.num_speculative_tokens, num_speculative_tokens=speculative_config.num_speculative_tokens,
......
...@@ -151,8 +151,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -151,8 +151,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.use_spec_decode = False self.use_spec_decode = False
if self.speculative_config: if self.speculative_config:
self.use_spec_decode = True self.use_spec_decode = True
# TODO: find a better way to check if we are using ngram. assert self.speculative_config.method == "ngram", \
assert self.speculative_config.ngram_prompt_lookup_min, \
"Currently, only ngram spec decode is supported in V1." "Currently, only ngram spec decode is supported in V1."
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.drafter = NgramProposer() self.drafter = NgramProposer()
...@@ -160,7 +159,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -160,7 +159,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# This usually takes less than 1 second. # This usually takes less than 1 second.
self.drafter.propose( self.drafter.propose(
np.zeros(1024, dtype=np.int32), np.zeros(1024, dtype=np.int32),
self.speculative_config.ngram_prompt_lookup_min, self.speculative_config.prompt_lookup_min,
self.speculative_config.num_speculative_tokens, self.speculative_config.num_speculative_tokens,
) )
self.rejection_sampler = RejectionSampler() self.rejection_sampler = RejectionSampler()
...@@ -1155,7 +1154,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1155,7 +1154,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
drafter_output = self.drafter.propose( drafter_output = self.drafter.propose(
self.input_batch.token_ids_cpu[i, :end_idx], self.input_batch.token_ids_cpu[i, :end_idx],
self.speculative_config.ngram_prompt_lookup_min, self.speculative_config.prompt_lookup_min,
self.speculative_config.num_speculative_tokens, self.speculative_config.num_speculative_tokens,
) )
if drafter_output is None or len(drafter_output) == 0: if drafter_output is None or len(drafter_output) == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment