Commit 53076d70 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.2' into v0.8.2-ori

parents 322a0be6 9c5c81b0
...@@ -23,8 +23,10 @@ MAIN_MODEL = "JackFram/llama-68m" ...@@ -23,8 +23,10 @@ MAIN_MODEL = "JackFram/llama-68m"
[ [
{ {
# Identical models. # Identical models.
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"num_speculative_tokens": 5, "model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
}, },
]) ])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
...@@ -57,26 +59,33 @@ def test_spec_decode_cuda_graph(vllm_runner, common_llm_kwargs, ...@@ -57,26 +59,33 @@ def test_spec_decode_cuda_graph(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [ @pytest.mark.parametrize("per_test_common_llm_kwargs", [])
{
"speculative_model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"test_llm_kwargs", "test_llm_kwargs",
[ [
# Explicitly specify draft model quantization # Explicitly specify draft model quantization
{ {
"speculative_model_quantization": "gptq", "speculative_config": {
"model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
"num_speculative_tokens": 5,
"quantization": "gptq",
},
}, },
# Explicitly specify GPTQ-based draft model to use marlin quantization # Explicitly specify GPTQ-based draft model to use marlin quantization
{ {
"speculative_model_quantization": "marlin", "speculative_config": {
"model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
"num_speculative_tokens": 5,
"quantization": "marlin",
},
}, },
# Not explicitly specify draft model quantization # Not explicitly specify draft model quantization
{ {
"speculative_model_quantization": None, "speculative_config": {
"model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
"num_speculative_tokens": 5,
"quantization": None,
},
}, },
]) ])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
...@@ -107,15 +116,16 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs, ...@@ -107,15 +116,16 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", @pytest.mark.parametrize("test_llm_kwargs", [{
[{ "speculative_config": {
"speculative_disable_mqa_scorer": True, "model": "JackFram/llama-68m",
}]) "num_speculative_tokens": 3,
"disable_mqa_scorer": True,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
...@@ -127,7 +137,7 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs, ...@@ -127,7 +137,7 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, seed: int): output_len: int, seed: int):
"""Verify that ngram speculative decoding generates the same output """Verify that speculative decoding generates the same output
with batch expansion scorer and mqa scorer. with batch expansion scorer and mqa scorer.
""" """
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(vllm_runner,
......
...@@ -27,18 +27,19 @@ from .conftest import run_equality_correctness_test_tp ...@@ -27,18 +27,19 @@ from .conftest import run_equality_correctness_test_tp
@pytest.mark.parametrize("baseline_llm_kwargs", [[]]) @pytest.mark.parametrize("baseline_llm_kwargs", [[]])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
[ [
"--speculative-model", "--speculative_config",
"JackFram/llama-68m", str({
"--num-speculative-tokens", "model": "JackFram/llama-68m",
"3", "num_speculative_tokens": 3,
}),
], ],
[ [
"--speculative-model", "--speculative_config",
"[ngram]", str({
"--num-speculative-tokens", "model": "ngram",
"5", "num_speculative_tokens": 5,
"--ngram-prompt-lookup-max", "prompt_lookup_max": 3,
"3", }),
], ],
]) ])
@pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize("batch_size", [2])
...@@ -83,23 +84,24 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs, ...@@ -83,23 +84,24 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
]]) ]])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [[]]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [[]])
@pytest.mark.parametrize("baseline_llm_kwargs", [[]]) @pytest.mark.parametrize("baseline_llm_kwargs", [[]])
@pytest.mark.parametrize("model, test_llm_kwargs", @pytest.mark.parametrize(
[("JackFram/llama-68m", [ "model, test_llm_kwargs",
"--speculative-model", [("JackFram/llama-68m", [
"JackFram/llama-68m", "--speculative_config",
"--num_speculative-tokens", str({
"5", "model": "JackFram/llama-68m",
"--speculative-draft-tensor-parallel-size", "num_speculative_tokens": 5,
"1", "draft_tensor_parallel_size": 1,
]), }),
("ibm-granite/granite-3b-code-instruct", [ ]),
"--speculative-model", ("ibm-granite/granite-3b-code-instruct", [
"ibm-granite/granite-3b-code-instruct", "--speculative_config",
"--num_speculative-tokens", str({
"5", "model": "ibm-granite/granite-3b-code-instruct",
"--speculative-draft-tensor-parallel-size", "num_speculative_tokens": 5,
"1", "draft_tensor_parallel_size": 1,
])]) }),
])])
@pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs, def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
...@@ -144,18 +146,19 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs, ...@@ -144,18 +146,19 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [[]]) @pytest.mark.parametrize("baseline_llm_kwargs", [[]])
@pytest.mark.parametrize("model, test_llm_kwargs", @pytest.mark.parametrize("model, test_llm_kwargs",
[("JackFram/llama-68m", [ [("JackFram/llama-68m", [
"--speculative-model", "--speculative_config",
"JackFram/llama-68m", str({
"--num_speculative-tokens", "model": "JackFram/llama-68m",
"3", "num_speculative_tokens": 3,
}),
]), ]),
("JackFram/llama-68m", [ ("JackFram/llama-68m", [
"--speculative-model", "--speculative_config",
"JackFram/llama-68m", str({
"--num_speculative-tokens", "model": "JackFram/llama-68m",
"3", "num_speculative_tokens": 3,
"--speculative-draft-tensor-parallel-size", "draft_tensor_parallel_size": 1,
"1", }),
])]) ])])
@pytest.mark.parametrize("logprobs", [None, 2]) @pytest.mark.parametrize("logprobs", [None, 2])
@pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize("batch_size", [2])
......
...@@ -24,12 +24,7 @@ SPEC_MODEL = "JackFram/llama-68m" ...@@ -24,12 +24,7 @@ SPEC_MODEL = "JackFram/llama-68m"
"4", "4",
]]) ]])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [ @pytest.mark.parametrize("per_test_common_llm_kwargs", [
[ [],
"--speculative-model",
f"{SPEC_MODEL}",
"--num-speculative-tokens",
"5",
],
]) ])
@pytest.mark.parametrize("baseline_llm_kwargs", [[]]) @pytest.mark.parametrize("baseline_llm_kwargs", [[]])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -37,8 +32,12 @@ SPEC_MODEL = "JackFram/llama-68m" ...@@ -37,8 +32,12 @@ SPEC_MODEL = "JackFram/llama-68m"
[ [
#TODO(wooyeon): add spec_draft_dp=2 case #TODO(wooyeon): add spec_draft_dp=2 case
[ [
"--speculative-draft-tensor-parallel-size", "--speculative_config",
"1", str({
"model": f"{SPEC_MODEL}",
"num_speculative_tokens": 5,
"draft_tensor_parallel_size": 1,
}),
], ],
]) ])
@pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize("batch_size", [2])
...@@ -78,15 +77,14 @@ def test_draft_model_tp_lt_target_model_tp4(common_llm_kwargs, ...@@ -78,15 +77,14 @@ def test_draft_model_tp_lt_target_model_tp4(common_llm_kwargs,
"test_llm_kwargs", "test_llm_kwargs",
[ [
[ [
"--speculative-model",
f"{SPEC_MODEL}",
"--num-speculative-tokens",
"5",
# Artificially limit the draft model max model len; this forces vLLM # Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens. # to skip speculation once the sequences grow beyond 32-k tokens.
"--speculative-max-model-len", "--speculative_config",
"32", str({
"model": f"{SPEC_MODEL}",
"num_speculative_tokens": 5,
"max_model_len": 32,
}),
], ],
]) ])
@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("batch_size", [8])
......
...@@ -20,16 +20,19 @@ from .conftest import run_equality_correctness_test ...@@ -20,16 +20,19 @@ from .conftest import run_equality_correctness_test
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", @pytest.mark.parametrize("test_llm_kwargs", [{
[{ "speculative_config": {
"speculative_model": "JackFram/llama-68m", "model": "JackFram/llama-68m",
"num_speculative_tokens": 3, "num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False, "disable_logprobs": False,
}, { },
"speculative_model": "JackFram/llama-68m", }, {
"num_speculative_tokens": 3, "speculative_config": {
"disable_logprobs_during_spec_decoding": True, "model": "JackFram/llama-68m",
}]) "num_speculative_tokens": 3,
"disable_logprobs": True,
},
}])
@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
...@@ -48,19 +51,20 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs, ...@@ -48,19 +51,20 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
as well as with and without chunked prefill. as well as with and without chunked prefill.
""" """
maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs) maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(
common_llm_kwargs, vllm_runner,
per_test_common_llm_kwargs, common_llm_kwargs,
baseline_llm_kwargs, per_test_common_llm_kwargs,
test_llm_kwargs, baseline_llm_kwargs,
batch_size, test_llm_kwargs,
output_len, batch_size,
seed, output_len,
temperature=0.0, seed,
logprobs=logprobs, temperature=0.0,
prompt_logprobs=logprobs, logprobs=logprobs,
disable_logprobs=test_llm_kwargs[ prompt_logprobs=logprobs,
'disable_logprobs_during_spec_decoding']) disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -73,16 +77,19 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs, ...@@ -73,16 +77,19 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", @pytest.mark.parametrize("test_llm_kwargs", [{
[{ "speculative_config": {
"speculative_model": "JackFram/llama-160m", "model": "JackFram/llama-160m",
"num_speculative_tokens": 3, "num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False, "disable_logprobs": False,
}, { },
"speculative_model": "JackFram/llama-160m", }, {
"num_speculative_tokens": 6, "speculative_config": {
"disable_logprobs_during_spec_decoding": False, "model": "JackFram/llama-160m",
}]) "num_speculative_tokens": 6,
"disable_logprobs": False,
},
}])
@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
...@@ -98,18 +105,19 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs, ...@@ -98,18 +105,19 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
output_len: int, seed: int, logprobs: int): output_len: int, seed: int, logprobs: int):
"""Veriy logprob greedy equality with different speculation lens. """Veriy logprob greedy equality with different speculation lens.
""" """
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(
common_llm_kwargs, vllm_runner,
per_test_common_llm_kwargs, common_llm_kwargs,
baseline_llm_kwargs, per_test_common_llm_kwargs,
test_llm_kwargs, baseline_llm_kwargs,
batch_size, test_llm_kwargs,
output_len, batch_size,
seed, output_len,
temperature=0.0, seed,
logprobs=logprobs, temperature=0.0,
disable_logprobs=test_llm_kwargs[ logprobs=logprobs,
'disable_logprobs_during_spec_decoding']) disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -125,13 +133,15 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs, ...@@ -125,13 +133,15 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize( @pytest.mark.parametrize(
"test_llm_kwargs", "test_llm_kwargs",
[{ [{
"speculative_model": "JackFram/llama-160m", "speculative_config": {
"num_speculative_tokens": 3, "model": "JackFram/llama-160m",
"disable_logprobs_during_spec_decoding": False, "num_speculative_tokens": 3,
"disable_logprobs": False,
# Artificially limit the draft model max model len; this forces vLLM # Artificially limit the draft model max model len; this forces
# to skip speculation once the sequences grow beyond 32-k tokens. # vLLM to skip speculation once the sequences grow beyond 32-k
"speculative_max_model_len": 32, # tokens.
"max_model_len": 32,
},
}]) }])
@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -149,18 +159,19 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs, ...@@ -149,18 +159,19 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
seed: int, logprobs: int): seed: int, logprobs: int):
"""Verify logprobs greedy equality when some sequences skip speculation. """Verify logprobs greedy equality when some sequences skip speculation.
""" """
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(
common_llm_kwargs, vllm_runner,
per_test_common_llm_kwargs, common_llm_kwargs,
baseline_llm_kwargs, per_test_common_llm_kwargs,
test_llm_kwargs, baseline_llm_kwargs,
batch_size, test_llm_kwargs,
output_len, batch_size,
seed, output_len,
temperature=0.0, seed,
logprobs=logprobs, temperature=0.0,
disable_logprobs=test_llm_kwargs[ logprobs=logprobs,
'disable_logprobs_during_spec_decoding']) disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -173,12 +184,13 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs, ...@@ -173,12 +184,13 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", @pytest.mark.parametrize("test_llm_kwargs", [{
[{ "speculative_config": {
"speculative_model": "JackFram/llama-160m", "model": "JackFram/llama-160m",
"num_speculative_tokens": 3, "num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False, "disable_logprobs": False,
}]) },
}])
@pytest.mark.parametrize("batch_size", [1]) @pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
...@@ -248,12 +260,13 @@ def test_logprobs_temp_1(vllm_runner, common_llm_kwargs, ...@@ -248,12 +260,13 @@ def test_logprobs_temp_1(vllm_runner, common_llm_kwargs,
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", @pytest.mark.parametrize("test_llm_kwargs", [{
[{ "speculative_config": {
"speculative_model": "JackFram/llama-68m", "model": "JackFram/llama-68m",
"num_speculative_tokens": 3, "num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": True, "disable_logprobs": True,
}]) },
}])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -270,15 +283,16 @@ def test_logprobs_disabled(vllm_runner, common_llm_kwargs, ...@@ -270,15 +283,16 @@ def test_logprobs_disabled(vllm_runner, common_llm_kwargs,
"""Check the behavior when logprobs are disabled. """Check the behavior when logprobs are disabled.
Token choices should match with the base model. Token choices should match with the base model.
""" """
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(
common_llm_kwargs, vllm_runner,
per_test_common_llm_kwargs, common_llm_kwargs,
baseline_llm_kwargs, per_test_common_llm_kwargs,
test_llm_kwargs, baseline_llm_kwargs,
batch_size, test_llm_kwargs,
output_len, batch_size,
seed, output_len,
temperature=0.0, seed,
logprobs=logprobs, temperature=0.0,
disable_logprobs=test_llm_kwargs[ logprobs=logprobs,
'disable_logprobs_during_spec_decoding']) disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
...@@ -60,8 +60,10 @@ PRECISION = "float32" ...@@ -60,8 +60,10 @@ PRECISION = "float32"
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS, "model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
}, },
]) ])
@pytest.mark.parametrize("output_len", [ @pytest.mark.parametrize("output_len", [
...@@ -107,14 +109,18 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, ...@@ -107,14 +109,18 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS, "model": SPEC_MODEL,
"disable_logprobs_during_spec_decoding": False, "num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs": False,
},
}, },
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS, "model": SPEC_MODEL,
"disable_logprobs_during_spec_decoding": True, "num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs": True,
},
}, },
]) ])
@pytest.mark.parametrize("output_len", [ @pytest.mark.parametrize("output_len", [
...@@ -132,19 +138,20 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, ...@@ -132,19 +138,20 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
prefill_chunk_size: int): prefill_chunk_size: int):
"""Verify greedy equality with different batch size.""" """Verify greedy equality with different batch size."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(
common_llm_kwargs, vllm_runner,
per_test_common_llm_kwargs, common_llm_kwargs,
baseline_llm_kwargs, per_test_common_llm_kwargs,
test_llm_kwargs, baseline_llm_kwargs,
batch_size, test_llm_kwargs,
max_output_len=output_len, batch_size,
seed=seed, max_output_len=output_len,
temperature=0.0, seed=seed,
logprobs=logprobs, temperature=0.0,
prompt_logprobs=logprobs, logprobs=logprobs,
disable_logprobs=test_llm_kwargs[ prompt_logprobs=logprobs,
'disable_logprobs_during_spec_decoding']) disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -165,8 +172,10 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, ...@@ -165,8 +172,10 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS, "model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
}, },
]) ])
@pytest.mark.parametrize("output_len", [ @pytest.mark.parametrize("output_len", [
...@@ -214,8 +223,10 @@ def test_medusa_e2e_greedy_correctness_cuda_graph( ...@@ -214,8 +223,10 @@ def test_medusa_e2e_greedy_correctness_cuda_graph(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS, "model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
}, },
]) ])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -264,8 +275,10 @@ def test_medusa_e2e_greedy_correctness_with_preemption( ...@@ -264,8 +275,10 @@ def test_medusa_e2e_greedy_correctness_with_preemption(
"test_llm_kwargs", "test_llm_kwargs",
[ [
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"num_speculative_tokens": k, "model": SPEC_MODEL,
"num_speculative_tokens": k,
},
} }
# Try a range of num. speculative tokens # Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS) for k in range(1, 1 + MAX_SPEC_TOKENS)
...@@ -312,12 +325,13 @@ def test_medusa_different_k(vllm_runner, common_llm_kwargs, ...@@ -312,12 +325,13 @@ def test_medusa_different_k(vllm_runner, common_llm_kwargs,
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", @pytest.mark.parametrize("test_llm_kwargs", [{
[{ "speculative_config": {
"speculative_model": SPEC_MODEL, "model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS, "num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4 "disable_by_batch_size": 4,
}]) },
}])
@pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
...@@ -359,16 +373,17 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs, ...@@ -359,16 +373,17 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
# Main model # Main model
"model_name": MAIN_MODEL, "model_name": MAIN_MODEL,
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", @pytest.mark.parametrize("test_llm_kwargs", [{
[{ "speculative_config": {
"speculative_disable_mqa_scorer": True, "model": SPEC_MODEL,
}]) "num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_by_batch_size": 4,
"disable_mqa_scorer": True,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
......
...@@ -62,7 +62,9 @@ PRECISION = "float32" ...@@ -62,7 +62,9 @@ PRECISION = "float32"
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"model": SPEC_MODEL,
},
}, },
]) ])
@pytest.mark.parametrize("output_len", [ @pytest.mark.parametrize("output_len", [
...@@ -108,12 +110,16 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, ...@@ -108,12 +110,16 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"disable_logprobs_during_spec_decoding": False, "model": SPEC_MODEL,
"disable_logprobs": False,
},
}, },
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"disable_logprobs_during_spec_decoding": True, "model": SPEC_MODEL,
"disable_logprobs": True,
},
}, },
]) ])
@pytest.mark.parametrize("output_len", [8]) @pytest.mark.parametrize("output_len", [8])
...@@ -133,19 +139,20 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, ...@@ -133,19 +139,20 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
# up sampling different tokens at the tail (ie top tokens don't change). # up sampling different tokens at the tail (ie top tokens don't change).
# TL;DR: sd+cp == org+cp but sd+cp != org..is this expected? # TL;DR: sd+cp == org+cp but sd+cp != org..is this expected?
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs) maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(
common_llm_kwargs, vllm_runner,
per_test_common_llm_kwargs, common_llm_kwargs,
baseline_llm_kwargs, per_test_common_llm_kwargs,
test_llm_kwargs, baseline_llm_kwargs,
batch_size, test_llm_kwargs,
max_output_len=output_len, batch_size,
seed=seed, max_output_len=output_len,
temperature=0.0, seed=seed,
logprobs=logprobs, temperature=0.0,
prompt_logprobs=logprobs, logprobs=logprobs,
disable_logprobs=test_llm_kwargs[ prompt_logprobs=logprobs,
'disable_logprobs_during_spec_decoding']) disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -167,7 +174,9 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, ...@@ -167,7 +174,9 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"model": SPEC_MODEL,
},
}, },
]) ])
@pytest.mark.parametrize("output_len", [2048]) @pytest.mark.parametrize("output_len", [2048])
...@@ -209,8 +218,10 @@ def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs, ...@@ -209,8 +218,10 @@ def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
# Main model # Main model
"model_name": MAIN_MODEL, "model_name": MAIN_MODEL,
# Speculative model # Speculative config
"speculative_model": SPEC_MODEL, "speculative_config": {
"model": SPEC_MODEL,
},
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
...@@ -274,7 +285,9 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs, ...@@ -274,7 +285,9 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"model": SPEC_MODEL,
},
}, },
]) ])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -326,7 +339,9 @@ def test_mlp_e2e_greedy_correctness_with_preemption( ...@@ -326,7 +339,9 @@ def test_mlp_e2e_greedy_correctness_with_preemption(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"model": SPEC_MODEL,
},
}, },
]) ])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -382,8 +397,10 @@ def test_mlp_e2e_greedy_correctness_with_padding( ...@@ -382,8 +397,10 @@ def test_mlp_e2e_greedy_correctness_with_padding(
"test_llm_kwargs", "test_llm_kwargs",
[ [
{ {
"speculative_model": SPEC_MODEL, "speculative_config": {
"num_speculative_tokens": k, "model": SPEC_MODEL,
"num_speculative_tokens": k,
},
} }
# Try a range of num. speculative tokens # Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS) for k in range(1, 1 + MAX_SPEC_TOKENS)
...@@ -430,11 +447,12 @@ def test_mlp_different_k(vllm_runner, common_llm_kwargs, ...@@ -430,11 +447,12 @@ def test_mlp_different_k(vllm_runner, common_llm_kwargs,
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", @pytest.mark.parametrize("test_llm_kwargs", [{
[{ "speculative_config": {
"speculative_model": SPEC_MODEL, "model": SPEC_MODEL,
"speculative_disable_by_batch_size": 4 "disable_by_batch_size": 4,
}]) },
}])
@pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
...@@ -475,14 +493,15 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs, ...@@ -475,14 +493,15 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
"speculative_model": SPEC_MODEL,
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", @pytest.mark.parametrize("test_llm_kwargs", [{
[{ "speculative_config": {
"speculative_disable_mqa_scorer": True, "model": SPEC_MODEL,
}]) "disable_mqa_scorer": True,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
......
...@@ -57,7 +57,9 @@ PRECISION = "bfloat16" ...@@ -57,7 +57,9 @@ PRECISION = "bfloat16"
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"num_speculative_tokens": MAX_SPEC_TOKENS, "speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
}, },
]) ])
@pytest.mark.parametrize("output_len", [ @pytest.mark.parametrize("output_len", [
...@@ -99,12 +101,16 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, ...@@ -99,12 +101,16 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"num_speculative_tokens": MAX_SPEC_TOKENS, "speculative_config": {
"disable_logprobs_during_spec_decoding": False, "num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs": False,
},
}, },
{ {
"num_speculative_tokens": MAX_SPEC_TOKENS, "speculative_config": {
"disable_logprobs_during_spec_decoding": True, "num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs": True,
},
}, },
]) ])
@pytest.mark.parametrize("output_len", [ @pytest.mark.parametrize("output_len", [
...@@ -119,18 +125,19 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, ...@@ -119,18 +125,19 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
batch_size: int, output_len: int, seed: int, batch_size: int, output_len: int, seed: int,
logprobs: int): logprobs: int):
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(
common_llm_kwargs, vllm_runner,
per_test_common_llm_kwargs, common_llm_kwargs,
baseline_llm_kwargs, per_test_common_llm_kwargs,
test_llm_kwargs, baseline_llm_kwargs,
batch_size, test_llm_kwargs,
output_len, batch_size,
seed, output_len,
logprobs=logprobs, seed,
prompt_logprobs=logprobs, logprobs=logprobs,
disable_logprobs=test_llm_kwargs[ prompt_logprobs=logprobs,
'disable_logprobs_during_spec_decoding']) disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -152,7 +159,9 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, ...@@ -152,7 +159,9 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"num_speculative_tokens": MAX_SPEC_TOKENS, "speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
}, },
]) ])
@pytest.mark.parametrize("output_len", [ @pytest.mark.parametrize("output_len", [
...@@ -198,7 +207,9 @@ def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs, ...@@ -198,7 +207,9 @@ def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"num_speculative_tokens": MAX_SPEC_TOKENS, "speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
}, },
]) ])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -243,7 +254,9 @@ def test_mtp_e2e_greedy_correctness_with_preemption( ...@@ -243,7 +254,9 @@ def test_mtp_e2e_greedy_correctness_with_preemption(
"test_llm_kwargs", "test_llm_kwargs",
[ [
{ {
"num_speculative_tokens": k, "speculative_config": {
"num_speculative_tokens": k,
},
} }
# Try a range of num. speculative tokens # Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS) for k in range(1, 1 + MAX_SPEC_TOKENS)
...@@ -286,11 +299,12 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs, ...@@ -286,11 +299,12 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs,
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", @pytest.mark.parametrize("test_llm_kwargs", [{
[{ "speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS, "num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4 "disable_by_batch_size": 4
}]) },
}])
@pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
......
...@@ -61,15 +61,19 @@ from .conftest import (get_output_from_llm_generator, ...@@ -61,15 +61,19 @@ from .conftest import (get_output_from_llm_generator,
"per_test_common_llm_kwargs", "per_test_common_llm_kwargs",
[ [
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"num_speculative_tokens": 5, "model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False, "enable_chunked_prefill": False,
}, },
{ {
# Chunked prefill enabled with small value # Chunked prefill enabled with small value
# to make sure we get mixed batches. # to make sure we get mixed batches.
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"num_speculative_tokens": 5, "model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True, "enable_chunked_prefill": True,
"max_num_batched_tokens": 4, "max_num_batched_tokens": 4,
"max_num_seqs": 4 "max_num_seqs": 4
...@@ -148,20 +152,23 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator, ...@@ -148,20 +152,23 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
}, },
]) ])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", @pytest.mark.parametrize("test_llm_kwargs", [{
[{ "speculative_config": {
"speculative_model": "JackFram/llama-68m", "model": "JackFram/llama-68m",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"enable_chunked_prefill": False, "disable_logprobs": False,
"disable_logprobs_during_spec_decoding": False },
}, { "enable_chunked_prefill": False,
"speculative_model": "JackFram/llama-68m", }, {
"num_speculative_tokens": 3, "speculative_config": {
"enable_chunked_prefill": True, "model": "JackFram/llama-68m",
"max_num_batched_tokens": 4, "num_speculative_tokens": 3,
"max_num_seqs": 4, "disable_logprobs": False,
"disable_logprobs_during_spec_decoding": False },
}]) "enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
}])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
[ [
...@@ -184,7 +191,7 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( ...@@ -184,7 +191,7 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
whether all speculative tokens are accepted. whether all speculative tokens are accepted.
""" """
ensure_all_accepted = per_test_common_llm_kwargs.get( ensure_all_accepted = per_test_common_llm_kwargs.get(
"model_name") == test_llm_kwargs.get("speculative_model") "model_name") == test_llm_kwargs.get("speculative_config")["model"]
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(vllm_runner,
common_llm_kwargs, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
...@@ -224,13 +231,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( ...@@ -224,13 +231,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"num_speculative_tokens": 5, "model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False, "enable_chunked_prefill": False,
}, },
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"num_speculative_tokens": 5, "model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True, "enable_chunked_prefill": True,
"max_num_batched_tokens": 4, "max_num_batched_tokens": 4,
"max_num_seqs": 4 "max_num_seqs": 4
...@@ -283,13 +294,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs( ...@@ -283,13 +294,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"num_speculative_tokens": 5, "model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False, "enable_chunked_prefill": False,
}, },
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"num_speculative_tokens": 5, "model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True, "enable_chunked_prefill": True,
"max_num_batched_tokens": 4, "max_num_batched_tokens": 4,
"max_num_seqs": 4 "max_num_seqs": 4
...@@ -336,13 +351,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len( ...@@ -336,13 +351,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"num_speculative_tokens": 5, "model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False, "enable_chunked_prefill": False,
}, },
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"num_speculative_tokens": 5, "model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True, "enable_chunked_prefill": True,
"max_num_batched_tokens": 4, "max_num_batched_tokens": 4,
"max_num_seqs": 4 "max_num_seqs": 4
...@@ -391,13 +410,17 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1( ...@@ -391,13 +410,17 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"num_speculative_tokens": 5, "model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False, "enable_chunked_prefill": False,
}, },
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"num_speculative_tokens": 5, "model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True, "enable_chunked_prefill": True,
"max_num_batched_tokens": 4, "max_num_batched_tokens": 4,
"max_num_seqs": 4 "max_num_seqs": 4
...@@ -449,13 +472,17 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs( ...@@ -449,13 +472,17 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"num_speculative_tokens": 5, "model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False, "enable_chunked_prefill": False,
}, },
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"num_speculative_tokens": 5, "model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True, "enable_chunked_prefill": True,
"max_num_batched_tokens": 4, "max_num_batched_tokens": 4,
"max_num_seqs": 4 "max_num_seqs": 4
...@@ -514,13 +541,17 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption( ...@@ -514,13 +541,17 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"num_speculative_tokens": 5, "model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False, "enable_chunked_prefill": False,
}, },
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"num_speculative_tokens": 5, "model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True, "enable_chunked_prefill": True,
"max_num_batched_tokens": 4, "max_num_batched_tokens": 4,
"max_num_seqs": 4 "max_num_seqs": 4
...@@ -567,21 +598,25 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs, ...@@ -567,21 +598,25 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs,
"test_llm_kwargs", "test_llm_kwargs",
[ [
{ {
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
# Artificially limit the draft model max model len; this forces vLLM # Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens. # to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_max_model_len": 32, "speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"max_model_len": 32,
},
"enable_chunked_prefill": False, "enable_chunked_prefill": False,
}, },
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"num_speculative_tokens": 5, "model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"max_model_len": 32,
},
"enable_chunked_prefill": True, "enable_chunked_prefill": True,
"max_num_batched_tokens": 4, "max_num_batched_tokens": 4,
"max_num_seqs": 4, "max_num_seqs": 4,
"speculative_max_model_len": 32,
}, },
]) ])
@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("batch_size", [8])
...@@ -627,15 +662,19 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs, ...@@ -627,15 +662,19 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"num_speculative_tokens": 5, "model": "JackFram/llama-68m",
"speculative_disable_by_batch_size": 2, "num_speculative_tokens": 5,
"disable_by_batch_size": 2,
},
"enable_chunked_prefill": False, "enable_chunked_prefill": False,
}, },
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"num_speculative_tokens": 5, "model": "JackFram/llama-68m",
"speculative_disable_by_batch_size": 2, "num_speculative_tokens": 5,
"disable_by_batch_size": 2,
},
"enable_chunked_prefill": True, "enable_chunked_prefill": True,
"max_num_batched_tokens": 4, "max_num_batched_tokens": 4,
"max_num_seqs": 4, "max_num_seqs": 4,
...@@ -676,15 +715,19 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs, ...@@ -676,15 +715,19 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs,
"test_llm_kwargs", "test_llm_kwargs",
[ [
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"num_speculative_tokens": k, "model": "JackFram/llama-68m",
"num_speculative_tokens": k,
},
"enable_chunked_prefill": False, "enable_chunked_prefill": False,
} }
# Try a range of common k, as well as large speculation. # Try a range of common k, as well as large speculation.
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63] for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
] + [{ ] + [{
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"num_speculative_tokens": k, "model": "JackFram/llama-68m",
"num_speculative_tokens": k,
},
"enable_chunked_prefill": True, "enable_chunked_prefill": True,
"max_num_batched_tokens": 4, "max_num_batched_tokens": 4,
"max_num_seqs": 4, "max_num_seqs": 4,
...@@ -729,17 +772,21 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, ...@@ -729,17 +772,21 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
"test_llm_kwargs", "test_llm_kwargs",
[ [
{ {
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"num_speculative_tokens": k, "model": "JackFram/llama-68m",
"spec_decoding_acceptance_method": "typical_acceptance_sampler", "num_speculative_tokens": k,
"acceptance_method": "typical_acceptance_sampler",
},
"enable_chunked_prefill": False "enable_chunked_prefill": False
} }
# Try a range of common k. # Try a range of common k.
for k in [1, 2, 3] for k in [1, 2, 3]
] + [{ ] + [{
"speculative_model": "JackFram/llama-68m", "speculative_config": {
"num_speculative_tokens": k, "model": "JackFram/llama-68m",
"spec_decoding_acceptance_method": "typical_acceptance_sampler", "num_speculative_tokens": k,
"acceptance_method": "typical_acceptance_sampler",
},
"enable_chunked_prefill": True, "enable_chunked_prefill": True,
"max_num_batched_tokens": 4, "max_num_batched_tokens": 4,
"max_num_seqs": 4 "max_num_seqs": 4
......
...@@ -48,16 +48,20 @@ from .conftest import run_equality_correctness_test ...@@ -48,16 +48,20 @@ from .conftest import run_equality_correctness_test
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": "[ngram]", "speculative_config": {
"num_speculative_tokens": 5, "method": "ngram",
"ngram_prompt_lookup_max": 3, "num_speculative_tokens": 5,
"speculative_disable_mqa_scorer": False, "prompt_lookup_max": 3,
"disable_mqa_scorer": False,
},
}, },
{ {
"speculative_model": "[ngram]", "speculative_config": {
"num_speculative_tokens": 5, "method": "ngram",
"ngram_prompt_lookup_max": 3, "num_speculative_tokens": 5,
"speculative_disable_mqa_scorer": True, "prompt_lookup_max": 3,
"disable_mqa_scorer": True,
},
}, },
]) ])
@pytest.mark.parametrize("output_len", [ @pytest.mark.parametrize("output_len", [
...@@ -101,16 +105,20 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, ...@@ -101,16 +105,20 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": "[ngram]", "speculative_config": {
"num_speculative_tokens": 5, "method": "ngram",
"ngram_prompt_lookup_max": 3, "num_speculative_tokens": 5,
"disable_logprobs_during_spec_decoding": False, "prompt_lookup_max": 3,
"disable_logprobs": False,
},
}, },
{ {
"speculative_model": "[ngram]", "speculative_config": {
"num_speculative_tokens": 5, "method": "ngram",
"ngram_prompt_lookup_max": 3, "num_speculative_tokens": 5,
"disable_logprobs_during_spec_decoding": True, "prompt_lookup_max": 3,
"disable_logprobs": True,
},
}, },
]) ])
@pytest.mark.parametrize("output_len", [ @pytest.mark.parametrize("output_len", [
...@@ -125,19 +133,20 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, ...@@ -125,19 +133,20 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
batch_size: int, output_len: int, seed: int, batch_size: int, output_len: int, seed: int,
logprobs: int): logprobs: int):
"""Verify greedy equality on a tiny model with different batch size.""" """Verify greedy equality on a tiny model with different batch size."""
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(
common_llm_kwargs, vllm_runner,
per_test_common_llm_kwargs, common_llm_kwargs,
baseline_llm_kwargs, per_test_common_llm_kwargs,
test_llm_kwargs, baseline_llm_kwargs,
batch_size, test_llm_kwargs,
max_output_len=output_len, batch_size,
seed=seed, max_output_len=output_len,
temperature=0.0, seed=seed,
logprobs=logprobs, temperature=0.0,
prompt_logprobs=logprobs, logprobs=logprobs,
disable_logprobs=test_llm_kwargs[ prompt_logprobs=logprobs,
'disable_logprobs_during_spec_decoding']) disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -159,17 +168,21 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, ...@@ -159,17 +168,21 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [ @pytest.mark.parametrize("test_llm_kwargs", [
{ {
"speculative_model": "[ngram]", "speculative_config": {
"num_speculative_tokens": 5, "method": "ngram",
"ngram_prompt_lookup_max": 3, "num_speculative_tokens": 5,
"prompt_lookup_max": 3,
},
"enable_chunked_prefill": False, "enable_chunked_prefill": False,
}, },
{ {
"speculative_model": "[ngram]", "speculative_config": {
"num_speculative_tokens": 5, "method": "ngram",
"ngram_prompt_lookup_max": 3, "num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_mqa_scorer": True,
},
"enable_chunked_prefill": True, "enable_chunked_prefill": True,
"speculative_disable_mqa_scorer": True,
"max_num_batched_tokens": 4, "max_num_batched_tokens": 4,
"max_num_seqs": 4 "max_num_seqs": 4
}, },
...@@ -214,17 +227,21 @@ def test_ngram_e2e_greedy_correctness_with_preemption( ...@@ -214,17 +227,21 @@ def test_ngram_e2e_greedy_correctness_with_preemption(
"test_llm_kwargs", "test_llm_kwargs",
[ [
{ {
"speculative_model": "[ngram]", "speculative_config": {
"num_speculative_tokens": k, "method": "ngram",
"ngram_prompt_lookup_max": 3, "num_speculative_tokens": k,
"prompt_lookup_max": 3,
},
} }
# Try a range of common k, as well as large speculation. # Try a range of common k, as well as large speculation.
for k in [1, 3, 5] for k in [1, 3, 5]
] + [ ] + [
{ {
"speculative_model": "[ngram]", "speculative_config": {
"num_speculative_tokens": k, "method": "ngram",
"ngram_prompt_lookup_max": 1, "num_speculative_tokens": k,
"prompt_lookup_max": 1,
},
} }
# Try a range of common k, as well as large speculation. # Try a range of common k, as well as large speculation.
for k in [1, 3, 5] for k in [1, 3, 5]
...@@ -243,7 +260,7 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs, ...@@ -243,7 +260,7 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
seed: int): seed: int):
"""Verify that ngram speculative decoding produces exact equality """Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and to without spec decode with many different values of k and
different ngram_prompt_lookup_max. different ngram prompt_lookup_max.
""" """
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(vllm_runner,
common_llm_kwargs, common_llm_kwargs,
...@@ -266,22 +283,25 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs, ...@@ -266,22 +283,25 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", @pytest.mark.parametrize("test_llm_kwargs", [{
[{ "speculative_config": {
"speculative_model": "[ngram]", "method": "ngram",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3, "prompt_lookup_max": 3,
"speculative_disable_by_batch_size": 4 "disable_by_batch_size": 4
}, { },
"speculative_model": "[ngram]", }, {
"num_speculative_tokens": 5, "speculative_config": {
"ngram_prompt_lookup_max": 3, "method": "ngram",
"speculative_disable_by_batch_size": 4, "num_speculative_tokens": 5,
"enable_chunked_prefill": True, "prompt_lookup_max": 3,
"speculative_disable_mqa_scorer": True, "disable_by_batch_size": 4,
"max_num_batched_tokens": 4, "disable_mqa_scorer": True,
"max_num_seqs": 4 },
}]) "enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
}])
@pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
...@@ -296,7 +316,7 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs, ...@@ -296,7 +316,7 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
seed: int): seed: int):
"""Verify that ngram speculative decoding produces exact equality """Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and to without spec decode with many different values of k and
different ngram_prompt_lookup_max. different ngram prompt_lookup_max.
""" """
run_equality_correctness_test(vllm_runner, run_equality_correctness_test(vllm_runner,
common_llm_kwargs, common_llm_kwargs,
...@@ -316,18 +336,17 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs, ...@@ -316,18 +336,17 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
# Required for spec decode.
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", @pytest.mark.parametrize("test_llm_kwargs", [{
[{ "speculative_config": {
"speculative_disable_mqa_scorer": True, "method": "ngram",
}]) "num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_mqa_scorer": True,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_len", "output_len",
......
...@@ -19,11 +19,11 @@ SPEC_MODEL = "JackFram/llama-160m" ...@@ -19,11 +19,11 @@ SPEC_MODEL = "JackFram/llama-160m"
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
# speculative model # speculative config
"speculative_model": "JackFram/llama-160m", "speculative_config": {
"model": "JackFram/llama-160m",
# num speculative tokens "num_speculative_tokens": 3,
"num_speculative_tokens": 3, },
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
......
...@@ -41,10 +41,10 @@ async def test_tokenizer_group(tokenizer_group_type): ...@@ -41,10 +41,10 @@ async def test_tokenizer_group(tokenizer_group_type):
max_input_length=None, max_input_length=None,
) )
assert reference_tokenizer.encode("prompt") == tokenizer_group.encode( assert reference_tokenizer.encode("prompt") == tokenizer_group.encode(
request_id="request_id", prompt="prompt", lora_request=None) prompt="prompt", lora_request=None)
assert reference_tokenizer.encode( assert reference_tokenizer.encode(
"prompt") == await tokenizer_group.encode_async( "prompt") == await tokenizer_group.encode_async(prompt="prompt",
request_id="request_id", prompt="prompt", lora_request=None) lora_request=None)
assert isinstance(tokenizer_group.get_lora_tokenizer(None), assert isinstance(tokenizer_group.get_lora_tokenizer(None),
PreTrainedTokenizerBase) PreTrainedTokenizerBase)
assert tokenizer_group.get_lora_tokenizer( assert tokenizer_group.get_lora_tokenizer(
...@@ -69,8 +69,7 @@ async def test_tokenizer_group_pool(tokenizer_group_type): ...@@ -69,8 +69,7 @@ async def test_tokenizer_group_pool(tokenizer_group_type):
# and check that all requests are processed correctly. # and check that all requests are processed correctly.
num_requests = tokenizer_group_pool.pool_size * 5 num_requests = tokenizer_group_pool.pool_size * 5
requests = [ requests = [
tokenizer_group_pool.encode_async(request_id=str(i), tokenizer_group_pool.encode_async(prompt=f"prompt {i}",
prompt=f"prompt {i}",
lora_request=None) lora_request=None)
for i in range(num_requests) for i in range(num_requests)
] ]
...@@ -161,12 +160,8 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type): ...@@ -161,12 +160,8 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
fail_at[0] = 1000 fail_at[0] = 1000
# We should recover successfully. # We should recover successfully.
await tokenizer_group_pool.encode_async(request_id="1", await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None)
prompt="prompt", await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None)
lora_request=None)
await tokenizer_group_pool.encode_async(request_id="1",
prompt="prompt",
lora_request=None)
# Check that we have a new actor # Check that we have a new actor
assert len(tokenizer_group_pool.tokenizer_actors) == len(tokenizer_actors) assert len(tokenizer_group_pool.tokenizer_actors) == len(tokenizer_actors)
...@@ -184,8 +179,7 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type): ...@@ -184,8 +179,7 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
# We should fail after re-initialization. # We should fail after re-initialization.
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
await tokenizer_group_pool.encode_async(request_id="1", await tokenizer_group_pool.encode_async(prompt="prompt",
prompt="prompt",
lora_request=None) lora_request=None)
# check_health should raise the same thing # check_health should raise the same thing
...@@ -206,11 +200,8 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type): ...@@ -206,11 +200,8 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
# Prompt too long error # Prompt too long error
with pytest.raises(ValueError): with pytest.raises(ValueError):
await tokenizer_group_pool.encode_async(request_id="1", await tokenizer_group_pool.encode_async(prompt="prompt" * 100,
prompt="prompt" * 100,
lora_request=None) lora_request=None)
await tokenizer_group_pool.encode_async(request_id="1", await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None)
prompt="prompt",
lora_request=None)
# Actors should stay the same. # Actors should stay the same.
assert tokenizer_group_pool.tokenizer_actors == tokenizer_actors assert tokenizer_group_pool.tokenizer_actors == tokenizer_actors
...@@ -786,7 +786,7 @@ def large_gpu_mark(min_gb: int) -> pytest.MarkDecorator: ...@@ -786,7 +786,7 @@ def large_gpu_mark(min_gb: int) -> pytest.MarkDecorator:
without enough resources, or called when filtering tests to run directly. without enough resources, or called when filtering tests to run directly.
""" """
try: try:
if current_platform.is_cpu() or current_platform.is_openvino(): if current_platform.is_cpu():
memory_gb = 0 memory_gb = 0
else: else:
memory_gb = current_platform.get_device_total_memory() / GB_bytes memory_gb = current_platform.get_device_total_memory() / GB_bytes
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import pytest import pytest
import torch
from vllm.multimodal.inputs import MultiModalKwargs from vllm.multimodal.inputs import MultiModalKwargs
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
...@@ -8,7 +9,10 @@ from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, ...@@ -8,7 +9,10 @@ from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
KVCacheBlock, PrefixCachingMetrics, KVCacheBlock, PrefixCachingMetrics,
generate_block_hash_extra_keys, generate_block_hash_extra_keys,
hash_block_tokens, hash_block_tokens,
hash_request_tokens) hash_request_tokens,
unify_kv_cache_configs)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor)
from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request from vllm.v1.request import Request
...@@ -314,3 +318,107 @@ def test_metrics(): ...@@ -314,3 +318,107 @@ def test_metrics():
assert metrics.aggregated_query_total == 0 assert metrics.aggregated_query_total == 0
assert metrics.aggregated_query_hit == 0 assert metrics.aggregated_query_hit == 0
assert not metrics.query_queue assert not metrics.query_queue
def test_unify_kv_cache_configs():
def new_kv_cache_spec(block_size=16,
num_kv_heads=2,
head_size=64,
dtype=torch.float32,
use_mla=False):
return FullAttentionSpec(block_size=block_size,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
use_mla=use_mla)
same_kv_cache_config = [
KVCacheConfig(
num_blocks=10,
tensors={
"layer1": KVCacheTensor(100),
"layer2": KVCacheTensor(100),
},
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
new_kv_cache_spec(num_kv_heads=4)),
],
),
KVCacheConfig(
num_blocks=20,
tensors={
"layer1": KVCacheTensor(100),
"layer2": KVCacheTensor(100),
},
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
new_kv_cache_spec(num_kv_heads=4)),
],
),
]
unify_kv_cache_configs(same_kv_cache_config)
assert same_kv_cache_config[0].num_blocks == 10
assert same_kv_cache_config[1].num_blocks == 10
need_sort_kv_cache_config = [
KVCacheConfig(
num_blocks=10,
tensors={
"layer1": KVCacheTensor(100),
"layer2": KVCacheTensor(100),
},
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
new_kv_cache_spec(num_kv_heads=4)),
],
),
KVCacheConfig(
num_blocks=20,
tensors={
"layer1": KVCacheTensor(100),
"layer2": KVCacheTensor(100),
},
kv_cache_groups=[
KVCacheGroupSpec(["layer2"],
new_kv_cache_spec(num_kv_heads=4)),
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
],
),
]
unify_kv_cache_configs(need_sort_kv_cache_config)
assert need_sort_kv_cache_config[0].num_blocks == 10
assert need_sort_kv_cache_config[1].num_blocks == 10
diff_kv_cache_config = [
KVCacheConfig(
num_blocks=10,
tensors={
"layer1": KVCacheTensor(100),
"layer2": KVCacheTensor(100),
},
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
new_kv_cache_spec(num_kv_heads=4)),
],
),
KVCacheConfig(
num_blocks=20,
tensors={
"layer1": KVCacheTensor(100),
"layer2": KVCacheTensor(100),
},
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
new_kv_cache_spec(num_kv_heads=8)),
],
),
]
with pytest.raises(AssertionError):
unify_kv_cache_configs(diff_kv_cache_config)
...@@ -6,7 +6,8 @@ import pytest ...@@ -6,7 +6,8 @@ import pytest
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.core.scheduler import Scheduler, SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output import StructuredOutputManager
......
...@@ -70,12 +70,16 @@ def test_ngram_correctness( ...@@ -70,12 +70,16 @@ def test_ngram_correctness(
ref_outputs = ref_llm.chat(test_prompts, sampling_config) ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm del ref_llm
spec_llm = LLM(model=model_name, spec_llm = LLM(
speculative_model='[ngram]', model=model_name,
ngram_prompt_lookup_max=5, speculative_config={
ngram_prompt_lookup_min=3, "method": "ngram",
num_speculative_tokens=3, "prompt_lookup_max": 5,
max_model_len=1024) "prompt_lookup_min": 3,
"num_speculative_tokens": 3,
},
max_model_len=1024,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config) spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0 matches = 0
misses = 0 misses = 0
......
...@@ -158,6 +158,22 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch): ...@@ -158,6 +158,22 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
assert len(engine_core.scheduler.waiting) == 0 assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 0 assert len(engine_core.scheduler.running) == 0
# Sending duplicate requests with same request_id
req0 = make_request()
req1 = make_request()
req0.request_id = req1.request_id = "test"
engine_core.add_request(req0)
while len(engine_core.step().outputs) > 0:
pass
engine_core.add_request(req1)
while len(engine_core.step().outputs) > 0:
pass
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 0
@create_new_process_for_each_test() @create_new_process_for_each_test()
def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch): def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch):
......
...@@ -57,6 +57,50 @@ def test_guided_json_completion( ...@@ -57,6 +57,50 @@ def test_guided_json_completion(
jsonschema.validate(instance=output_json, schema=sample_json_schema) jsonschema.validate(instance=output_json, schema=sample_json_schema)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
def test_guided_json_completion_disable_any_whitespace(
monkeypatch: pytest.MonkeyPatch,
sample_json_schema: dict[str, Any],
guided_decoding_backend: str,
model_name: str,
):
if guided_decoding_backend != "xgrammar":
pytest.skip("disable-any-whitespace is only supported for xgrammar.")
guided_decoding_backend = 'xgrammar:disable-any-whitespace'
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=sample_json_schema))
outputs = llm.generate(prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
assert "\n" not in generated_text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=sample_json_schema)
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1) GUIDED_DECODING_BACKENDS_V1)
...@@ -301,7 +345,6 @@ def test_guided_choice_completion( ...@@ -301,7 +345,6 @@ def test_guided_choice_completion(
prompts="The best language for type-safe systems programming is ", prompts="The best language for type-safe systems programming is ",
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True) use_tqdm=True)
assert outputs is not None assert outputs is not None
for output in outputs: for output in outputs:
assert output is not None assert output is not None
......
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
import numpy as np import numpy as np
from vllm.v1.spec_decode.ngram_proposer import (_find_subarray_kmp, from vllm.v1.spec_decode.ngram_proposer import (NgramProposer,
_find_subarray_kmp,
_kmp_lps_array) _kmp_lps_array)
...@@ -35,3 +36,53 @@ def test_find_subarray_kmp(): ...@@ -35,3 +36,53 @@ def test_find_subarray_kmp():
# Return on the first match # Return on the first match
np.testing.assert_array_equal(_find_subarray_kmp(X, 1, 3), np.testing.assert_array_equal(_find_subarray_kmp(X, 1, 3),
np.array([6, 2, 3])) np.array([6, 2, 3]))
def test_ngram_proposer():
proposer = NgramProposer()
# No match.
result = proposer.propose(
context_token_ids=np.array([1, 2, 3, 4, 5]),
min_n=2,
max_n=2,
k=2,
)
assert result is None
# No match for 4-gram.
result = proposer.propose(
context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]),
min_n=4,
max_n=4,
k=2,
)
assert result is None
# No match for 4-gram but match for 3-gram.
result = proposer.propose(
context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]),
min_n=3,
max_n=4,
k=2,
)
assert np.array_equal(result, np.array([4, 1]))
# Match for both 4-gram and 3-gram.
# In this case, the proposer should return the 4-gram match.
result = proposer.propose(
context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]),
min_n=3,
max_n=4,
k=2,
)
assert np.array_equal(result, np.array([1, 2])) # Not [5, 1]
# Match for 2-gram and 3-gram, but not 4-gram.
result = proposer.propose(
context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]),
min_n=2,
max_n=4,
k=2,
)
assert np.array_equal(result, np.array([1, 2])) # Not [5, 2]
# SPDX-License-Identifier: Apache-2.0
"""
Test:
* Tests for MultiHeadAttention layer
"""
import pytest
import torch
import torch_xla
import torch_xla.core
import torch_xla.core.xla_model
from vllm import envs
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.selector import _cached_get_attn_backend
from vllm.platforms import current_platform
if not envs.VLLM_USE_V1:
pytest.skip(
"Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.",
allow_module_level=True,
)
@pytest.fixture(autouse=True)
def clear_cache():
"""Clear lru cache to ensure each test case runs without caching.
"""
_cached_get_attn_backend.cache_clear()
def ref_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: float,
) -> torch.Tensor:
"""
Native implementation of scaled dot product attention without mask:
- query, key, value: [batch_size, seq_len, num_heads, head_size]
- attn_mask: [batch_size, seq_len, seq_len]
"""
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
attn_weights = scale * torch.matmul(query, key.transpose(2, 3))
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
out = torch.matmul(attn_weights, value).transpose(1, 2)
return out
BATCH_SIZES = [1, 16]
SEQ_LENS = [1]
NUM_HEADS = [1, 16]
NUM_KV_HEADS = [1]
HEAD_SIZES = [64, 80]
@pytest.mark.skipif(not current_platform.is_tpu(),
reason="This test needs a TPU")
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("device", [torch_xla.core.xla_model.xla_device()])
def test_mha_attn_forward(
batch_size: int,
seq_len: int,
num_heads: int,
num_kv_heads: int,
head_size: int,
device: str,
):
current_platform.seed_everything(0)
# These are expected to be f32
q = torch.randn(batch_size, seq_len, num_heads * head_size, device=device)
k = torch.randn(batch_size,
seq_len,
num_kv_heads * head_size,
device=device)
v = torch.randn(batch_size,
seq_len,
num_kv_heads * head_size,
device=device)
scale = 1.0 / head_size**0.5
attn = MultiHeadAttention(num_heads,
head_size,
scale=scale,
num_kv_heads=num_kv_heads)
output = attn(q, k, v)
assert num_heads % num_kv_heads == 0
num_queries_per_kv = num_heads // num_kv_heads
q = q.reshape(batch_size, seq_len, num_heads, head_size)
k = k.reshape(batch_size, seq_len, num_kv_heads, head_size)
v = v.reshape(batch_size, seq_len, num_kv_heads, head_size)
if num_queries_per_kv > 1:
k = torch.repeat_interleave(k, num_queries_per_kv, dim=2)
v = torch.repeat_interleave(v, num_queries_per_kv, dim=2)
ref_output = ref_attention(
q,
k,
v,
scale=scale,
).reshape(batch_size, seq_len, num_heads * head_size)
# torch_xla flash_attn kernel is less accurate but much faster
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-3)
# SPDX-License-Identifier: Apache-2.0
import tempfile
from time import time
import pytest
from vllm import LLM, envs
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
if not envs.VLLM_USE_V1:
pytest.skip(
"Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.",
allow_module_level=True,
)
@pytest.mark.parametrize("model_name", ["D4nt3/Qwen2.5-two-layers"])
@pytest.mark.skipif(not current_platform.is_tpu(),
reason="This test needs a TPU")
def test_sampler_compilation(model_name: str, monkeypatch):
"""
Check that no recompilation happens despite changing sampling parameters.
We can't read XLA metrics from the engine process, hence we measure time.
"""
with tempfile.TemporaryDirectory() as temp_dir:
monkeypatch.setenv("VLLM_XLA_CACHE_PATH", temp_dir)
# Compiling model init may still take some time, enforce_eager to skip.
llm = LLM(model_name,
enforce_eager=True,
max_num_seqs=16,
max_model_len=1024,
gpu_memory_utilization=0.5)
prompts = [
"A robot may not injure a human being",
"It is only with the heart that one can see rightly;",
]
# First inference should be slow
sampling_params = SamplingParams(
temperature=0.7,
# top_p=0.6, # TODO too slow!
top_k=10,
min_p=0.2,
max_tokens=16)
s = time()
_ = llm.generate(prompts, sampling_params)
run1 = time() - s
# Second request with different params, but for which we
# compiled for in previous eager iteration.
sampling_params = SamplingParams(temperature=0.1,
top_k=12,
min_p=0.8,
max_tokens=24)
s = time()
_ = llm.generate(prompts, sampling_params)
run2 = time() - s
# Much faster after compiling
assert run1 * 0.1 > run2
print("TIMES", run1, run2)
# Third request with min_p set to "None". It will not trigger
# recompilation as a default 0 value will be used.
sampling_params = SamplingParams(max_tokens=24, temperature=0.0)
s = time()
_ = llm.generate(prompts, sampling_params)
run3 = time() - s
assert run1 * 0.1 > run3
print("TIMES", run1, run3)
@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
@pytest.mark.skipif(not current_platform.is_tpu(),
reason="This test needs a TPU")
def test_sampler_different(model_name: str):
"""
Test significantly different sampling params to assert the model produces
different results.
"""
llm = LLM(
model_name,
enforce_eager=True,
max_num_seqs=1,
max_model_len=64,
# TODO: setting to 0.5 or it will go OOM
gpu_memory_utilization=0.5)
prompts = [
"Write a short story about a robot that dreams for the first time."
]
sampling_params = SamplingParams(temperature=0.9, min_p=0.2, max_tokens=64)
output = llm.generate(prompts, sampling_params)
sampling_params = SamplingParams(temperature=0.1, min_p=0.8, max_tokens=64)
output2 = llm.generate(prompts, sampling_params)
assert output[0].outputs[0].text != output2[0].outputs[0].text
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment