Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
53076d70
Commit
53076d70
authored
Mar 24, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.2' into v0.8.2-ori
parents
322a0be6
9c5c81b0
Changes
219
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
983 additions
and
426 deletions
+983
-426
tests/spec_decode/e2e/test_integration.py
tests/spec_decode/e2e/test_integration.py
+28
-18
tests/spec_decode/e2e/test_integration_dist_tp2.py
tests/spec_decode/e2e/test_integration_dist_tp2.py
+40
-37
tests/spec_decode/e2e/test_integration_dist_tp4.py
tests/spec_decode/e2e/test_integration_dist_tp4.py
+13
-15
tests/spec_decode/e2e/test_logprobs.py
tests/spec_decode/e2e/test_logprobs.py
+102
-88
tests/spec_decode/e2e/test_medusa_correctness.py
tests/spec_decode/e2e/test_medusa_correctness.py
+55
-40
tests/spec_decode/e2e/test_mlp_correctness.py
tests/spec_decode/e2e/test_mlp_correctness.py
+54
-35
tests/spec_decode/e2e/test_mtp_correctness.py
tests/spec_decode/e2e/test_mtp_correctness.py
+39
-25
tests/spec_decode/e2e/test_multistep_correctness.py
tests/spec_decode/e2e/test_multistep_correctness.py
+112
-65
tests/spec_decode/e2e/test_ngram_correctness.py
tests/spec_decode/e2e/test_ngram_correctness.py
+88
-69
tests/spec_decode/e2e/test_seed.py
tests/spec_decode/e2e/test_seed.py
+5
-5
tests/tokenization/test_tokenizer_group.py
tests/tokenization/test_tokenizer_group.py
+9
-18
tests/utils.py
tests/utils.py
+1
-1
tests/v1/core/test_kv_cache_utils.py
tests/v1/core/test_kv_cache_utils.py
+109
-1
tests/v1/core/test_scheduler.py
tests/v1/core/test_scheduler.py
+2
-1
tests/v1/e2e/test_ngram_spec_decode.py
tests/v1/e2e/test_ngram_spec_decode.py
+10
-6
tests/v1/engine/test_engine_core.py
tests/v1/engine/test_engine_core.py
+16
-0
tests/v1/entrypoints/llm/test_struct_output_generate.py
tests/v1/entrypoints/llm/test_struct_output_generate.py
+44
-1
tests/v1/spec_decode/test_ngram.py
tests/v1/spec_decode/test_ngram.py
+52
-1
tests/v1/tpu/test_mha_attn.py
tests/v1/tpu/test_mha_attn.py
+109
-0
tests/v1/tpu/test_sampler.py
tests/v1/tpu/test_sampler.py
+95
-0
No files found.
tests/spec_decode/e2e/test_integration.py
View file @
53076d70
...
@@ -23,8 +23,10 @@ MAIN_MODEL = "JackFram/llama-68m"
...
@@ -23,8 +23,10 @@ MAIN_MODEL = "JackFram/llama-68m"
[
[
{
{
# Identical models.
# Identical models.
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_config"
:
{
"num_speculative_tokens"
:
5
,
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
},
},
])
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
...
@@ -57,26 +59,33 @@ def test_spec_decode_cuda_graph(vllm_runner, common_llm_kwargs,
...
@@ -57,26 +59,33 @@ def test_spec_decode_cuda_graph(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
"enforce_eager"
:
True
,
}])
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[])
{
"speculative_model"
:
"LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"
,
"num_speculative_tokens"
:
5
,
},
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
"test_llm_kwargs"
,
[
[
# Explicitly specify draft model quantization
# Explicitly specify draft model quantization
{
{
"speculative_model_quantization"
:
"gptq"
,
"speculative_config"
:
{
"model"
:
"LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"
,
"num_speculative_tokens"
:
5
,
"quantization"
:
"gptq"
,
},
},
},
# Explicitly specify GPTQ-based draft model to use marlin quantization
# Explicitly specify GPTQ-based draft model to use marlin quantization
{
{
"speculative_model_quantization"
:
"marlin"
,
"speculative_config"
:
{
"model"
:
"LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"
,
"num_speculative_tokens"
:
5
,
"quantization"
:
"marlin"
,
},
},
},
# Not explicitly specify draft model quantization
# Not explicitly specify draft model quantization
{
{
"speculative_model_quantization"
:
None
,
"speculative_config"
:
{
"model"
:
"LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"
,
"num_speculative_tokens"
:
5
,
"quantization"
:
None
,
},
},
},
])
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
...
@@ -107,15 +116,16 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
...
@@ -107,15 +116,16 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
"enforce_eager"
:
True
,
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
3
,
}])
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
[{
"speculative_config"
:
{
"speculative_disable_mqa_scorer"
:
True
,
"model"
:
"JackFram/llama-68m"
,
}])
"num_speculative_tokens"
:
3
,
"disable_mqa_scorer"
:
True
,
},
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"output_len"
,
"output_len"
,
...
@@ -127,7 +137,7 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
...
@@ -127,7 +137,7 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
def
test_mqa_scorer
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
def
test_mqa_scorer
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
output_len
:
int
,
seed
:
int
):
"""Verify that
ngram
speculative decoding generates the same output
"""Verify that speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
with batch expansion scorer and mqa scorer.
"""
"""
run_equality_correctness_test
(
vllm_runner
,
run_equality_correctness_test
(
vllm_runner
,
...
...
tests/spec_decode/e2e/test_integration_dist_tp2.py
View file @
53076d70
...
@@ -27,18 +27,19 @@ from .conftest import run_equality_correctness_test_tp
...
@@ -27,18 +27,19 @@ from .conftest import run_equality_correctness_test_tp
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[[]])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[[]])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
[
[
"--speculative-model"
,
"--speculative_config"
,
"JackFram/llama-68m"
,
str
({
"--num-speculative-tokens"
,
"model"
:
"JackFram/llama-68m"
,
"3"
,
"num_speculative_tokens"
:
3
,
}),
],
],
[
[
"--speculative
-model
"
,
"--speculative
_config
"
,
"[ngram]"
,
str
({
"--num-speculative-tokens
"
,
"model"
:
"ngram
"
,
"5"
,
"num_speculative_tokens"
:
5
,
"--ngram-
prompt
-
lookup
-
max"
,
"
prompt
_
lookup
_
max"
:
3
,
"3"
,
})
,
],
],
])
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
...
@@ -83,23 +84,24 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
...
@@ -83,23 +84,24 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
]])
]])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[[]])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[[]])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[[]])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[[]])
@
pytest
.
mark
.
parametrize
(
"model, test_llm_kwargs"
,
@
pytest
.
mark
.
parametrize
(
[(
"JackFram/llama-68m"
,
[
"model, test_llm_kwargs"
,
"--speculative-model"
,
[(
"JackFram/llama-68m"
,
[
"JackFram/llama-68m"
,
"--speculative_config"
,
"--num_speculative-tokens"
,
str
({
"5"
,
"model"
:
"JackFram/llama-68m"
,
"--speculative-draft-tensor-parallel-size"
,
"num_speculative_tokens"
:
5
,
"1"
,
"draft_tensor_parallel_size"
:
1
,
]),
}),
(
"ibm-granite/granite-3b-code-instruct"
,
[
]),
"--speculative-model"
,
(
"ibm-granite/granite-3b-code-instruct"
,
[
"ibm-granite/granite-3b-code-instruct"
,
"--speculative_config"
,
"--num_speculative-tokens"
,
str
({
"5"
,
"model"
:
"ibm-granite/granite-3b-code-instruct"
,
"--speculative-draft-tensor-parallel-size"
,
"num_speculative_tokens"
:
5
,
"1"
,
"draft_tensor_parallel_size"
:
1
,
])])
}),
])])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_draft_model_tp_lt_target_model_tp2
(
model
,
common_llm_kwargs
,
def
test_draft_model_tp_lt_target_model_tp2
(
model
,
common_llm_kwargs
,
...
@@ -144,18 +146,19 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
...
@@ -144,18 +146,19 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[[]])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[[]])
@
pytest
.
mark
.
parametrize
(
"model, test_llm_kwargs"
,
@
pytest
.
mark
.
parametrize
(
"model, test_llm_kwargs"
,
[(
"JackFram/llama-68m"
,
[
[(
"JackFram/llama-68m"
,
[
"--speculative-model"
,
"--speculative_config"
,
"JackFram/llama-68m"
,
str
({
"--num_speculative-tokens"
,
"model"
:
"JackFram/llama-68m"
,
"3"
,
"num_speculative_tokens"
:
3
,
}),
]),
]),
(
"JackFram/llama-68m"
,
[
(
"JackFram/llama-68m"
,
[
"--speculative
-model
"
,
"--speculative
_config
"
,
"JackFram/llama-68m"
,
str
({
"--num_speculative-tokens
"
,
"model"
:
"JackFram/llama-68m
"
,
"3"
,
"num_speculative_tokens"
:
3
,
"--speculative-
draft
-
tensor
-
parallel
-
size"
,
"
draft
_
tensor
_
parallel
_
size"
:
1
,
"1"
,
})
,
])])
])])
@
pytest
.
mark
.
parametrize
(
"logprobs"
,
[
None
,
2
])
@
pytest
.
mark
.
parametrize
(
"logprobs"
,
[
None
,
2
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
...
...
tests/spec_decode/e2e/test_integration_dist_tp4.py
View file @
53076d70
...
@@ -24,12 +24,7 @@ SPEC_MODEL = "JackFram/llama-68m"
...
@@ -24,12 +24,7 @@ SPEC_MODEL = "JackFram/llama-68m"
"4"
,
"4"
,
]])
]])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
[
[],
"--speculative-model"
,
f
"
{
SPEC_MODEL
}
"
,
"--num-speculative-tokens"
,
"5"
,
],
])
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[[]])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[[]])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -37,8 +32,12 @@ SPEC_MODEL = "JackFram/llama-68m"
...
@@ -37,8 +32,12 @@ SPEC_MODEL = "JackFram/llama-68m"
[
[
#TODO(wooyeon): add spec_draft_dp=2 case
#TODO(wooyeon): add spec_draft_dp=2 case
[
[
"--speculative-draft-tensor-parallel-size"
,
"--speculative_config"
,
"1"
,
str
({
"model"
:
f
"
{
SPEC_MODEL
}
"
,
"num_speculative_tokens"
:
5
,
"draft_tensor_parallel_size"
:
1
,
}),
],
],
])
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
...
@@ -78,15 +77,14 @@ def test_draft_model_tp_lt_target_model_tp4(common_llm_kwargs,
...
@@ -78,15 +77,14 @@ def test_draft_model_tp_lt_target_model_tp4(common_llm_kwargs,
"test_llm_kwargs"
,
"test_llm_kwargs"
,
[
[
[
[
"--speculative-model"
,
f
"
{
SPEC_MODEL
}
"
,
"--num-speculative-tokens"
,
"5"
,
# Artificially limit the draft model max model len; this forces vLLM
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
# to skip speculation once the sequences grow beyond 32-k tokens.
"--speculative-max-model-len"
,
"--speculative_config"
,
"32"
,
str
({
"model"
:
f
"
{
SPEC_MODEL
}
"
,
"num_speculative_tokens"
:
5
,
"max_model_len"
:
32
,
}),
],
],
])
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
...
...
tests/spec_decode/e2e/test_logprobs.py
View file @
53076d70
...
@@ -20,16 +20,19 @@ from .conftest import run_equality_correctness_test
...
@@ -20,16 +20,19 @@ from .conftest import run_equality_correctness_test
}])
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
[{
"speculative_config"
:
{
"speculative_model"
:
"JackFram/llama-68m"
,
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
3
,
"num_speculative_tokens"
:
3
,
"disable_logprobs_during_spec_decoding"
:
False
,
"disable_logprobs"
:
False
,
},
{
},
"speculative_model"
:
"JackFram/llama-68m"
,
},
{
"num_speculative_tokens"
:
3
,
"speculative_config"
:
{
"disable_logprobs_during_spec_decoding"
:
True
,
"model"
:
"JackFram/llama-68m"
,
}])
"num_speculative_tokens"
:
3
,
"disable_logprobs"
:
True
,
},
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"output_len"
,
"output_len"
,
...
@@ -48,19 +51,20 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
...
@@ -48,19 +51,20 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
as well as with and without chunked prefill.
as well as with and without chunked prefill.
"""
"""
maybe_enable_chunked_prefill
(
prefill_chunk_size
,
common_llm_kwargs
)
maybe_enable_chunked_prefill
(
prefill_chunk_size
,
common_llm_kwargs
)
run_equality_correctness_test
(
vllm_runner
,
run_equality_correctness_test
(
common_llm_kwargs
,
vllm_runner
,
per_test_common_llm_kwargs
,
common_llm_kwargs
,
baseline_llm_kwargs
,
per_test_common_llm_kwargs
,
test_llm_kwargs
,
baseline_llm_kwargs
,
batch_size
,
test_llm_kwargs
,
output_len
,
batch_size
,
seed
,
output_len
,
temperature
=
0.0
,
seed
,
logprobs
=
logprobs
,
temperature
=
0.0
,
prompt_logprobs
=
logprobs
,
logprobs
=
logprobs
,
disable_logprobs
=
test_llm_kwargs
[
prompt_logprobs
=
logprobs
,
'disable_logprobs_during_spec_decoding'
])
disable_logprobs
=
test_llm_kwargs
[
"speculative_config"
]
[
"disable_logprobs"
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -73,16 +77,19 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
...
@@ -73,16 +77,19 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
}])
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
[{
"speculative_config"
:
{
"speculative_model"
:
"JackFram/llama-160m"
,
"model"
:
"JackFram/llama-160m"
,
"num_speculative_tokens"
:
3
,
"num_speculative_tokens"
:
3
,
"disable_logprobs_during_spec_decoding"
:
False
,
"disable_logprobs"
:
False
,
},
{
},
"speculative_model"
:
"JackFram/llama-160m"
,
},
{
"num_speculative_tokens"
:
6
,
"speculative_config"
:
{
"disable_logprobs_during_spec_decoding"
:
False
,
"model"
:
"JackFram/llama-160m"
,
}])
"num_speculative_tokens"
:
6
,
"disable_logprobs"
:
False
,
},
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"output_len"
,
"output_len"
,
...
@@ -98,18 +105,19 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
...
@@ -98,18 +105,19 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
output_len
:
int
,
seed
:
int
,
logprobs
:
int
):
output_len
:
int
,
seed
:
int
,
logprobs
:
int
):
"""Veriy logprob greedy equality with different speculation lens.
"""Veriy logprob greedy equality with different speculation lens.
"""
"""
run_equality_correctness_test
(
vllm_runner
,
run_equality_correctness_test
(
common_llm_kwargs
,
vllm_runner
,
per_test_common_llm_kwargs
,
common_llm_kwargs
,
baseline_llm_kwargs
,
per_test_common_llm_kwargs
,
test_llm_kwargs
,
baseline_llm_kwargs
,
batch_size
,
test_llm_kwargs
,
output_len
,
batch_size
,
seed
,
output_len
,
temperature
=
0.0
,
seed
,
logprobs
=
logprobs
,
temperature
=
0.0
,
disable_logprobs
=
test_llm_kwargs
[
logprobs
=
logprobs
,
'disable_logprobs_during_spec_decoding'
])
disable_logprobs
=
test_llm_kwargs
[
"speculative_config"
]
[
"disable_logprobs"
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -125,13 +133,15 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
...
@@ -125,13 +133,15 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
"test_llm_kwargs"
,
[{
[{
"speculative_model"
:
"JackFram/llama-160m"
,
"speculative_config"
:
{
"num_speculative_tokens"
:
3
,
"model"
:
"JackFram/llama-160m"
,
"disable_logprobs_during_spec_decoding"
:
False
,
"num_speculative_tokens"
:
3
,
"disable_logprobs"
:
False
,
# Artificially limit the draft model max model len; this forces vLLM
# Artificially limit the draft model max model len; this forces
# to skip speculation once the sequences grow beyond 32-k tokens.
# vLLM to skip speculation once the sequences grow beyond 32-k
"speculative_max_model_len"
:
32
,
# tokens.
"max_model_len"
:
32
,
},
}])
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -149,18 +159,19 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
...
@@ -149,18 +159,19 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
seed
:
int
,
logprobs
:
int
):
seed
:
int
,
logprobs
:
int
):
"""Verify logprobs greedy equality when some sequences skip speculation.
"""Verify logprobs greedy equality when some sequences skip speculation.
"""
"""
run_equality_correctness_test
(
vllm_runner
,
run_equality_correctness_test
(
common_llm_kwargs
,
vllm_runner
,
per_test_common_llm_kwargs
,
common_llm_kwargs
,
baseline_llm_kwargs
,
per_test_common_llm_kwargs
,
test_llm_kwargs
,
baseline_llm_kwargs
,
batch_size
,
test_llm_kwargs
,
output_len
,
batch_size
,
seed
,
output_len
,
temperature
=
0.0
,
seed
,
logprobs
=
logprobs
,
temperature
=
0.0
,
disable_logprobs
=
test_llm_kwargs
[
logprobs
=
logprobs
,
'disable_logprobs_during_spec_decoding'
])
disable_logprobs
=
test_llm_kwargs
[
"speculative_config"
]
[
"disable_logprobs"
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -173,12 +184,13 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
...
@@ -173,12 +184,13 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
}])
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
[{
"speculative_config"
:
{
"speculative_model"
:
"JackFram/llama-160m"
,
"model"
:
"JackFram/llama-160m"
,
"num_speculative_tokens"
:
3
,
"num_speculative_tokens"
:
3
,
"disable_logprobs_during_spec_decoding"
:
False
,
"disable_logprobs"
:
False
,
}])
},
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"output_len"
,
"output_len"
,
...
@@ -248,12 +260,13 @@ def test_logprobs_temp_1(vllm_runner, common_llm_kwargs,
...
@@ -248,12 +260,13 @@ def test_logprobs_temp_1(vllm_runner, common_llm_kwargs,
}])
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
[{
"speculative_config"
:
{
"speculative_model"
:
"JackFram/llama-68m"
,
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
3
,
"num_speculative_tokens"
:
3
,
"disable_logprobs_during_spec_decoding"
:
True
,
"disable_logprobs"
:
True
,
}])
},
}])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -270,15 +283,16 @@ def test_logprobs_disabled(vllm_runner, common_llm_kwargs,
...
@@ -270,15 +283,16 @@ def test_logprobs_disabled(vllm_runner, common_llm_kwargs,
"""Check the behavior when logprobs are disabled.
"""Check the behavior when logprobs are disabled.
Token choices should match with the base model.
Token choices should match with the base model.
"""
"""
run_equality_correctness_test
(
vllm_runner
,
run_equality_correctness_test
(
common_llm_kwargs
,
vllm_runner
,
per_test_common_llm_kwargs
,
common_llm_kwargs
,
baseline_llm_kwargs
,
per_test_common_llm_kwargs
,
test_llm_kwargs
,
baseline_llm_kwargs
,
batch_size
,
test_llm_kwargs
,
output_len
,
batch_size
,
seed
,
output_len
,
temperature
=
0.0
,
seed
,
logprobs
=
logprobs
,
temperature
=
0.0
,
disable_logprobs
=
test_llm_kwargs
[
logprobs
=
logprobs
,
'disable_logprobs_during_spec_decoding'
])
disable_logprobs
=
test_llm_kwargs
[
"speculative_config"
]
[
"disable_logprobs"
])
tests/spec_decode/e2e/test_medusa_correctness.py
View file @
53076d70
...
@@ -60,8 +60,10 @@ PRECISION = "float32"
...
@@ -60,8 +60,10 @@ PRECISION = "float32"
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
{
"speculative_model"
:
SPEC_MODEL
,
"speculative_config"
:
{
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
},
},
},
])
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
...
@@ -107,14 +109,18 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
...
@@ -107,14 +109,18 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
{
"speculative_model"
:
SPEC_MODEL
,
"speculative_config"
:
{
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"model"
:
SPEC_MODEL
,
"disable_logprobs_during_spec_decoding"
:
False
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"disable_logprobs"
:
False
,
},
},
},
{
{
"speculative_model"
:
SPEC_MODEL
,
"speculative_config"
:
{
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"model"
:
SPEC_MODEL
,
"disable_logprobs_during_spec_decoding"
:
True
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"disable_logprobs"
:
True
,
},
},
},
])
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
...
@@ -132,19 +138,20 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
...
@@ -132,19 +138,20 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
prefill_chunk_size
:
int
):
prefill_chunk_size
:
int
):
"""Verify greedy equality with different batch size."""
"""Verify greedy equality with different batch size."""
maybe_enable_chunked_prefill
(
prefill_chunk_size
,
test_llm_kwargs
)
maybe_enable_chunked_prefill
(
prefill_chunk_size
,
test_llm_kwargs
)
run_equality_correctness_test
(
vllm_runner
,
run_equality_correctness_test
(
common_llm_kwargs
,
vllm_runner
,
per_test_common_llm_kwargs
,
common_llm_kwargs
,
baseline_llm_kwargs
,
per_test_common_llm_kwargs
,
test_llm_kwargs
,
baseline_llm_kwargs
,
batch_size
,
test_llm_kwargs
,
max_output_len
=
output_len
,
batch_size
,
seed
=
seed
,
max_output_len
=
output_len
,
temperature
=
0.0
,
seed
=
seed
,
logprobs
=
logprobs
,
temperature
=
0.0
,
prompt_logprobs
=
logprobs
,
logprobs
=
logprobs
,
disable_logprobs
=
test_llm_kwargs
[
prompt_logprobs
=
logprobs
,
'disable_logprobs_during_spec_decoding'
])
disable_logprobs
=
test_llm_kwargs
[
"speculative_config"
]
[
"disable_logprobs"
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -165,8 +172,10 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
...
@@ -165,8 +172,10 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
{
"speculative_model"
:
SPEC_MODEL
,
"speculative_config"
:
{
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
},
},
},
])
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
...
@@ -214,8 +223,10 @@ def test_medusa_e2e_greedy_correctness_cuda_graph(
...
@@ -214,8 +223,10 @@ def test_medusa_e2e_greedy_correctness_cuda_graph(
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
{
"speculative_model"
:
SPEC_MODEL
,
"speculative_config"
:
{
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
},
},
},
])
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -264,8 +275,10 @@ def test_medusa_e2e_greedy_correctness_with_preemption(
...
@@ -264,8 +275,10 @@ def test_medusa_e2e_greedy_correctness_with_preemption(
"test_llm_kwargs"
,
"test_llm_kwargs"
,
[
[
{
{
"speculative_model"
:
SPEC_MODEL
,
"speculative_config"
:
{
"num_speculative_tokens"
:
k
,
"model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
k
,
},
}
}
# Try a range of num. speculative tokens
# Try a range of num. speculative tokens
for
k
in
range
(
1
,
1
+
MAX_SPEC_TOKENS
)
for
k
in
range
(
1
,
1
+
MAX_SPEC_TOKENS
)
...
@@ -312,12 +325,13 @@ def test_medusa_different_k(vllm_runner, common_llm_kwargs,
...
@@ -312,12 +325,13 @@ def test_medusa_different_k(vllm_runner, common_llm_kwargs,
}])
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
[{
"speculative_config"
:
{
"speculative_model"
:
SPEC_MODEL
,
"model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"speculative_disable_by_batch_size"
:
4
"disable_by_batch_size"
:
4
,
}])
},
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"output_len"
,
"output_len"
,
...
@@ -359,16 +373,17 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
...
@@ -359,16 +373,17 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
# Main model
# Main model
"model_name"
:
MAIN_MODEL
,
"model_name"
:
MAIN_MODEL
,
"speculative_model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"speculative_disable_by_batch_size"
:
4
}])
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
[{
"speculative_config"
:
{
"speculative_disable_mqa_scorer"
:
True
,
"model"
:
SPEC_MODEL
,
}])
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"disable_by_batch_size"
:
4
,
"disable_mqa_scorer"
:
True
,
},
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"output_len"
,
"output_len"
,
...
...
tests/spec_decode/e2e/test_mlp_correctness.py
View file @
53076d70
...
@@ -62,7 +62,9 @@ PRECISION = "float32"
...
@@ -62,7 +62,9 @@ PRECISION = "float32"
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
{
"speculative_model"
:
SPEC_MODEL
,
"speculative_config"
:
{
"model"
:
SPEC_MODEL
,
},
},
},
])
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
...
@@ -108,12 +110,16 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
...
@@ -108,12 +110,16 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
{
"speculative_model"
:
SPEC_MODEL
,
"speculative_config"
:
{
"disable_logprobs_during_spec_decoding"
:
False
,
"model"
:
SPEC_MODEL
,
"disable_logprobs"
:
False
,
},
},
},
{
{
"speculative_model"
:
SPEC_MODEL
,
"speculative_config"
:
{
"disable_logprobs_during_spec_decoding"
:
True
,
"model"
:
SPEC_MODEL
,
"disable_logprobs"
:
True
,
},
},
},
])
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
8
])
...
@@ -133,19 +139,20 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
...
@@ -133,19 +139,20 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
# up sampling different tokens at the tail (ie top tokens don't change).
# up sampling different tokens at the tail (ie top tokens don't change).
# TL;DR: sd+cp == org+cp but sd+cp != org..is this expected?
# TL;DR: sd+cp == org+cp but sd+cp != org..is this expected?
maybe_enable_chunked_prefill
(
prefill_chunk_size
,
baseline_llm_kwargs
)
maybe_enable_chunked_prefill
(
prefill_chunk_size
,
baseline_llm_kwargs
)
run_equality_correctness_test
(
vllm_runner
,
run_equality_correctness_test
(
common_llm_kwargs
,
vllm_runner
,
per_test_common_llm_kwargs
,
common_llm_kwargs
,
baseline_llm_kwargs
,
per_test_common_llm_kwargs
,
test_llm_kwargs
,
baseline_llm_kwargs
,
batch_size
,
test_llm_kwargs
,
max_output_len
=
output_len
,
batch_size
,
seed
=
seed
,
max_output_len
=
output_len
,
temperature
=
0.0
,
seed
=
seed
,
logprobs
=
logprobs
,
temperature
=
0.0
,
prompt_logprobs
=
logprobs
,
logprobs
=
logprobs
,
disable_logprobs
=
test_llm_kwargs
[
prompt_logprobs
=
logprobs
,
'disable_logprobs_during_spec_decoding'
])
disable_logprobs
=
test_llm_kwargs
[
"speculative_config"
]
[
"disable_logprobs"
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -167,7 +174,9 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
...
@@ -167,7 +174,9 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
{
"speculative_model"
:
SPEC_MODEL
,
"speculative_config"
:
{
"model"
:
SPEC_MODEL
,
},
},
},
])
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
2048
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
2048
])
...
@@ -209,8 +218,10 @@ def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
...
@@ -209,8 +218,10 @@ def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
# Main model
# Main model
"model_name"
:
MAIN_MODEL
,
"model_name"
:
MAIN_MODEL
,
# Speculative model
# Speculative config
"speculative_model"
:
SPEC_MODEL
,
"speculative_config"
:
{
"model"
:
SPEC_MODEL
,
},
}])
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{
"seed"
:
1
}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{
"seed"
:
1
}])
...
@@ -274,7 +285,9 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
...
@@ -274,7 +285,9 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
{
"speculative_model"
:
SPEC_MODEL
,
"speculative_config"
:
{
"model"
:
SPEC_MODEL
,
},
},
},
])
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -326,7 +339,9 @@ def test_mlp_e2e_greedy_correctness_with_preemption(
...
@@ -326,7 +339,9 @@ def test_mlp_e2e_greedy_correctness_with_preemption(
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
{
"speculative_model"
:
SPEC_MODEL
,
"speculative_config"
:
{
"model"
:
SPEC_MODEL
,
},
},
},
])
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -382,8 +397,10 @@ def test_mlp_e2e_greedy_correctness_with_padding(
...
@@ -382,8 +397,10 @@ def test_mlp_e2e_greedy_correctness_with_padding(
"test_llm_kwargs"
,
"test_llm_kwargs"
,
[
[
{
{
"speculative_model"
:
SPEC_MODEL
,
"speculative_config"
:
{
"num_speculative_tokens"
:
k
,
"model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
k
,
},
}
}
# Try a range of num. speculative tokens
# Try a range of num. speculative tokens
for
k
in
range
(
1
,
1
+
MAX_SPEC_TOKENS
)
for
k
in
range
(
1
,
1
+
MAX_SPEC_TOKENS
)
...
@@ -430,11 +447,12 @@ def test_mlp_different_k(vllm_runner, common_llm_kwargs,
...
@@ -430,11 +447,12 @@ def test_mlp_different_k(vllm_runner, common_llm_kwargs,
}])
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
[{
"speculative_config"
:
{
"speculative_model"
:
SPEC_MODEL
,
"model"
:
SPEC_MODEL
,
"speculative_disable_by_batch_size"
:
4
"disable_by_batch_size"
:
4
,
}])
},
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"output_len"
,
"output_len"
,
...
@@ -475,14 +493,15 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
...
@@ -475,14 +493,15 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
"enforce_eager"
:
True
,
"speculative_model"
:
SPEC_MODEL
,
}])
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
[{
"speculative_config"
:
{
"speculative_disable_mqa_scorer"
:
True
,
"model"
:
SPEC_MODEL
,
}])
"disable_mqa_scorer"
:
True
,
},
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"output_len"
,
"output_len"
,
...
...
tests/spec_decode/e2e/test_mtp_correctness.py
View file @
53076d70
...
@@ -57,7 +57,9 @@ PRECISION = "bfloat16"
...
@@ -57,7 +57,9 @@ PRECISION = "bfloat16"
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
{
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"speculative_config"
:
{
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
},
},
},
])
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
...
@@ -99,12 +101,16 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
...
@@ -99,12 +101,16 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
{
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"speculative_config"
:
{
"disable_logprobs_during_spec_decoding"
:
False
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"disable_logprobs"
:
False
,
},
},
},
{
{
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"speculative_config"
:
{
"disable_logprobs_during_spec_decoding"
:
True
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"disable_logprobs"
:
True
,
},
},
},
])
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
...
@@ -119,18 +125,19 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
...
@@ -119,18 +125,19 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
,
logprobs
:
int
):
logprobs
:
int
):
run_equality_correctness_test
(
vllm_runner
,
run_equality_correctness_test
(
common_llm_kwargs
,
vllm_runner
,
per_test_common_llm_kwargs
,
common_llm_kwargs
,
baseline_llm_kwargs
,
per_test_common_llm_kwargs
,
test_llm_kwargs
,
baseline_llm_kwargs
,
batch_size
,
test_llm_kwargs
,
output_len
,
batch_size
,
seed
,
output_len
,
logprobs
=
logprobs
,
seed
,
prompt_logprobs
=
logprobs
,
logprobs
=
logprobs
,
disable_logprobs
=
test_llm_kwargs
[
prompt_logprobs
=
logprobs
,
'disable_logprobs_during_spec_decoding'
])
disable_logprobs
=
test_llm_kwargs
[
"speculative_config"
]
[
"disable_logprobs"
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -152,7 +159,9 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
...
@@ -152,7 +159,9 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
{
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"speculative_config"
:
{
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
},
},
},
])
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
...
@@ -198,7 +207,9 @@ def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
...
@@ -198,7 +207,9 @@ def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
{
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"speculative_config"
:
{
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
},
},
},
])
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -243,7 +254,9 @@ def test_mtp_e2e_greedy_correctness_with_preemption(
...
@@ -243,7 +254,9 @@ def test_mtp_e2e_greedy_correctness_with_preemption(
"test_llm_kwargs"
,
"test_llm_kwargs"
,
[
[
{
{
"num_speculative_tokens"
:
k
,
"speculative_config"
:
{
"num_speculative_tokens"
:
k
,
},
}
}
# Try a range of num. speculative tokens
# Try a range of num. speculative tokens
for
k
in
range
(
1
,
1
+
MAX_SPEC_TOKENS
)
for
k
in
range
(
1
,
1
+
MAX_SPEC_TOKENS
)
...
@@ -286,11 +299,12 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs,
...
@@ -286,11 +299,12 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs,
}])
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
[{
"speculative_config"
:
{
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"speculative_disable_by_batch_size"
:
4
"disable_by_batch_size"
:
4
}])
},
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"output_len"
,
"output_len"
,
...
...
tests/spec_decode/e2e/test_multistep_correctness.py
View file @
53076d70
...
@@ -61,15 +61,19 @@ from .conftest import (get_output_from_llm_generator,
...
@@ -61,15 +61,19 @@ from .conftest import (get_output_from_llm_generator,
"per_test_common_llm_kwargs"
,
"per_test_common_llm_kwargs"
,
[
[
{
{
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_config"
:
{
"num_speculative_tokens"
:
5
,
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
"enable_chunked_prefill"
:
False
,
"enable_chunked_prefill"
:
False
,
},
},
{
{
# Chunked prefill enabled with small value
# Chunked prefill enabled with small value
# to make sure we get mixed batches.
# to make sure we get mixed batches.
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_config"
:
{
"num_speculative_tokens"
:
5
,
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
"enable_chunked_prefill"
:
True
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
"max_num_seqs"
:
4
...
@@ -148,20 +152,23 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
...
@@ -148,20 +152,23 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
},
},
])
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
[{
"speculative_config"
:
{
"speculative_model"
:
"JackFram/llama-68m"
,
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
False
,
"disable_logprobs"
:
False
,
"disable_logprobs_during_spec_decoding"
:
False
},
},
{
"enable_chunked_prefill"
:
False
,
"speculative_model"
:
"JackFram/llama-68m"
,
},
{
"num_speculative_tokens"
:
3
,
"speculative_config"
:
{
"enable_chunked_prefill"
:
True
,
"model"
:
"JackFram/llama-68m"
,
"max_num_batched_tokens"
:
4
,
"num_speculative_tokens"
:
3
,
"max_num_seqs"
:
4
,
"disable_logprobs"
:
False
,
"disable_logprobs_during_spec_decoding"
:
False
},
}])
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
,
}])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"output_len"
,
"output_len"
,
[
[
...
@@ -184,7 +191,7 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
...
@@ -184,7 +191,7 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
whether all speculative tokens are accepted.
whether all speculative tokens are accepted.
"""
"""
ensure_all_accepted
=
per_test_common_llm_kwargs
.
get
(
ensure_all_accepted
=
per_test_common_llm_kwargs
.
get
(
"model_name"
)
==
test_llm_kwargs
.
get
(
"speculative_model"
)
"model_name"
)
==
test_llm_kwargs
.
get
(
"speculative_
config"
)[
"
model"
]
run_equality_correctness_test
(
vllm_runner
,
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
per_test_common_llm_kwargs
,
...
@@ -224,13 +231,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
...
@@ -224,13 +231,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
{
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_config"
:
{
"num_speculative_tokens"
:
5
,
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
"enable_chunked_prefill"
:
False
,
"enable_chunked_prefill"
:
False
,
},
},
{
{
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_config"
:
{
"num_speculative_tokens"
:
5
,
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
"enable_chunked_prefill"
:
True
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
"max_num_seqs"
:
4
...
@@ -283,13 +294,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
...
@@ -283,13 +294,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
{
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_config"
:
{
"num_speculative_tokens"
:
5
,
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
"enable_chunked_prefill"
:
False
,
"enable_chunked_prefill"
:
False
,
},
},
{
{
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_config"
:
{
"num_speculative_tokens"
:
5
,
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
"enable_chunked_prefill"
:
True
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
"max_num_seqs"
:
4
...
@@ -336,13 +351,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
...
@@ -336,13 +351,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
{
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_config"
:
{
"num_speculative_tokens"
:
5
,
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
"enable_chunked_prefill"
:
False
,
"enable_chunked_prefill"
:
False
,
},
},
{
{
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_config"
:
{
"num_speculative_tokens"
:
5
,
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
"enable_chunked_prefill"
:
True
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
"max_num_seqs"
:
4
...
@@ -391,13 +410,17 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
...
@@ -391,13 +410,17 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
{
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_config"
:
{
"num_speculative_tokens"
:
5
,
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
"enable_chunked_prefill"
:
False
,
"enable_chunked_prefill"
:
False
,
},
},
{
{
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_config"
:
{
"num_speculative_tokens"
:
5
,
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
"enable_chunked_prefill"
:
True
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
"max_num_seqs"
:
4
...
@@ -449,13 +472,17 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
...
@@ -449,13 +472,17 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
{
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_config"
:
{
"num_speculative_tokens"
:
5
,
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
"enable_chunked_prefill"
:
False
,
"enable_chunked_prefill"
:
False
,
},
},
{
{
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_config"
:
{
"num_speculative_tokens"
:
5
,
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
"enable_chunked_prefill"
:
True
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
"max_num_seqs"
:
4
...
@@ -514,13 +541,17 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
...
@@ -514,13 +541,17 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
{
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_config"
:
{
"num_speculative_tokens"
:
5
,
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
"enable_chunked_prefill"
:
False
,
"enable_chunked_prefill"
:
False
,
},
},
{
{
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_config"
:
{
"num_speculative_tokens"
:
5
,
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
"enable_chunked_prefill"
:
True
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
"max_num_seqs"
:
4
...
@@ -567,21 +598,25 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs,
...
@@ -567,21 +598,25 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs,
"test_llm_kwargs"
,
"test_llm_kwargs"
,
[
[
{
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
# Artificially limit the draft model max model len; this forces vLLM
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
# to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_max_model_len"
:
32
,
"speculative_config"
:
{
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"max_model_len"
:
32
,
},
"enable_chunked_prefill"
:
False
,
"enable_chunked_prefill"
:
False
,
},
},
{
{
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_config"
:
{
"num_speculative_tokens"
:
5
,
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"max_model_len"
:
32
,
},
"enable_chunked_prefill"
:
True
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
,
"max_num_seqs"
:
4
,
"speculative_max_model_len"
:
32
,
},
},
])
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
...
@@ -627,15 +662,19 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs,
...
@@ -627,15 +662,19 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs,
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
{
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_config"
:
{
"num_speculative_tokens"
:
5
,
"model"
:
"JackFram/llama-68m"
,
"speculative_disable_by_batch_size"
:
2
,
"num_speculative_tokens"
:
5
,
"disable_by_batch_size"
:
2
,
},
"enable_chunked_prefill"
:
False
,
"enable_chunked_prefill"
:
False
,
},
},
{
{
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_config"
:
{
"num_speculative_tokens"
:
5
,
"model"
:
"JackFram/llama-68m"
,
"speculative_disable_by_batch_size"
:
2
,
"num_speculative_tokens"
:
5
,
"disable_by_batch_size"
:
2
,
},
"enable_chunked_prefill"
:
True
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
,
"max_num_seqs"
:
4
,
...
@@ -676,15 +715,19 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs,
...
@@ -676,15 +715,19 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs,
"test_llm_kwargs"
,
"test_llm_kwargs"
,
[
[
{
{
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_config"
:
{
"num_speculative_tokens"
:
k
,
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
k
,
},
"enable_chunked_prefill"
:
False
,
"enable_chunked_prefill"
:
False
,
}
}
# Try a range of common k, as well as large speculation.
# Try a range of common k, as well as large speculation.
for
k
in
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
63
]
for
k
in
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
63
]
]
+
[{
]
+
[{
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_config"
:
{
"num_speculative_tokens"
:
k
,
"model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
k
,
},
"enable_chunked_prefill"
:
True
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
,
"max_num_seqs"
:
4
,
...
@@ -729,17 +772,21 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
...
@@ -729,17 +772,21 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
"test_llm_kwargs"
,
"test_llm_kwargs"
,
[
[
{
{
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_config"
:
{
"num_speculative_tokens"
:
k
,
"model"
:
"JackFram/llama-68m"
,
"spec_decoding_acceptance_method"
:
"typical_acceptance_sampler"
,
"num_speculative_tokens"
:
k
,
"acceptance_method"
:
"typical_acceptance_sampler"
,
},
"enable_chunked_prefill"
:
False
"enable_chunked_prefill"
:
False
}
}
# Try a range of common k.
# Try a range of common k.
for
k
in
[
1
,
2
,
3
]
for
k
in
[
1
,
2
,
3
]
]
+
[{
]
+
[{
"speculative_model"
:
"JackFram/llama-68m"
,
"speculative_config"
:
{
"num_speculative_tokens"
:
k
,
"model"
:
"JackFram/llama-68m"
,
"spec_decoding_acceptance_method"
:
"typical_acceptance_sampler"
,
"num_speculative_tokens"
:
k
,
"acceptance_method"
:
"typical_acceptance_sampler"
,
},
"enable_chunked_prefill"
:
True
,
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
"max_num_seqs"
:
4
...
...
tests/spec_decode/e2e/test_ngram_correctness.py
View file @
53076d70
...
@@ -48,16 +48,20 @@ from .conftest import run_equality_correctness_test
...
@@ -48,16 +48,20 @@ from .conftest import run_equality_correctness_test
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
{
"speculative_model"
:
"[ngram]"
,
"speculative_config"
:
{
"num_speculative_tokens"
:
5
,
"method"
:
"ngram"
,
"ngram_prompt_lookup_max"
:
3
,
"num_speculative_tokens"
:
5
,
"speculative_disable_mqa_scorer"
:
False
,
"prompt_lookup_max"
:
3
,
"disable_mqa_scorer"
:
False
,
},
},
},
{
{
"speculative_model"
:
"[ngram]"
,
"speculative_config"
:
{
"num_speculative_tokens"
:
5
,
"method"
:
"ngram"
,
"ngram_prompt_lookup_max"
:
3
,
"num_speculative_tokens"
:
5
,
"speculative_disable_mqa_scorer"
:
True
,
"prompt_lookup_max"
:
3
,
"disable_mqa_scorer"
:
True
,
},
},
},
])
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
...
@@ -101,16 +105,20 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
...
@@ -101,16 +105,20 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
{
"speculative_model"
:
"[ngram]"
,
"speculative_config"
:
{
"num_speculative_tokens"
:
5
,
"method"
:
"ngram"
,
"ngram_prompt_lookup_max"
:
3
,
"num_speculative_tokens"
:
5
,
"disable_logprobs_during_spec_decoding"
:
False
,
"prompt_lookup_max"
:
3
,
"disable_logprobs"
:
False
,
},
},
},
{
{
"speculative_model"
:
"[ngram]"
,
"speculative_config"
:
{
"num_speculative_tokens"
:
5
,
"method"
:
"ngram"
,
"ngram_prompt_lookup_max"
:
3
,
"num_speculative_tokens"
:
5
,
"disable_logprobs_during_spec_decoding"
:
True
,
"prompt_lookup_max"
:
3
,
"disable_logprobs"
:
True
,
},
},
},
])
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
...
@@ -125,19 +133,20 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
...
@@ -125,19 +133,20 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
,
logprobs
:
int
):
logprobs
:
int
):
"""Verify greedy equality on a tiny model with different batch size."""
"""Verify greedy equality on a tiny model with different batch size."""
run_equality_correctness_test
(
vllm_runner
,
run_equality_correctness_test
(
common_llm_kwargs
,
vllm_runner
,
per_test_common_llm_kwargs
,
common_llm_kwargs
,
baseline_llm_kwargs
,
per_test_common_llm_kwargs
,
test_llm_kwargs
,
baseline_llm_kwargs
,
batch_size
,
test_llm_kwargs
,
max_output_len
=
output_len
,
batch_size
,
seed
=
seed
,
max_output_len
=
output_len
,
temperature
=
0.0
,
seed
=
seed
,
logprobs
=
logprobs
,
temperature
=
0.0
,
prompt_logprobs
=
logprobs
,
logprobs
=
logprobs
,
disable_logprobs
=
test_llm_kwargs
[
prompt_logprobs
=
logprobs
,
'disable_logprobs_during_spec_decoding'
])
disable_logprobs
=
test_llm_kwargs
[
"speculative_config"
]
[
"disable_logprobs"
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -159,17 +168,21 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
...
@@ -159,17 +168,21 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
{
"speculative_model"
:
"[ngram]"
,
"speculative_config"
:
{
"num_speculative_tokens"
:
5
,
"method"
:
"ngram"
,
"ngram_prompt_lookup_max"
:
3
,
"num_speculative_tokens"
:
5
,
"prompt_lookup_max"
:
3
,
},
"enable_chunked_prefill"
:
False
,
"enable_chunked_prefill"
:
False
,
},
},
{
{
"speculative_model"
:
"[ngram]"
,
"speculative_config"
:
{
"num_speculative_tokens"
:
5
,
"method"
:
"ngram"
,
"ngram_prompt_lookup_max"
:
3
,
"num_speculative_tokens"
:
5
,
"prompt_lookup_max"
:
3
,
"disable_mqa_scorer"
:
True
,
},
"enable_chunked_prefill"
:
True
,
"enable_chunked_prefill"
:
True
,
"speculative_disable_mqa_scorer"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
"max_num_seqs"
:
4
},
},
...
@@ -214,17 +227,21 @@ def test_ngram_e2e_greedy_correctness_with_preemption(
...
@@ -214,17 +227,21 @@ def test_ngram_e2e_greedy_correctness_with_preemption(
"test_llm_kwargs"
,
"test_llm_kwargs"
,
[
[
{
{
"speculative_model"
:
"[ngram]"
,
"speculative_config"
:
{
"num_speculative_tokens"
:
k
,
"method"
:
"ngram"
,
"ngram_prompt_lookup_max"
:
3
,
"num_speculative_tokens"
:
k
,
"prompt_lookup_max"
:
3
,
},
}
}
# Try a range of common k, as well as large speculation.
# Try a range of common k, as well as large speculation.
for
k
in
[
1
,
3
,
5
]
for
k
in
[
1
,
3
,
5
]
]
+
[
]
+
[
{
{
"speculative_model"
:
"[ngram]"
,
"speculative_config"
:
{
"num_speculative_tokens"
:
k
,
"method"
:
"ngram"
,
"ngram_prompt_lookup_max"
:
1
,
"num_speculative_tokens"
:
k
,
"prompt_lookup_max"
:
1
,
},
}
}
# Try a range of common k, as well as large speculation.
# Try a range of common k, as well as large speculation.
for
k
in
[
1
,
3
,
5
]
for
k
in
[
1
,
3
,
5
]
...
@@ -243,7 +260,7 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
...
@@ -243,7 +260,7 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
seed
:
int
):
seed
:
int
):
"""Verify that ngram speculative decoding produces exact equality
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and
to without spec decode with many different values of k and
different ngram
_
prompt_lookup_max.
different ngram
prompt_lookup_max.
"""
"""
run_equality_correctness_test
(
vllm_runner
,
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
common_llm_kwargs
,
...
@@ -266,22 +283,25 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
...
@@ -266,22 +283,25 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
}])
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
[{
"speculative_config"
:
{
"speculative_model"
:
"[ngram]"
,
"method"
:
"ngram"
,
"num_speculative_tokens"
:
5
,
"num_speculative_tokens"
:
5
,
"ngram_prompt_lookup_max"
:
3
,
"prompt_lookup_max"
:
3
,
"speculative_disable_by_batch_size"
:
4
"disable_by_batch_size"
:
4
},
{
},
"speculative_model"
:
"[ngram]"
,
},
{
"num_speculative_tokens"
:
5
,
"speculative_config"
:
{
"ngram_prompt_lookup_max"
:
3
,
"method"
:
"ngram"
,
"speculative_disable_by_batch_size"
:
4
,
"num_speculative_tokens"
:
5
,
"enable_chunked_prefill"
:
True
,
"prompt_lookup_max"
:
3
,
"speculative_disable_mqa_scorer"
:
True
,
"disable_by_batch_size"
:
4
,
"max_num_batched_tokens"
:
4
,
"disable_mqa_scorer"
:
True
,
"max_num_seqs"
:
4
},
}])
"enable_chunked_prefill"
:
True
,
"max_num_batched_tokens"
:
4
,
"max_num_seqs"
:
4
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"output_len"
,
"output_len"
,
...
@@ -296,7 +316,7 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
...
@@ -296,7 +316,7 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
seed
:
int
):
seed
:
int
):
"""Verify that ngram speculative decoding produces exact equality
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and
to without spec decode with many different values of k and
different ngram
_
prompt_lookup_max.
different ngram
prompt_lookup_max.
"""
"""
run_equality_correctness_test
(
vllm_runner
,
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
common_llm_kwargs
,
...
@@ -316,18 +336,17 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
...
@@ -316,18 +336,17 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
"enforce_eager"
:
True
,
# Required for spec decode.
"speculative_model"
:
"[ngram]"
,
"num_speculative_tokens"
:
5
,
"ngram_prompt_lookup_max"
:
3
,
}])
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
[{
"speculative_config"
:
{
"speculative_disable_mqa_scorer"
:
True
,
"method"
:
"ngram"
,
}])
"num_speculative_tokens"
:
5
,
"prompt_lookup_max"
:
3
,
"disable_mqa_scorer"
:
True
,
},
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"output_len"
,
"output_len"
,
...
...
tests/spec_decode/e2e/test_seed.py
View file @
53076d70
...
@@ -19,11 +19,11 @@ SPEC_MODEL = "JackFram/llama-160m"
...
@@ -19,11 +19,11 @@ SPEC_MODEL = "JackFram/llama-160m"
# Skip cuda graph recording for fast test.
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
"enforce_eager"
:
True
,
# speculative
model
# speculative
config
"speculative_
model"
:
"JackFram/llama-160m"
,
"speculative_
config"
:
{
"model"
:
"JackFram/llama-160m"
,
#
num
speculative
tokens
"
num
_
speculative
_
tokens
"
:
3
,
"num_speculative_tokens"
:
3
,
}
,
}])
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{
"seed"
:
1
}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{
"seed"
:
1
}])
...
...
tests/tokenization/test_tokenizer_group.py
View file @
53076d70
...
@@ -41,10 +41,10 @@ async def test_tokenizer_group(tokenizer_group_type):
...
@@ -41,10 +41,10 @@ async def test_tokenizer_group(tokenizer_group_type):
max_input_length
=
None
,
max_input_length
=
None
,
)
)
assert
reference_tokenizer
.
encode
(
"prompt"
)
==
tokenizer_group
.
encode
(
assert
reference_tokenizer
.
encode
(
"prompt"
)
==
tokenizer_group
.
encode
(
request_id
=
"request_id"
,
prompt
=
"prompt"
,
lora_request
=
None
)
prompt
=
"prompt"
,
lora_request
=
None
)
assert
reference_tokenizer
.
encode
(
assert
reference_tokenizer
.
encode
(
"prompt"
)
==
await
tokenizer_group
.
encode_async
(
"prompt"
)
==
await
tokenizer_group
.
encode_async
(
prompt
=
"prompt"
,
request_id
=
"request_id"
,
prompt
=
"prompt"
,
lora_request
=
None
)
lora_request
=
None
)
assert
isinstance
(
tokenizer_group
.
get_lora_tokenizer
(
None
),
assert
isinstance
(
tokenizer_group
.
get_lora_tokenizer
(
None
),
PreTrainedTokenizerBase
)
PreTrainedTokenizerBase
)
assert
tokenizer_group
.
get_lora_tokenizer
(
assert
tokenizer_group
.
get_lora_tokenizer
(
...
@@ -69,8 +69,7 @@ async def test_tokenizer_group_pool(tokenizer_group_type):
...
@@ -69,8 +69,7 @@ async def test_tokenizer_group_pool(tokenizer_group_type):
# and check that all requests are processed correctly.
# and check that all requests are processed correctly.
num_requests
=
tokenizer_group_pool
.
pool_size
*
5
num_requests
=
tokenizer_group_pool
.
pool_size
*
5
requests
=
[
requests
=
[
tokenizer_group_pool
.
encode_async
(
request_id
=
str
(
i
),
tokenizer_group_pool
.
encode_async
(
prompt
=
f
"prompt
{
i
}
"
,
prompt
=
f
"prompt
{
i
}
"
,
lora_request
=
None
)
lora_request
=
None
)
for
i
in
range
(
num_requests
)
for
i
in
range
(
num_requests
)
]
]
...
@@ -161,12 +160,8 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
...
@@ -161,12 +160,8 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
fail_at
[
0
]
=
1000
fail_at
[
0
]
=
1000
# We should recover successfully.
# We should recover successfully.
await
tokenizer_group_pool
.
encode_async
(
request_id
=
"1"
,
await
tokenizer_group_pool
.
encode_async
(
prompt
=
"prompt"
,
lora_request
=
None
)
prompt
=
"prompt"
,
await
tokenizer_group_pool
.
encode_async
(
prompt
=
"prompt"
,
lora_request
=
None
)
lora_request
=
None
)
await
tokenizer_group_pool
.
encode_async
(
request_id
=
"1"
,
prompt
=
"prompt"
,
lora_request
=
None
)
# Check that we have a new actor
# Check that we have a new actor
assert
len
(
tokenizer_group_pool
.
tokenizer_actors
)
==
len
(
tokenizer_actors
)
assert
len
(
tokenizer_group_pool
.
tokenizer_actors
)
==
len
(
tokenizer_actors
)
...
@@ -184,8 +179,7 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
...
@@ -184,8 +179,7 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
# We should fail after re-initialization.
# We should fail after re-initialization.
with
pytest
.
raises
(
RuntimeError
):
with
pytest
.
raises
(
RuntimeError
):
await
tokenizer_group_pool
.
encode_async
(
request_id
=
"1"
,
await
tokenizer_group_pool
.
encode_async
(
prompt
=
"prompt"
,
prompt
=
"prompt"
,
lora_request
=
None
)
lora_request
=
None
)
# check_health should raise the same thing
# check_health should raise the same thing
...
@@ -206,11 +200,8 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
...
@@ -206,11 +200,8 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
# Prompt too long error
# Prompt too long error
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
await
tokenizer_group_pool
.
encode_async
(
request_id
=
"1"
,
await
tokenizer_group_pool
.
encode_async
(
prompt
=
"prompt"
*
100
,
prompt
=
"prompt"
*
100
,
lora_request
=
None
)
lora_request
=
None
)
await
tokenizer_group_pool
.
encode_async
(
request_id
=
"1"
,
await
tokenizer_group_pool
.
encode_async
(
prompt
=
"prompt"
,
lora_request
=
None
)
prompt
=
"prompt"
,
lora_request
=
None
)
# Actors should stay the same.
# Actors should stay the same.
assert
tokenizer_group_pool
.
tokenizer_actors
==
tokenizer_actors
assert
tokenizer_group_pool
.
tokenizer_actors
==
tokenizer_actors
tests/utils.py
View file @
53076d70
...
@@ -786,7 +786,7 @@ def large_gpu_mark(min_gb: int) -> pytest.MarkDecorator:
...
@@ -786,7 +786,7 @@ def large_gpu_mark(min_gb: int) -> pytest.MarkDecorator:
without enough resources, or called when filtering tests to run directly.
without enough resources, or called when filtering tests to run directly.
"""
"""
try
:
try
:
if
current_platform
.
is_cpu
()
or
current_platform
.
is_openvino
()
:
if
current_platform
.
is_cpu
():
memory_gb
=
0
memory_gb
=
0
else
:
else
:
memory_gb
=
current_platform
.
get_device_total_memory
()
/
GB_bytes
memory_gb
=
current_platform
.
get_device_total_memory
()
/
GB_bytes
...
...
tests/v1/core/test_kv_cache_utils.py
View file @
53076d70
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
pytest
import
pytest
import
torch
from
vllm.multimodal.inputs
import
MultiModalKwargs
from
vllm.multimodal.inputs
import
MultiModalKwargs
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
...
@@ -8,7 +9,10 @@ from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
...
@@ -8,7 +9,10 @@ from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
KVCacheBlock
,
PrefixCachingMetrics
,
KVCacheBlock
,
PrefixCachingMetrics
,
generate_block_hash_extra_keys
,
generate_block_hash_extra_keys
,
hash_block_tokens
,
hash_block_tokens
,
hash_request_tokens
)
hash_request_tokens
,
unify_kv_cache_configs
)
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
,
KVCacheTensor
)
from
vllm.v1.metrics.stats
import
PrefixCacheStats
from
vllm.v1.metrics.stats
import
PrefixCacheStats
from
vllm.v1.request
import
Request
from
vllm.v1.request
import
Request
...
@@ -314,3 +318,107 @@ def test_metrics():
...
@@ -314,3 +318,107 @@ def test_metrics():
assert
metrics
.
aggregated_query_total
==
0
assert
metrics
.
aggregated_query_total
==
0
assert
metrics
.
aggregated_query_hit
==
0
assert
metrics
.
aggregated_query_hit
==
0
assert
not
metrics
.
query_queue
assert
not
metrics
.
query_queue
def
test_unify_kv_cache_configs
():
def
new_kv_cache_spec
(
block_size
=
16
,
num_kv_heads
=
2
,
head_size
=
64
,
dtype
=
torch
.
float32
,
use_mla
=
False
):
return
FullAttentionSpec
(
block_size
=
block_size
,
num_kv_heads
=
num_kv_heads
,
head_size
=
head_size
,
dtype
=
dtype
,
use_mla
=
use_mla
)
same_kv_cache_config
=
[
KVCacheConfig
(
num_blocks
=
10
,
tensors
=
{
"layer1"
:
KVCacheTensor
(
100
),
"layer2"
:
KVCacheTensor
(
100
),
},
kv_cache_groups
=
[
KVCacheGroupSpec
([
"layer1"
],
new_kv_cache_spec
()),
KVCacheGroupSpec
([
"layer2"
],
new_kv_cache_spec
(
num_kv_heads
=
4
)),
],
),
KVCacheConfig
(
num_blocks
=
20
,
tensors
=
{
"layer1"
:
KVCacheTensor
(
100
),
"layer2"
:
KVCacheTensor
(
100
),
},
kv_cache_groups
=
[
KVCacheGroupSpec
([
"layer1"
],
new_kv_cache_spec
()),
KVCacheGroupSpec
([
"layer2"
],
new_kv_cache_spec
(
num_kv_heads
=
4
)),
],
),
]
unify_kv_cache_configs
(
same_kv_cache_config
)
assert
same_kv_cache_config
[
0
].
num_blocks
==
10
assert
same_kv_cache_config
[
1
].
num_blocks
==
10
need_sort_kv_cache_config
=
[
KVCacheConfig
(
num_blocks
=
10
,
tensors
=
{
"layer1"
:
KVCacheTensor
(
100
),
"layer2"
:
KVCacheTensor
(
100
),
},
kv_cache_groups
=
[
KVCacheGroupSpec
([
"layer1"
],
new_kv_cache_spec
()),
KVCacheGroupSpec
([
"layer2"
],
new_kv_cache_spec
(
num_kv_heads
=
4
)),
],
),
KVCacheConfig
(
num_blocks
=
20
,
tensors
=
{
"layer1"
:
KVCacheTensor
(
100
),
"layer2"
:
KVCacheTensor
(
100
),
},
kv_cache_groups
=
[
KVCacheGroupSpec
([
"layer2"
],
new_kv_cache_spec
(
num_kv_heads
=
4
)),
KVCacheGroupSpec
([
"layer1"
],
new_kv_cache_spec
()),
],
),
]
unify_kv_cache_configs
(
need_sort_kv_cache_config
)
assert
need_sort_kv_cache_config
[
0
].
num_blocks
==
10
assert
need_sort_kv_cache_config
[
1
].
num_blocks
==
10
diff_kv_cache_config
=
[
KVCacheConfig
(
num_blocks
=
10
,
tensors
=
{
"layer1"
:
KVCacheTensor
(
100
),
"layer2"
:
KVCacheTensor
(
100
),
},
kv_cache_groups
=
[
KVCacheGroupSpec
([
"layer1"
],
new_kv_cache_spec
()),
KVCacheGroupSpec
([
"layer2"
],
new_kv_cache_spec
(
num_kv_heads
=
4
)),
],
),
KVCacheConfig
(
num_blocks
=
20
,
tensors
=
{
"layer1"
:
KVCacheTensor
(
100
),
"layer2"
:
KVCacheTensor
(
100
),
},
kv_cache_groups
=
[
KVCacheGroupSpec
([
"layer1"
],
new_kv_cache_spec
()),
KVCacheGroupSpec
([
"layer2"
],
new_kv_cache_spec
(
num_kv_heads
=
8
)),
],
),
]
with
pytest
.
raises
(
AssertionError
):
unify_kv_cache_configs
(
diff_kv_cache_config
)
tests/v1/core/test_scheduler.py
View file @
53076d70
...
@@ -6,7 +6,8 @@ import pytest
...
@@ -6,7 +6,8 @@ import pytest
from
vllm.config
import
CacheConfig
,
ModelConfig
,
SchedulerConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
ModelConfig
,
SchedulerConfig
,
VllmConfig
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.core.scheduler
import
Scheduler
,
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.scheduler
import
Scheduler
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.structured_output
import
StructuredOutputManager
from
vllm.v1.structured_output
import
StructuredOutputManager
...
...
tests/v1/e2e/test_ngram_spec_decode.py
View file @
53076d70
...
@@ -70,12 +70,16 @@ def test_ngram_correctness(
...
@@ -70,12 +70,16 @@ def test_ngram_correctness(
ref_outputs
=
ref_llm
.
chat
(
test_prompts
,
sampling_config
)
ref_outputs
=
ref_llm
.
chat
(
test_prompts
,
sampling_config
)
del
ref_llm
del
ref_llm
spec_llm
=
LLM
(
model
=
model_name
,
spec_llm
=
LLM
(
speculative_model
=
'[ngram]'
,
model
=
model_name
,
ngram_prompt_lookup_max
=
5
,
speculative_config
=
{
ngram_prompt_lookup_min
=
3
,
"method"
:
"ngram"
,
num_speculative_tokens
=
3
,
"prompt_lookup_max"
:
5
,
max_model_len
=
1024
)
"prompt_lookup_min"
:
3
,
"num_speculative_tokens"
:
3
,
},
max_model_len
=
1024
,
)
spec_outputs
=
spec_llm
.
chat
(
test_prompts
,
sampling_config
)
spec_outputs
=
spec_llm
.
chat
(
test_prompts
,
sampling_config
)
matches
=
0
matches
=
0
misses
=
0
misses
=
0
...
...
tests/v1/engine/test_engine_core.py
View file @
53076d70
...
@@ -158,6 +158,22 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
...
@@ -158,6 +158,22 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
0
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
0
assert
len
(
engine_core
.
scheduler
.
running
)
==
0
assert
len
(
engine_core
.
scheduler
.
running
)
==
0
# Sending duplicate requests with same request_id
req0
=
make_request
()
req1
=
make_request
()
req0
.
request_id
=
req1
.
request_id
=
"test"
engine_core
.
add_request
(
req0
)
while
len
(
engine_core
.
step
().
outputs
)
>
0
:
pass
engine_core
.
add_request
(
req1
)
while
len
(
engine_core
.
step
().
outputs
)
>
0
:
pass
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
0
assert
len
(
engine_core
.
scheduler
.
running
)
==
0
@
create_new_process_for_each_test
()
@
create_new_process_for_each_test
()
def
test_engine_core_advanced_sampling
(
monkeypatch
:
pytest
.
MonkeyPatch
):
def
test_engine_core_advanced_sampling
(
monkeypatch
:
pytest
.
MonkeyPatch
):
...
...
tests/v1/entrypoints/llm/test_struct_output_generate.py
View file @
53076d70
...
@@ -57,6 +57,50 @@ def test_guided_json_completion(
...
@@ -57,6 +57,50 @@ def test_guided_json_completion(
jsonschema
.
validate
(
instance
=
output_json
,
schema
=
sample_json_schema
)
jsonschema
.
validate
(
instance
=
output_json
,
schema
=
sample_json_schema
)
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS_V1
)
@
pytest
.
mark
.
parametrize
(
"model_name"
,
MODELS_TO_TEST
)
def
test_guided_json_completion_disable_any_whitespace
(
monkeypatch
:
pytest
.
MonkeyPatch
,
sample_json_schema
:
dict
[
str
,
Any
],
guided_decoding_backend
:
str
,
model_name
:
str
,
):
if
guided_decoding_backend
!=
"xgrammar"
:
pytest
.
skip
(
"disable-any-whitespace is only supported for xgrammar."
)
guided_decoding_backend
=
'xgrammar:disable-any-whitespace'
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
,
guided_decoding_backend
=
guided_decoding_backend
)
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
max_tokens
=
1000
,
guided_decoding
=
GuidedDecodingParams
(
json
=
sample_json_schema
))
outputs
=
llm
.
generate
(
prompts
=
[
f
"Give an example JSON for an employee profile "
f
"that fits this schema:
{
sample_json_schema
}
"
]
*
2
,
sampling_params
=
sampling_params
,
use_tqdm
=
True
)
assert
outputs
is
not
None
for
output
in
outputs
:
assert
output
is
not
None
assert
isinstance
(
output
,
RequestOutput
)
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
assert
generated_text
is
not
None
assert
"
\n
"
not
in
generated_text
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
output_json
=
json
.
loads
(
generated_text
)
jsonschema
.
validate
(
instance
=
output_json
,
schema
=
sample_json_schema
)
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS_V1
)
GUIDED_DECODING_BACKENDS_V1
)
...
@@ -301,7 +345,6 @@ def test_guided_choice_completion(
...
@@ -301,7 +345,6 @@ def test_guided_choice_completion(
prompts
=
"The best language for type-safe systems programming is "
,
prompts
=
"The best language for type-safe systems programming is "
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
use_tqdm
=
True
)
use_tqdm
=
True
)
assert
outputs
is
not
None
assert
outputs
is
not
None
for
output
in
outputs
:
for
output
in
outputs
:
assert
output
is
not
None
assert
output
is
not
None
...
...
tests/v1/spec_decode/test_ngram.py
View file @
53076d70
...
@@ -2,7 +2,8 @@
...
@@ -2,7 +2,8 @@
import
numpy
as
np
import
numpy
as
np
from
vllm.v1.spec_decode.ngram_proposer
import
(
_find_subarray_kmp
,
from
vllm.v1.spec_decode.ngram_proposer
import
(
NgramProposer
,
_find_subarray_kmp
,
_kmp_lps_array
)
_kmp_lps_array
)
...
@@ -35,3 +36,53 @@ def test_find_subarray_kmp():
...
@@ -35,3 +36,53 @@ def test_find_subarray_kmp():
# Return on the first match
# Return on the first match
np
.
testing
.
assert_array_equal
(
_find_subarray_kmp
(
X
,
1
,
3
),
np
.
testing
.
assert_array_equal
(
_find_subarray_kmp
(
X
,
1
,
3
),
np
.
array
([
6
,
2
,
3
]))
np
.
array
([
6
,
2
,
3
]))
def
test_ngram_proposer
():
proposer
=
NgramProposer
()
# No match.
result
=
proposer
.
propose
(
context_token_ids
=
np
.
array
([
1
,
2
,
3
,
4
,
5
]),
min_n
=
2
,
max_n
=
2
,
k
=
2
,
)
assert
result
is
None
# No match for 4-gram.
result
=
proposer
.
propose
(
context_token_ids
=
np
.
array
([
1
,
2
,
3
,
4
,
1
,
2
,
3
]),
min_n
=
4
,
max_n
=
4
,
k
=
2
,
)
assert
result
is
None
# No match for 4-gram but match for 3-gram.
result
=
proposer
.
propose
(
context_token_ids
=
np
.
array
([
1
,
2
,
3
,
4
,
1
,
2
,
3
]),
min_n
=
3
,
max_n
=
4
,
k
=
2
,
)
assert
np
.
array_equal
(
result
,
np
.
array
([
4
,
1
]))
# Match for both 4-gram and 3-gram.
# In this case, the proposer should return the 4-gram match.
result
=
proposer
.
propose
(
context_token_ids
=
np
.
array
([
2
,
3
,
4
,
5
,
1
,
2
,
3
,
4
,
1
,
2
,
3
,
4
]),
min_n
=
3
,
max_n
=
4
,
k
=
2
,
)
assert
np
.
array_equal
(
result
,
np
.
array
([
1
,
2
]))
# Not [5, 1]
# Match for 2-gram and 3-gram, but not 4-gram.
result
=
proposer
.
propose
(
context_token_ids
=
np
.
array
([
3
,
4
,
5
,
2
,
3
,
4
,
1
,
2
,
3
,
4
]),
min_n
=
2
,
max_n
=
4
,
k
=
2
,
)
assert
np
.
array_equal
(
result
,
np
.
array
([
1
,
2
]))
# Not [5, 2]
tests/v1/tpu/test_mha_attn.py
0 → 100644
View file @
53076d70
# SPDX-License-Identifier: Apache-2.0
"""
Test:
* Tests for MultiHeadAttention layer
"""
import
pytest
import
torch
import
torch_xla
import
torch_xla.core
import
torch_xla.core.xla_model
from
vllm
import
envs
from
vllm.attention.layer
import
MultiHeadAttention
from
vllm.attention.selector
import
_cached_get_attn_backend
from
vllm.platforms
import
current_platform
if
not
envs
.
VLLM_USE_V1
:
pytest
.
skip
(
"Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test."
,
allow_module_level
=
True
,
)
@
pytest
.
fixture
(
autouse
=
True
)
def
clear_cache
():
"""Clear lru cache to ensure each test case runs without caching.
"""
_cached_get_attn_backend
.
cache_clear
()
def
ref_attention
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
scale
:
float
,
)
->
torch
.
Tensor
:
"""
Native implementation of scaled dot product attention without mask:
- query, key, value: [batch_size, seq_len, num_heads, head_size]
- attn_mask: [batch_size, seq_len, seq_len]
"""
query
,
key
,
value
=
(
x
.
transpose
(
1
,
2
)
for
x
in
(
query
,
key
,
value
))
attn_weights
=
scale
*
torch
.
matmul
(
query
,
key
.
transpose
(
2
,
3
))
attn_weights
=
torch
.
softmax
(
attn_weights
,
dim
=-
1
).
to
(
value
.
dtype
)
out
=
torch
.
matmul
(
attn_weights
,
value
).
transpose
(
1
,
2
)
return
out
BATCH_SIZES
=
[
1
,
16
]
SEQ_LENS
=
[
1
]
NUM_HEADS
=
[
1
,
16
]
NUM_KV_HEADS
=
[
1
]
HEAD_SIZES
=
[
64
,
80
]
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_tpu
(),
reason
=
"This test needs a TPU"
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZES
)
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
SEQ_LENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"num_kv_heads"
,
NUM_KV_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
[
torch_xla
.
core
.
xla_model
.
xla_device
()])
def
test_mha_attn_forward
(
batch_size
:
int
,
seq_len
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
device
:
str
,
):
current_platform
.
seed_everything
(
0
)
# These are expected to be f32
q
=
torch
.
randn
(
batch_size
,
seq_len
,
num_heads
*
head_size
,
device
=
device
)
k
=
torch
.
randn
(
batch_size
,
seq_len
,
num_kv_heads
*
head_size
,
device
=
device
)
v
=
torch
.
randn
(
batch_size
,
seq_len
,
num_kv_heads
*
head_size
,
device
=
device
)
scale
=
1.0
/
head_size
**
0.5
attn
=
MultiHeadAttention
(
num_heads
,
head_size
,
scale
=
scale
,
num_kv_heads
=
num_kv_heads
)
output
=
attn
(
q
,
k
,
v
)
assert
num_heads
%
num_kv_heads
==
0
num_queries_per_kv
=
num_heads
//
num_kv_heads
q
=
q
.
reshape
(
batch_size
,
seq_len
,
num_heads
,
head_size
)
k
=
k
.
reshape
(
batch_size
,
seq_len
,
num_kv_heads
,
head_size
)
v
=
v
.
reshape
(
batch_size
,
seq_len
,
num_kv_heads
,
head_size
)
if
num_queries_per_kv
>
1
:
k
=
torch
.
repeat_interleave
(
k
,
num_queries_per_kv
,
dim
=
2
)
v
=
torch
.
repeat_interleave
(
v
,
num_queries_per_kv
,
dim
=
2
)
ref_output
=
ref_attention
(
q
,
k
,
v
,
scale
=
scale
,
).
reshape
(
batch_size
,
seq_len
,
num_heads
*
head_size
)
# torch_xla flash_attn kernel is less accurate but much faster
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-3
)
tests/v1/tpu/test_sampler.py
0 → 100644
View file @
53076d70
# SPDX-License-Identifier: Apache-2.0
import
tempfile
from
time
import
time
import
pytest
from
vllm
import
LLM
,
envs
from
vllm.platforms
import
current_platform
from
vllm.sampling_params
import
SamplingParams
if
not
envs
.
VLLM_USE_V1
:
pytest
.
skip
(
"Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test."
,
allow_module_level
=
True
,
)
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"D4nt3/Qwen2.5-two-layers"
])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_tpu
(),
reason
=
"This test needs a TPU"
)
def
test_sampler_compilation
(
model_name
:
str
,
monkeypatch
):
"""
Check that no recompilation happens despite changing sampling parameters.
We can't read XLA metrics from the engine process, hence we measure time.
"""
with
tempfile
.
TemporaryDirectory
()
as
temp_dir
:
monkeypatch
.
setenv
(
"VLLM_XLA_CACHE_PATH"
,
temp_dir
)
# Compiling model init may still take some time, enforce_eager to skip.
llm
=
LLM
(
model_name
,
enforce_eager
=
True
,
max_num_seqs
=
16
,
max_model_len
=
1024
,
gpu_memory_utilization
=
0.5
)
prompts
=
[
"A robot may not injure a human being"
,
"It is only with the heart that one can see rightly;"
,
]
# First inference should be slow
sampling_params
=
SamplingParams
(
temperature
=
0.7
,
# top_p=0.6, # TODO too slow!
top_k
=
10
,
min_p
=
0.2
,
max_tokens
=
16
)
s
=
time
()
_
=
llm
.
generate
(
prompts
,
sampling_params
)
run1
=
time
()
-
s
# Second request with different params, but for which we
# compiled for in previous eager iteration.
sampling_params
=
SamplingParams
(
temperature
=
0.1
,
top_k
=
12
,
min_p
=
0.8
,
max_tokens
=
24
)
s
=
time
()
_
=
llm
.
generate
(
prompts
,
sampling_params
)
run2
=
time
()
-
s
# Much faster after compiling
assert
run1
*
0.1
>
run2
print
(
"TIMES"
,
run1
,
run2
)
# Third request with min_p set to "None". It will not trigger
# recompilation as a default 0 value will be used.
sampling_params
=
SamplingParams
(
max_tokens
=
24
,
temperature
=
0.0
)
s
=
time
()
_
=
llm
.
generate
(
prompts
,
sampling_params
)
run3
=
time
()
-
s
assert
run1
*
0.1
>
run3
print
(
"TIMES"
,
run1
,
run3
)
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"Qwen/Qwen2.5-1.5B-Instruct"
])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_tpu
(),
reason
=
"This test needs a TPU"
)
def
test_sampler_different
(
model_name
:
str
):
"""
Test significantly different sampling params to assert the model produces
different results.
"""
llm
=
LLM
(
model_name
,
enforce_eager
=
True
,
max_num_seqs
=
1
,
max_model_len
=
64
,
# TODO: setting to 0.5 or it will go OOM
gpu_memory_utilization
=
0.5
)
prompts
=
[
"Write a short story about a robot that dreams for the first time."
]
sampling_params
=
SamplingParams
(
temperature
=
0.9
,
min_p
=
0.2
,
max_tokens
=
64
)
output
=
llm
.
generate
(
prompts
,
sampling_params
)
sampling_params
=
SamplingParams
(
temperature
=
0.1
,
min_p
=
0.8
,
max_tokens
=
64
)
output2
=
llm
.
generate
(
prompts
,
sampling_params
)
assert
output
[
0
].
outputs
[
0
].
text
!=
output2
[
0
].
outputs
[
0
].
text
Prev
1
2
3
4
5
6
7
8
9
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment