Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
775f00f8
Unverified
Commit
775f00f8
authored
Sep 11, 2024
by
Lily Liu
Committed by
GitHub
Sep 11, 2024
Browse files
[Speculative Decoding] Test refactor (#8317)
Co-authored-by:
youkaichao
<
youkaichao@126.com
>
parent
8baa4549
Changes
12
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
929 additions
and
1044 deletions
+929
-1044
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+2
-1
tests/spec_decode/e2e/conftest.py
tests/spec_decode/e2e/conftest.py
+176
-299
tests/spec_decode/e2e/test_eagle_correctness.py
tests/spec_decode/e2e/test_eagle_correctness.py
+49
-48
tests/spec_decode/e2e/test_integration.py
tests/spec_decode/e2e/test_integration.py
+32
-20
tests/spec_decode/e2e/test_integration_dist_tp2.py
tests/spec_decode/e2e/test_integration_dist_tp2.py
+76
-79
tests/spec_decode/e2e/test_integration_dist_tp4.py
tests/spec_decode/e2e/test_integration_dist_tp4.py
+65
-61
tests/spec_decode/e2e/test_logprobs.py
tests/spec_decode/e2e/test_logprobs.py
+93
-234
tests/spec_decode/e2e/test_medusa_correctness.py
tests/spec_decode/e2e/test_medusa_correctness.py
+72
-46
tests/spec_decode/e2e/test_mlp_correctness.py
tests/spec_decode/e2e/test_mlp_correctness.py
+102
-65
tests/spec_decode/e2e/test_multistep_correctness.py
tests/spec_decode/e2e/test_multistep_correctness.py
+171
-136
tests/spec_decode/e2e/test_ngram_correctness.py
tests/spec_decode/e2e/test_ngram_correctness.py
+58
-36
tests/spec_decode/e2e/test_seed.py
tests/spec_decode/e2e/test_seed.py
+33
-19
No files found.
.buildkite/test-pipeline.yaml
View file @
775f00f8
...
...
@@ -217,7 +217,8 @@ steps:
commands
:
# See https://github.com/vllm-project/vllm/issues/5152
-
export VLLM_ATTENTION_BACKEND=XFORMERS
-
pytest -v -s spec_decode
-
pytest -v -s spec_decode/e2e/test_multistep_correctness.py
-
pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py
-
label
:
LoRA Test %N
# 30min each
mirror_hardwares
:
[
amd
]
...
...
tests/spec_decode/e2e/conftest.py
View file @
775f00f8
import
asyncio
import
os
from
itertools
import
cycle
from
typing
import
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
import
pytest
import
ray
import
torch
from
vllm
import
LLM
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.lora.request
import
LoRARequest
from
vllm
import
LLM
,
SamplingParams
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.outputs
import
RequestOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
Logprob
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
Counter
,
random_uuid
from
...conftest
import
cleanup
from
...utils
import
wait_for_gpu_memory_to_clear
from
...models.utils
import
check_logprobs_close
,
check_outputs_equal
from
...utils
import
RemoteOpenAIServer
class
AsyncLLM
:
"""AsyncLLM
Note: Current LLM class in vllm don't support async mode, for test purpose,
we implement async one in here. Maybe we could move to
vllm/entrypoints/llm.py in future.
Below AsyncLLM is directly borrow from vllm/entrypoints/llm.py with changes
to make to work in async mode.
"""
def
__init__
(
self
,
model
:
str
,
tokenizer
:
Optional
[
str
]
=
None
,
tokenizer_mode
:
str
=
"auto"
,
skip_tokenizer_init
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
tensor_parallel_size
:
int
=
1
,
dtype
:
str
=
"auto"
,
quantization
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
tokenizer_revision
:
Optional
[
str
]
=
None
,
seed
:
int
=
0
,
gpu_memory_utilization
:
float
=
0.9
,
swap_space
:
int
=
4
,
enforce_eager
:
bool
=
False
,
max_seq_len_to_capture
:
int
=
8192
,
disable_custom_all_reduce
:
bool
=
False
,
**
kwargs
,
)
->
None
:
if
"disable_log_stats"
not
in
kwargs
:
kwargs
[
"disable_log_stats"
]
=
True
# Needed to engine_use_ray works as a deprecated feature,
# otherwise the following constructor will raise an exception
os
.
environ
[
"VLLM_ALLOW_ENGINE_USE_RAY"
]
=
"1"
engine_args
=
AsyncEngineArgs
(
model
=
model
,
tokenizer
=
tokenizer
,
tokenizer_mode
=
tokenizer_mode
,
skip_tokenizer_init
=
skip_tokenizer_init
,
trust_remote_code
=
trust_remote_code
,
tensor_parallel_size
=
tensor_parallel_size
,
dtype
=
dtype
,
quantization
=
quantization
,
revision
=
revision
,
tokenizer_revision
=
tokenizer_revision
,
seed
=
seed
,
gpu_memory_utilization
=
gpu_memory_utilization
,
swap_space
=
swap_space
,
enforce_eager
=
enforce_eager
,
max_seq_len_to_capture
=
max_seq_len_to_capture
,
# For now use ray for the distributed back-end, since
# we rely on the use of engine_use_ray=True to avoid
# reinitializing CUDA in the same process (driver worker)
engine_use_ray
=
True
,
distributed_executor_backend
=
"ray"
,
disable_custom_all_reduce
=
disable_custom_all_reduce
,
**
kwargs
,
)
self
.
request_counter
=
Counter
()
self
.
llm_engine
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
,
usage_context
=
UsageContext
.
LLM_CLASS
)
def
generate
(
self
,
prompts
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
sampling_params
:
Optional
[
Union
[
SamplingParams
,
List
[
SamplingParams
]]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
List
[
int
]]]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalDataDict
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
)
->
List
[
RequestOutput
]:
if
prompts
is
None
:
raise
ValueError
(
"prompts must be provided."
)
if
isinstance
(
prompts
,
str
):
# Convert a single prompt to a list.
prompts
=
[
prompts
]
if
prompts
is
not
None
:
num_requests
=
len
(
prompts
)
if
sampling_params
is
None
:
# Use default sampling params.
sampling_params
=
SamplingParams
()
elif
isinstance
(
sampling_params
,
list
)
and
len
(
sampling_params
)
!=
num_requests
:
raise
ValueError
(
"The lengths of prompts and "
"sampling_params must be the same."
)
async
def
get_output
(
prompt
,
sampling_param
)
->
RequestOutput
:
request_id
=
random_uuid
()
results_generator
=
self
.
llm_engine
.
generate
(
prompt
,
sampling_param
,
request_id
)
final_output
=
None
async
for
request_output
in
results_generator
:
final_output
=
request_output
assert
final_output
is
not
None
return
final_output
outputs
:
List
[
RequestOutput
]
=
[]
try
:
for
i
in
range
(
num_requests
):
prompt
=
prompts
[
i
]
if
prompts
is
not
None
else
None
params
=
sampling_params
[
i
]
if
isinstance
(
sampling_params
,
Sequence
)
else
sampling_params
res
=
asyncio
.
run
(
get_output
(
prompt
,
params
))
outputs
.
append
(
res
)
finally
:
ray
.
shutdown
()
return
outputs
@
pytest
.
fixture
def
baseline_llm_generator
(
request
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
seed
):
return
create_llm_generator
(
"baseline"
,
request
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
seed
)
PROMPTS
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
"San Francisco is know for its"
,
"Facebook was created in 2004 by"
,
"Curious George is a"
,
"Python 3.11 brings improvements to its"
,
]
@
pytest
.
fixture
def
test_llm_generator
(
request
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
def
test_llm_generator
(
common_llm_kwargs
,
per_test_common_llm_kwargs
,
test_llm_kwargs
,
seed
):
return
create_llm_generator
(
"test"
,
request
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
test_llm_kwargs
,
seed
)
def
create_llm_generator
(
baseline_or_test
,
request
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
distinct_llm_kwargs
,
seed
):
def
generate
():
kwargs
=
{
**
common_llm_kwargs
,
**
per_test_common_llm_kwargs
,
**
distinc
t_llm_kwargs
,
**
tes
t_llm_kwargs
,
}
test_name
=
request
.
node
.
name
model
=
kwargs
[
"model"
]
draft_model
=
kwargs
.
get
(
"speculative_model"
,
None
)
same_draft_target_model
=
(
draft_model
is
not
None
and
draft_model
==
model
)
def
generator_inner
():
wait_for_gpu_memory_to_clear
(
devices
=
list
(
range
(
torch
.
cuda
.
device_count
())),
threshold_bytes
=
2
*
2
**
30
,
timeout_s
=
60
,
)
use_async
=
False
if
"use_async"
in
kwargs
:
use_async
=
kwargs
.
pop
(
"use_async"
)
print
(
f
'
{
use_async
=
}
'
)
print
(
f
'Creating
{
baseline_or_test
=
}
LLM for
{
test_name
=
}
.
{
kwargs
=
}
'
)
llm
=
AsyncLLM
(
**
kwargs
)
if
use_async
else
LLM
(
**
kwargs
)
# Override logging interval to 0 for spec decode test run to
# log all metrics in time.
if
(
baseline_or_test
==
"test"
and
not
use_async
and
llm
.
llm_engine
.
log_stats
):
for
sate_logger
in
llm
.
llm_engine
.
stat_loggers
.
values
():
sate_logger
.
local_interval
=
0
llm
=
LLM
(
**
kwargs
)
if
seed
is
not
None
:
set_random_seed
(
seed
)
yield
llm
del
llm
cleanup
()
def
generator_outer
():
for
llm
in
generator_inner
():
yield
llm
del
llm
cleanup
()
# Set an attribute to the generator_outer function to allow us to
# determine whether to further check the acceptance rate in tests.
generator_outer
.
same_draft_target_model
=
same_draft_target_model
# type: ignore
return
generator_outer
return
generate
def
maybe_assert_ngram_worker
(
llm
):
# Verify the proposer worker is ngram if ngram is specified.
if
(
not
isinstance
(
llm
,
AsyncLLM
)
and
llm
.
llm_engine
.
speculative_config
is
not
None
if
(
llm
.
llm_engine
.
speculative_config
is
not
None
and
llm
.
llm_engine
.
speculative_config
.
ngram_prompt_lookup_max
>
0
):
from
vllm.spec_decode.ngram_worker
import
NGramWorker
assert
isinstance
(
...
...
@@ -251,118 +81,165 @@ def get_output_from_llm_generator(
return
tokens
,
token_ids
,
acceptance_rate
def
get_logprobs_from_llm_generator
(
llm_generator
,
prompts
,
sampling_params
)
->
List
[
List
[
Dict
[
int
,
Logprob
]]]:
"""Returns a dict of (token_id: Logprob) for each generated position, for
each sequence in the batch.
"""
for
llm
in
llm_generator
():
outputs
=
llm
.
generate
(
prompts
,
sampling_params
,
use_tqdm
=
True
)
logprobs
=
[
output
.
outputs
[
0
].
logprobs
[:]
for
output
in
outputs
]
del
llm
def
run_logprob_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
max_output_len
:
int
,
seed
:
Optional
[
int
]
=
0
,
temperature
:
float
=
0.0
,
logprobs
:
int
=
1
):
org_args
=
{
**
common_llm_kwargs
,
**
per_test_common_llm_kwargs
,
**
baseline_llm_kwargs
,
}
return
logprobs
sd_args
=
{
**
common_llm_kwargs
,
**
per_test_common_llm_kwargs
,
**
test_llm_kwargs
,
}
prompts
=
[
prompt
for
prompt
,
_
in
zip
(
cycle
(
PROMPTS
),
range
(
batch_size
))]
def
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
,
force_output_len
:
bool
,
print_tokens
:
bool
=
False
,
ensure_all_accepted
:
bool
=
False
):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero.
"""
sampling_params
=
SamplingParams
(
temperature
=
temperature
,
max_tokens
=
max_output_len
,
seed
=
seed
,
logprobs
=
logprobs
)
with
vllm_runner
(
**
org_args
)
as
vllm_model
:
org_outputs
=
vllm_model
.
generate_w_logprobs
(
prompts
,
sampling_params
)
run_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
,
force_output_len
,
temperature
=
0.0
,
seeded
=
False
,
print_tokens
=
print_tokens
,
ensure_all_accepted
=
ensure_all_accepted
)
with
vllm_runner
(
**
sd_args
)
as
vllm_model
:
sd_outputs
=
vllm_model
.
generate_w_logprobs
(
prompts
,
sampling_params
)
check_logprobs_close
(
outputs_0_lst
=
org_outputs
,
outputs_1_lst
=
sd_outputs
,
name_0
=
"org"
,
name_1
=
"sd"
)
def
run_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
,
force_output_len
:
bool
,
temperature
:
float
,
seeded
:
bool
,
print_tokens
:
bool
=
False
,
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
max_output_len
:
int
,
seed
:
Optional
[
int
]
=
0
,
temperature
:
float
=
0.0
,
disable_seed
:
bool
=
False
,
ignore_eos
:
bool
=
True
,
ensure_all_accepted
:
bool
=
False
,
expected_acceptance_rate
:
Optional
[
float
]
=
None
):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero (or when temperature is > 0 and seeded).
"""
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
"San Francisco is know for its"
,
"Facebook was created in 2004 by"
,
"Curious George is a"
,
"Python 3.11 brings improvements to its"
,
]
org_args
=
{
**
common_llm_kwargs
,
**
per_test_common_llm_kwargs
,
**
baseline_llm_kwargs
,
}
prompts
=
[
prompt
for
prompt
,
_
in
zip
(
cycle
(
prompts
),
range
(
batch_size
))]
sd_args
=
{
**
common_llm_kwargs
,
**
per_test_common_llm_kwargs
,
**
test_llm_kwargs
,
}
# If the test requires that we generated max_output_len tokens, then set the
# sampling params to ignore eos token.
ignore_eos
=
force_output_len
prompts
=
[
prompt
for
prompt
,
_
in
zip
(
cycle
(
PROMPTS
),
range
(
batch_size
))]
if
seeded
:
sampling_params
=
[
SamplingParams
(
max_tokens
=
max_output_len
,
ignore_eos
=
ignore_eos
,
temperature
=
temperature
,
seed
=
i
,
)
for
i
in
range
(
len
(
prompts
))
]
else
:
sampling_params
=
SamplingParams
(
if
disable_seed
:
seed
=
None
sampling_params
=
SamplingParams
(
temperature
=
temperature
,
max_tokens
=
max_output_len
,
ignore_eos
=
ignore_eos
,
temperature
=
temperature
,
)
(
spec_batch_tokens
,
spec_batch_token_ids
,
acceptance_rate
)
=
get_output_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
(
baseline_batch_tokens
,
baseline_batch_token_ids
,
_
)
=
get_output_from_llm_generator
(
baseline_llm_generator
,
prompts
,
sampling_params
)
assert
len
(
baseline_batch_token_ids
)
==
len
(
prompts
)
assert
len
(
spec_batch_token_ids
)
==
len
(
prompts
)
for
i
,
(
baseline_token_ids
,
baseline_tokens
,
spec_token_ids
,
spec_tokens
)
in
enumerate
(
zip
(
baseline_batch_token_ids
,
baseline_batch_tokens
,
spec_batch_token_ids
,
spec_batch_tokens
)):
if
print_tokens
:
print
(
f
'
{
i
=
}
{
baseline_tokens
=
}
'
)
print
(
f
'
{
i
=
}
{
spec_tokens
=
}
'
)
print
(
f
'
{
i
=
}
{
baseline_token_ids
=
}
'
)
print
(
f
'
{
i
=
}
{
spec_token_ids
=
}
'
)
assert
baseline_token_ids
==
spec_token_ids
print
(
f
'
{
acceptance_rate
=
}
'
)
seed
=
seed
,
ignore_eos
=
ignore_eos
)
with
vllm_runner
(
**
org_args
)
as
vllm_model
:
org_outputs
=
vllm_model
.
generate
(
prompts
,
sampling_params
)
with
vllm_runner
(
**
sd_args
)
as
vllm_model
:
if
ensure_all_accepted
or
expected_acceptance_rate
is
not
None
:
# Force log interval to be 0 to catch all metrics.
stat_logger
=
vllm_model
.
model
.
llm_engine
.
stat_loggers
[
'prometheus'
]
stat_logger
.
local_interval
=
-
100
sd_outputs
=
vllm_model
.
generate
(
prompts
,
sampling_params
)
if
ensure_all_accepted
or
expected_acceptance_rate
is
not
None
:
acceptance_rate
=
(
stat_logger
.
metrics
.
gauge_spec_decode_draft_acceptance_rate
.
labels
(
**
stat_logger
.
labels
).
_value
.
get
())
if
ensure_all_accepted
:
assert
acceptance_rate
==
1.0
assert
True
# FIXME: ci fails to log acceptance rate.
# It works locally.
# assert acceptance_rate == 1.0
if
expected_acceptance_rate
is
not
None
:
assert
acceptance_rate
>=
expected_acceptance_rate
-
1e-2
check_outputs_equal
(
outputs_0_lst
=
org_outputs
,
outputs_1_lst
=
sd_outputs
,
name_0
=
"org"
,
name_1
=
"sd"
)
def
run_equality_correctness_test_tp
(
model
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
max_output_len
:
int
,
seed
:
int
=
0
,
temperature
:
float
=
0.0
):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero.
"""
arg1
=
common_llm_kwargs
+
per_test_common_llm_kwargs
+
baseline_llm_kwargs
arg2
=
common_llm_kwargs
+
per_test_common_llm_kwargs
+
test_llm_kwargs
env1
=
env2
=
None
max_wait_seconds
=
240
results
=
[]
prompts
=
[
prompt
for
prompt
,
_
in
zip
(
cycle
(
PROMPTS
),
range
(
batch_size
))]
for
args
,
env
in
((
arg1
,
env1
),
(
arg2
,
env2
)):
with
RemoteOpenAIServer
(
model
,
args
,
env_dict
=
env
,
max_wait_seconds
=
max_wait_seconds
)
as
server
:
client
=
server
.
get_client
()
completion
=
client
.
completions
.
create
(
model
=
model
,
prompt
=
prompts
,
max_tokens
=
max_output_len
,
seed
=
seed
,
temperature
=
temperature
)
results
.
append
({
"test"
:
"seeded_sampling"
,
"text"
:
[
choice
.
text
for
choice
in
completion
.
choices
],
"finish_reason"
:
[
choice
.
finish_reason
for
choice
in
completion
.
choices
],
"usage"
:
completion
.
usage
,
})
n
=
len
(
results
)
//
2
arg1_results
=
results
[:
n
]
arg2_results
=
results
[
n
:]
for
arg1_result
,
arg2_result
in
zip
(
arg1_results
,
arg2_results
):
assert
arg1_result
==
arg2_result
,
(
f
"Results for
{
model
=
}
are not the same with
{
arg1
=
}
and
{
arg2
=
}
. "
f
"
{
arg1_result
=
}
!=
{
arg2_result
=
}
"
)
tests/spec_decode/e2e/test_eagle_correctness.py
View file @
775f00f8
...
...
@@ -21,7 +21,7 @@ correctess for the target model outputs.
import
pytest
from
.conftest
import
run_
greedy_
equality_correctness_test
from
.conftest
import
run_equality_correctness_test
# main model
MAIN_MODEL
=
"JackFram/llama-68m"
...
...
@@ -53,7 +53,7 @@ PRECISION = "float32"
"dtype"
:
PRECISION
,
# Main model
"model"
:
MAIN_MODEL
,
"model
_name
"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
...
...
@@ -68,15 +68,16 @@ PRECISION = "float32"
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_eagle_e2e_greedy_correctness
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify greedy equality with different batch size."""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
def
test_eagle_e2e_greedy_correctness
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
output_len
,
seed
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -94,7 +95,7 @@ def test_eagle_e2e_greedy_correctness(baseline_llm_generator,
"dtype"
:
PRECISION
,
# Main model
"model"
:
MAIN_MODEL
,
"model
_name
"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
...
...
@@ -109,17 +110,16 @@ def test_eagle_e2e_greedy_correctness(baseline_llm_generator,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_eagle_e2e_greedy_correctness_cuda_graph
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
def
test_eagle_e2e_greedy_correctness_cuda_graph
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality with cuda graph enabled and different
batch sizes."""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
output_len
,
seed
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -140,7 +140,7 @@ def test_eagle_e2e_greedy_correctness_cuda_graph(baseline_llm_generator,
"dtype"
:
PRECISION
,
# Main model
"model"
:
MAIN_MODEL
,
"model
_name
"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
...
...
@@ -158,18 +158,17 @@ def test_eagle_e2e_greedy_correctness_cuda_graph(baseline_llm_generator,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_eagle_e2e_greedy_correctness_with_preemption
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
def
test_eagle_e2e_greedy_correctness_with_preemption
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
output_len
,
seed
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -185,7 +184,7 @@ def test_eagle_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
"dtype"
:
PRECISION
,
# Main model
"model"
:
MAIN_MODEL
,
"model
_name
"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
...
...
@@ -207,16 +206,17 @@ def test_eagle_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_eagle_different_k
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
def
test_eagle_different_k
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify that eagle speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
output_len
,
seed
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -232,7 +232,7 @@ def test_eagle_different_k(baseline_llm_generator, test_llm_generator,
"dtype"
:
PRECISION
,
# Main model
"model"
:
MAIN_MODEL
,
"model
_name
"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
...
...
@@ -250,17 +250,18 @@ def test_eagle_different_k(baseline_llm_generator, test_llm_generator,
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_eagle_disable_queue
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
def
test_eagle_disable_queue
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify that eagle speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
output_len
,
seed
)
if
__name__
==
"__main__"
:
...
...
tests/spec_decode/e2e/test_integration.py
View file @
775f00f8
...
...
@@ -4,7 +4,9 @@ other features, e.g. cuda graphs.
import
pytest
from
.conftest
import
run_greedy_equality_correctness_test
from
.conftest
import
run_equality_correctness_test
MAIN_MODEL
=
"JackFram/llama-68m"
@
pytest
.
mark
.
parametrize
(
...
...
@@ -15,7 +17,7 @@ from .conftest import run_greedy_equality_correctness_test
# Verify equality when cuda graphs allowed.
"enforce_eager"
:
False
,
"model"
:
"JackFram/llama-68m"
,
"model
_name
"
:
"JackFram/llama-68m"
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
...
...
@@ -31,23 +33,27 @@ from .conftest import run_greedy_equality_correctness_test
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_spec_decode_cuda_graph
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
output_len
):
def
test_spec_decode_cuda_graph
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify spec decode equality when cuda graphs are enabled.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
,
)
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-160m"
,
"model
_name
"
:
"JackFram/llama-160m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
...
...
@@ -80,13 +86,19 @@ def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator,
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_speculative_model_quantization_config
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
):
def
test_speculative_model_quantization_config
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
seed
:
int
):
"""Verify spec decode works well with draft model quantization configs.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
32
,
force_output_len
=
True
)
seed
=
seed
,
temperature
=
0.0
)
tests/spec_decode/e2e/test_integration_dist_tp2.py
View file @
775f00f8
...
...
@@ -7,42 +7,39 @@ import torch
from
vllm.utils
import
is_hip
from
.conftest
import
run_
greedy_
equality_correctness_test
from
.conftest
import
run_equality_correctness_test
_tp
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
reason
=
"Need at least 2 GPUs to run the test."
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
[[
# Skip cuda graph recording for fast test.
"enforce
_
eager"
:
True
,
"
--
enforce
-
eager"
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
"tensor_parallel_size"
:
2
,
# Use AsyncLLM engine, so that the engine runs in its own process.
# Otherwise, since vLLM does not follow true SPMD, the test runner
# process will have both the engine and the rank0 worker. NCCL is not
# cleaned up properly, and its server host thread leaks, causing the
# second run of the test to fail with internal NCCL error.
"use_async"
:
True
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
"--use-v2-block-manager"
,
"--tensor-parallel-size"
,
"2"
]])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[[]])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[[]])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
3
,
},
{
"speculative_model"
:
"[ngram]"
,
"num_speculative_tokens"
:
5
,
"ngram_prompt_lookup_max"
:
3
,
},
[
"--speculative-model"
,
"JackFram/llama-68m"
,
"--num-speculative-tokens"
,
"3"
,
],
[
"--speculative-model"
,
"[ngram]"
,
"--num-speculative-tokens"
,
"5"
,
"--ngram-prompt-lookup-max"
,
"3"
,
],
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
...
...
@@ -52,75 +49,75 @@ from .conftest import run_greedy_equality_correctness_test
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_target_model_tp_gt_1
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
def
test_target_model_tp_gt_1
(
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality when tensor parallelism is used.
"""
if
is_hip
():
pytest
.
skip
(
"hip is not well-supported yet"
)
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
run_equality_correctness_test_tp
(
"JackFram/llama-68m"
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
output_len
,
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
reason
=
"Need at least 2 GPUs to run the test."
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[
{
[
[
# Skip cuda graph recording for fast test.
"enforce
_
eager"
:
True
,
"
--
enforce
-
eager"
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
"tensor_parallel_size"
:
2
,
# Use AsyncLLM engine, so that the engine runs in its own process.
# Otherwise, since vLLM does not follow true SPMD, the test runner
# process will have both the engine and the rank0 worker. NCCL is not
# cleaned up properly, and its server host thread leaks, causing the
# second run of the test to fail with internal NCCL error.
"use_async"
:
True
,
"--use_v2_block_manager"
,
"--tensor_parallel_size"
,
"2"
,
# precision
"dtype"
:
"float32"
,
}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs, test_llm_kwargs"
,
[
(
{
# Use a small model for a fast test.
# Note this is repeated in the test body; to initialize a
# tokenizer.
"model"
:
"JackFram/llama-68m"
,
},
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
"speculative_draft_tensor_parallel_size"
:
1
,
}),
({
"model"
:
"ibm-granite/granite-3b-code-instruct"
,
},
{
"speculative_model"
:
"ibm-granite/granite-3b-code-instruct-accelerator"
,
"num_speculative_tokens"
:
5
,
"speculative_draft_tensor_parallel_size"
:
1
,
})
])
"--dtype"
,
"bfloat16"
,
]])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[[]])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[[]])
@
pytest
.
mark
.
parametrize
(
"model, test_llm_kwargs"
,
[(
"JackFram/llama-68m"
,
[
"--speculative-model"
,
"JackFram/llama-68m"
,
"--num_speculative-tokens"
,
"5"
,
"--speculative-draft-tensor-parallel-size"
,
"1"
,
]),
(
"ibm-granite/granite-3b-code-instruct"
,
[
"--speculative-model"
,
"ibm-granite/granite-3b-code-instruct"
,
"--num_speculative-tokens"
,
"5"
,
"--speculative-draft-tensor-parallel-size"
,
"1"
,
])])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_draft_model_tp_lt_target_model_tp2
(
test_llm_generator
,
baseline_llm_generator
,
batch_size
:
int
):
def
test_draft_model_tp_lt_target_model_tp2
(
model
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
seed
:
int
):
"""Verify spec decode works well with smaller tp for draft models.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
run_equality_correctness_test_tp
(
model
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
32
,
force_output_len
=
True
)
seed
=
seed
,
temperature
=
0.0
)
tests/spec_decode/e2e/test_integration_dist_tp4.py
View file @
775f00f8
...
...
@@ -2,98 +2,97 @@
tensor parallelism.
"""
import
openai
import
pytest
import
torch
from
.conftest
import
run_greedy_equality_correctness_test
from
.conftest
import
run_equality_correctness_test_tp
MAIN_MODEL
=
"JackFram/llama-68m"
SPEC_MODEL
=
"JackFram/llama-68m"
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
4
,
reason
=
"Need at least 4 GPUs to run the test."
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Use a small model for a fast test.
# Note this is repeated in the test body; to initialize a tokenizer.
"model"
:
"JackFram/llama-68m"
,
[[
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
"
--
enforce_eager"
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
"tensor_parallel_size"
:
4
,
# Use AsyncLLM engine, so that the engine runs in its own process.
# Otherwise, since vLLM does not follow true SPMD, the test runner
# process will have both the engine and the rank0 worker. NCCL is not
# cleaned up properly, and its server host thread leaks, causing the
# second run of the test to fail with internal NCCL error.
"use_async"
:
True
,
}])
"--use-v2-block-manager"
,
"--tensor-parallel-size"
,
"4"
,
]])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
[
"--speculative-model"
,
f
"
{
SPEC_MODEL
}
"
,
"--num-speculative-tokens"
,
"5"
,
],
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[
{}
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[
[]
])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
#TODO(wooyeon): add spec_draft_dp=2 case
{
"speculative_draft_tensor_parallel_size"
:
1
,
},
[
"--speculative-draft-tensor-parallel-size"
,
"1"
,
],
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_draft_model_tp_lt_target_model_tp4
(
test_llm_generator
,
baseline_llm_generator
,
batch_size
:
int
):
def
test_draft_model_tp_lt_target_model_tp4
(
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
seed
:
int
):
"""Verify spec decode works well with smaller tp for draft models.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
run_equality_correctness_test_tp
(
MAIN_MODEL
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
32
,
force_output_len
=
True
)
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
4
,
reason
=
"Need at least 4 GPUs to run the test."
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-160m"
,
[[
# Skip cuda graph recording for fast test.
"enforce
_
eager"
:
True
,
"
--
enforce
-
eager"
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
"tensor_parallel_size"
:
4
,
# Use AsyncLLM engine, so that the engine runs in its own process.
# Otherwise, since vLLM does not follow true SPMD, the test runner
# process will have both the engine and the rank0 worker. NCCL is not
# cleaned up properly, and its server host thread leaks, causing the
# second run of the test to fail with internal NCCL error.
"use_async"
:
True
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
"--use-v2-block-manager"
,
"--tensor-parallel-size"
,
"4"
,
]])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[[]])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[[]])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
[
"--speculative-model"
,
f
"
{
SPEC_MODEL
}
"
,
"--num-speculative-tokens"
,
"5"
,
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_max_model_len"
:
32
,
},
"--speculative-max-model-len"
,
"32"
,
],
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
...
...
@@ -105,8 +104,9 @@ def test_draft_model_tp_lt_target_model_tp4(test_llm_generator,
64
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_skip_speculation
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
def
test_skip_speculation
(
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify job failure with RuntimeError when all sequences skip speculation.
We do this by setting the max model len of the draft model to an
artificially low value, such that when the sequences grow beyond it, they
...
...
@@ -114,9 +114,13 @@ def test_skip_speculation(baseline_llm_generator, test_llm_generator,
TODO: fix it to pass without raising Error. (#5814)
"""
with
pytest
.
raises
(
RuntimeError
):
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
with
pytest
.
raises
(
openai
.
APIConnectionError
):
run_equality_correctness_test_tp
(
MAIN_MODEL
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
output_len
,
seed
,
temperature
=
0.0
)
tests/spec_decode/e2e/test_logprobs.py
View file @
775f00f8
import
math
from
itertools
import
cycle
import
pytest
from
vllm
import
SamplingParams
from
.conftest
import
get
_logprob
s_from_llm_generator
from
.conftest
import
run
_logprob
_correctness_test
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
"model
_name
"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
"max_logprobs"
:
6
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
...
...
@@ -36,64 +34,29 @@ from .conftest import get_logprobs_from_llm_generator
7
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_logprobs_equality
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
@
pytest
.
mark
.
parametrize
(
"logprobs"
,
[
1
,
6
])
def
test_logprobs_equality
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
,
logprobs
:
int
):
"""Verify output logprobs are equal with and without speculative decoding.
"""
run_greedy_logprobs_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
run_logprob_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
"max_logprobs"
:
6
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_model"
:
"JackFram/llama-160m"
,
"num_speculative_tokens"
:
3
,
"disable_logprobs_during_spec_decoding"
:
False
,
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
6
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
7
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_diff_num_logprobs
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
,
num_logprobs
:
int
):
"""Verify output logprobs are equal with and without spec decode.
This specifies a number of logprobs >1.
"""
run_greedy_logprobs_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
,
logprob_rank
=
num_logprobs
)
output_len
,
seed
,
temperature
=
0.0
,
logprobs
=
logprobs
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
"model
_name
"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
...
...
@@ -121,21 +84,29 @@ def test_diff_num_logprobs(baseline_llm_generator, test_llm_generator,
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_logprobs_different_k
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
@
pytest
.
mark
.
parametrize
(
"logprobs"
,
[
1
,
6
])
def
test_logprobs_different_k
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
,
logprobs
:
int
):
"""Veriy logprob greedy equality with different speculation lens.
"""
run_greedy_logprobs_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
run_logprob_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
output_len
,
seed
,
temperature
=
0.0
,
logprobs
=
logprobs
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
"model
_name
"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
...
...
@@ -164,22 +135,30 @@ def test_logprobs_different_k(baseline_llm_generator, test_llm_generator,
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_logprobs_when_skip_speculation
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
@
pytest
.
mark
.
parametrize
(
"logprobs"
,
[
1
])
def
test_logprobs_when_skip_speculation
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
,
logprobs
:
int
):
"""Verify logprobs greedy equality when some sequences skip speculation.
"""
run_greedy_logprobs_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
run_logprob_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
output_len
,
seed
,
temperature
=
0.0
,
logprobs
=
logprobs
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
"model
_name
"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
...
...
@@ -203,19 +182,17 @@ def test_logprobs_when_skip_speculation(baseline_llm_generator,
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_logprobs_temp_1
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
@
pytest
.
mark
.
parametrize
(
"logprobs"
,
[
6
])
def
test_logprobs_temp_1
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
,
logprobs
:
int
):
"""Verify at least one logprob result has num_logprobs+1, which tests the
case where the sampled token is not in top-k logprobs.
Ideally, this test should validate equality with non-spec by getting
logprobs. This is left as future improvement.
"""
batch_size
=
8
max_output_len
=
output_len
force_output_len
=
True
logprob_rank
=
5
temperature
=
1.0
prompts
=
[
...
...
@@ -231,129 +208,40 @@ def test_logprobs_temp_1(baseline_llm_generator, test_llm_generator,
prompts
=
[
prompt
for
prompt
,
_
in
zip
(
cycle
(
prompts
),
range
(
batch_size
))]
# If the test requires that we generated max_output_len tokens, then set the
# sampling params to ignore eos token.
ignore_eos
=
force_output_len
sampling_params
=
SamplingParams
(
max_tokens
=
max_
output_len
,
ignore_eos
=
ignore_eos
,
max_tokens
=
output_len
,
ignore_eos
=
True
,
temperature
=
temperature
,
logprobs
=
logprob
_rank
,
logprobs
=
logprob
s
,
)
spec_batch_logprobs
=
get_logprobs_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
sd_args
=
{
**
common_llm_kwargs
,
**
per_test_common_llm_kwargs
,
**
test_llm_kwargs
,
}
with
vllm_runner
(
**
sd_args
)
as
vllm_model
:
sd_outputs
=
vllm_model
.
generate_w_logprobs
(
prompts
,
sampling_params
)
num_returned_logprobs
=
[
len
(
logprob_dict
)
for
seq_logprobs
in
spec_batch_logprobs
for
logprob_dict
in
seq_logprobs
len
(
seq_logprobs
)
for
seq_logprobs
in
sd_outputs
[
-
1
]
]
# Assert one of the returned logprobs has > num_logprobs (indicating the
# sampled token is not in top-k).
assert
any
([
num_returned
>
logprob_rank
for
num_returned
in
num_returned_logprobs
])
def
run_greedy_logprobs_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
,
force_output_len
:
bool
,
logprob_rank
:
int
=
1
):
"""Helper method that compares the logprobs outputs of both the baseline LLM
and the test LLM. It asserts greedy equality of the logprobs when the
temperature is zero.
"""
temperature
=
0.0
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
"San Francisco is know for its"
,
"Facebook was created in 2004 by"
,
"Curious George is a"
,
"Python 3.11 brings improvements to its"
,
]
prompts
=
[
prompt
for
prompt
,
_
in
zip
(
cycle
(
prompts
),
range
(
batch_size
))]
# If the test requires that we generated max_output_len tokens, then set the
# sampling params to ignore eos token.
ignore_eos
=
force_output_len
sampling_params
=
SamplingParams
(
max_tokens
=
max_output_len
,
ignore_eos
=
ignore_eos
,
temperature
=
temperature
,
logprobs
=
logprob_rank
,
)
spec_batch_logprobs
=
get_logprobs_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
baseline_batch_logprobs
=
get_logprobs_from_llm_generator
(
baseline_llm_generator
,
prompts
,
sampling_params
)
assert
len
(
baseline_batch_logprobs
)
==
len
(
prompts
)
assert
len
(
spec_batch_logprobs
)
==
len
(
prompts
)
# For each sequence in the batch.
for
i
,
(
baseline_logprobs
,
spec_logprobs
)
in
enumerate
(
zip
(
baseline_batch_logprobs
,
spec_batch_logprobs
)):
assert
len
(
spec_logprobs
)
==
len
(
baseline_logprobs
)
# For each generated position of the sequence.
for
pos
,
(
spec_pos_logprobs
,
baseline_pos_logprobs
)
in
enumerate
(
zip
(
spec_logprobs
,
baseline_logprobs
)):
# Map rank to token/logprob in spec output.
spec_rank_to_token_id
=
{
value
.
rank
:
key
for
key
,
value
in
spec_pos_logprobs
.
items
()
}
spec_rank_to_logprob
=
{
value
.
rank
:
value
.
logprob
for
key
,
value
in
spec_pos_logprobs
.
items
()
}
# Map rank to token/logprob in baseline output.
baseline_rank_to_token_id
=
{
value
.
rank
:
key
for
key
,
value
in
baseline_pos_logprobs
.
items
()
}
baseline_rank_to_logprob
=
{
value
.
rank
:
value
.
logprob
for
key
,
value
in
baseline_pos_logprobs
.
items
()
}
# Assert set of ranks returned is equal.
assert
set
(
spec_rank_to_token_id
.
keys
())
==
set
(
baseline_rank_to_token_id
.
keys
())
# Assert each logprob/token id is correct, keyed by rank.
for
rank
in
sorted
(
set
(
spec_rank_to_token_id
.
keys
())):
assert
spec_rank_to_token_id
[
rank
]
==
baseline_rank_to_token_id
[
rank
],
f
"
{
rank
}
"
assert
math
.
isclose
(
a
=
spec_rank_to_logprob
[
rank
],
b
=
baseline_rank_to_logprob
[
rank
],
abs_tol
=
1e-1
,
)
assert
any
(
[
num_returned
>
logprobs
for
num_returned
in
num_returned_logprobs
])
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-160m"
,
"model
_name
"
:
"JackFram/llama-160m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
"max_logprobs"
:
6
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
...
...
@@ -364,57 +252,28 @@ def run_greedy_logprobs_correctness_test(baseline_llm_generator,
"disable_logprobs_during_spec_decoding"
:
True
,
}])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_logprobs_disabled
(
baseline_llm_generator
,
test_llm_generator
):
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"logprobs"
,
[
0
])
def
test_logprobs_disabled
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
,
logprobs
:
int
):
"""Check the behavior when logprobs are disabled.
Token choices should match with the base model.
"""
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
"San Francisco is know for its"
,
"Facebook was created in 2004 by"
,
"Curious George is a"
,
"Python 3.11 brings improvements to its"
,
]
prompts
=
[
prompt
for
prompt
,
_
in
zip
(
cycle
(
prompts
),
range
(
4
))]
sampling_params
=
SamplingParams
(
# Use smaller output len for fast test
max_tokens
=
7
,
ignore_eos
=
True
,
run_logprob_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
output_len
,
seed
,
temperature
=
0.0
,
logprobs
=
2
,
)
spec_batch_logprobs
=
get_logprobs_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
baseline_batch_logprobs
=
get_logprobs_from_llm_generator
(
baseline_llm_generator
,
prompts
,
sampling_params
)
assert
len
(
baseline_batch_logprobs
)
==
len
(
prompts
)
assert
len
(
spec_batch_logprobs
)
==
len
(
prompts
)
# For each sequence in the batch.
for
_
,
(
baseline_logprobs
,
spec_logprobs
)
in
enumerate
(
zip
(
baseline_batch_logprobs
,
spec_batch_logprobs
)):
assert
len
(
spec_logprobs
)
==
len
(
baseline_logprobs
)
# For each generated position of the sequence.
for
_
,
(
spec_pos_logprobs
,
baseline_pos_logprobs
)
in
enumerate
(
zip
(
spec_logprobs
,
baseline_logprobs
)):
assert
len
(
spec_pos_logprobs
)
==
1
spec_top_token_id
=
list
(
spec_pos_logprobs
)[
0
]
spec_top_logprob
=
spec_pos_logprobs
[
spec_top_token_id
]
assert
spec_top_logprob
.
logprob
==
0.0
assert
spec_top_logprob
.
rank
==
-
1
# check that the chosen token matches the base model
baseline_logprob
=
baseline_pos_logprobs
[
spec_top_token_id
]
assert
baseline_logprob
.
rank
==
1
assert
spec_top_logprob
.
decoded_token
\
==
baseline_logprob
.
decoded_token
logprobs
=
logprobs
)
tests/spec_decode/e2e/test_medusa_correctness.py
View file @
775f00f8
...
...
@@ -21,7 +21,7 @@ correctess for the target model outputs.
import
pytest
from
.conftest
import
run_
greedy_
equality_correctness_test
from
.conftest
import
run_equality_correctness_test
# main model
# lmsys/vicuna-7b-v1.3 was to be used but it's causing
...
...
@@ -55,7 +55,7 @@ PRECISION = "float32"
"dtype"
:
PRECISION
,
# Main model
"model"
:
MAIN_MODEL
,
"model
_name
"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
...
...
@@ -70,15 +70,21 @@ PRECISION = "float32"
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_medusa_e2e_greedy_correctness
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
def
test_medusa_e2e_greedy_correctness
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality with different batch size."""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -96,7 +102,7 @@ def test_medusa_e2e_greedy_correctness(baseline_llm_generator,
"dtype"
:
PRECISION
,
# Main model
"model"
:
MAIN_MODEL
,
"model
_name
"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
...
...
@@ -111,17 +117,21 @@ def test_medusa_e2e_greedy_correctness(baseline_llm_generator,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_medusa_e2e_greedy_correctness_cuda_graph
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
def
test_medusa_e2e_greedy_correctness_cuda_graph
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality with cuda graph enabled and different
batch sizes."""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -142,7 +152,7 @@ def test_medusa_e2e_greedy_correctness_cuda_graph(baseline_llm_generator,
"dtype"
:
PRECISION
,
# Main model
"model"
:
MAIN_MODEL
,
"model
_name
"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
...
...
@@ -160,18 +170,22 @@ def test_medusa_e2e_greedy_correctness_cuda_graph(baseline_llm_generator,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_medusa_e2e_greedy_correctness_with_preemption
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
def
test_medusa_e2e_greedy_correctness_with_preemption
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -187,7 +201,7 @@ def test_medusa_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
"dtype"
:
PRECISION
,
# Main model
"model"
:
MAIN_MODEL
,
"model
_name
"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
...
...
@@ -209,16 +223,22 @@ def test_medusa_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_medusa_different_k
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
def
test_medusa_different_k
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify that medusa speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -234,7 +254,7 @@ def test_medusa_different_k(baseline_llm_generator, test_llm_generator,
"dtype"
:
PRECISION
,
# Main model
"model"
:
MAIN_MODEL
,
"model
_name
"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
...
...
@@ -252,17 +272,23 @@ def test_medusa_different_k(baseline_llm_generator, test_llm_generator,
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_medusa_disable_queue
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
def
test_medusa_disable_queue
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify that medusa speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
seed
=
seed
,
temperature
=
0.0
)
if
__name__
==
"__main__"
:
...
...
tests/spec_decode/e2e/test_mlp_correctness.py
View file @
775f00f8
...
...
@@ -25,8 +25,7 @@ import pytest
from
vllm.model_executor.layers.vocab_parallel_embedding
import
pad_vocab_size
from
.conftest
import
(
run_equality_correctness_test
,
run_greedy_equality_correctness_test
)
from
.conftest
import
run_equality_correctness_test
# main model
MAIN_MODEL
=
"JackFram/llama-160m"
...
...
@@ -58,7 +57,7 @@ PRECISION = "float32"
"dtype"
:
PRECISION
,
# Main model
"model"
:
MAIN_MODEL
,
"model
_name
"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
...
...
@@ -72,14 +71,21 @@ PRECISION = "float32"
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_mlp_e2e_greedy_correctness
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
def
test_mlp_e2e_greedy_correctness
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality with different batch size."""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -98,7 +104,7 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
"dtype"
:
PRECISION
,
# Main model
"model"
:
MAIN_MODEL
,
"model
_name
"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
...
...
@@ -110,17 +116,21 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
2048
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_mlp_e2e_acceptance_rate
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
def
test_mlp_e2e_acceptance_rate
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify acceptance rate with different batch size and large output
length."""
run_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
temperature
=
0.0
,
seeded
=
True
,
force_output_len
=
True
,
seed
=
seed
,
expected_acceptance_rate
=
0.48
)
...
...
@@ -140,7 +150,7 @@ def test_mlp_e2e_acceptance_rate(baseline_llm_generator, test_llm_generator,
"dtype"
:
PRECISION
,
# Main model
"model"
:
MAIN_MODEL
,
"model
_name
"
:
MAIN_MODEL
,
# Speculative model
"speculative_model"
:
SPEC_MODEL
,
...
...
@@ -151,28 +161,35 @@ def test_mlp_e2e_acceptance_rate(baseline_llm_generator, test_llm_generator,
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"temperature"
,
[
0.1
,
1.0
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
None
])
def
test_mlp_e2e_seeded_correctness
(
baseline_llm_generator
,
test_llm_generator
,
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_mlp_e2e_seeded_correctness
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
temperature
:
float
):
temperature
:
float
,
seed
:
int
):
"""Verify seeded runs produce the same output."""
run_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
temperature
=
temperature
,
seeded
=
True
,
force_output_len
=
True
)
seed
=
seed
)
# Ensure this same test does fail if we _don't_ include per-request seeds
with
pytest
.
raises
(
AssertionError
):
run_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
temperature
=
temperature
,
seed
ed
=
False
,
force_output_len
=
True
)
seed
=
seed
,
disable_seed
=
True
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -193,7 +210,7 @@ def test_mlp_e2e_seeded_correctness(baseline_llm_generator, test_llm_generator,
"dtype"
:
PRECISION
,
# Main model
"model"
:
MAIN_MODEL
,
"model
_name
"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
...
...
@@ -210,18 +227,22 @@ def test_mlp_e2e_seeded_correctness(baseline_llm_generator, test_llm_generator,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_mlp_e2e_greedy_correctness_with_preemption
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
def
test_mlp_e2e_greedy_correctness_with_preemption
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -242,7 +263,7 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
"dtype"
:
PRECISION
,
# Main model
"model"
:
MAIN_MODEL
,
"model
_name
"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
...
...
@@ -259,10 +280,10 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_mlp_e2e_greedy_correctness_with_padding
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
def
test_mlp_e2e_greedy_correctness_with_padding
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality when the vocab dimension is padded
"""
...
...
@@ -273,11 +294,15 @@ def test_mlp_e2e_greedy_correctness_with_padding(baseline_llm_generator,
with
patch
(
"vllm.model_executor.layers.vocab_parallel_embedding.pad_vocab_size"
,
patched_pad_vocab_size
):
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -293,7 +318,7 @@ def test_mlp_e2e_greedy_correctness_with_padding(baseline_llm_generator,
"dtype"
:
PRECISION
,
# Main model
"model"
:
MAIN_MODEL
,
"model
_name
"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
...
...
@@ -315,16 +340,22 @@ def test_mlp_e2e_greedy_correctness_with_padding(baseline_llm_generator,
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_mlp_different_k
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
def
test_mlp_different_k
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
seed
:
int
,
output_len
:
int
):
"""Verify that mlp speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -340,7 +371,7 @@ def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
"dtype"
:
PRECISION
,
# Main model
"model"
:
MAIN_MODEL
,
"model
_name
"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
...
...
@@ -357,14 +388,20 @@ def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_mlp_disable_queue
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
def
test_mlp_disable_queue
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
seed
:
int
,
output_len
:
int
):
"""Verify that mlp speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
seed
=
seed
,
temperature
=
0.0
)
tests/spec_decode/e2e/test_multistep_correctness.py
View file @
775f00f8
This diff is collapsed.
Click to expand it.
tests/spec_decode/e2e/test_ngram_correctness.py
View file @
775f00f8
...
...
@@ -26,7 +26,7 @@ for the target model outputs.
import
pytest
from
.conftest
import
run_
greedy_
equality_correctness_test
from
.conftest
import
run_equality_correctness_test
@
pytest
.
mark
.
parametrize
(
...
...
@@ -43,7 +43,7 @@ from .conftest import run_greedy_equality_correctness_test
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
{
"model"
:
"JackFram/llama-68m"
,
"model
_name
"
:
"JackFram/llama-68m"
,
},
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
...
...
@@ -59,15 +59,21 @@ from .conftest import run_greedy_equality_correctness_test
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_ngram_e2e_greedy_correctness
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
def
test_ngram_e2e_greedy_correctness
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality on a tiny model with different batch size."""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -86,7 +92,7 @@ def test_ngram_e2e_greedy_correctness(baseline_llm_generator,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
{
"model"
:
"JackFram/llama-160m"
,
"model
_name
"
:
"JackFram/llama-160m"
,
},
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
...
...
@@ -105,24 +111,28 @@ def test_ngram_e2e_greedy_correctness(baseline_llm_generator,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_ngram_e2e_greedy_correctness_with_preemption
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
def
test_ngram_e2e_greedy_correctness_with_preemption
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
temperature
=
0
,
seed
=
seed
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
"model
_name
"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
...
...
@@ -159,23 +169,29 @@ def test_ngram_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_ngram_different_k
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
def
test_ngram_different_k
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and
different ngram_prompt_lookup_max.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
seed
=
seed
,
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
"model
_name
"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
...
...
@@ -200,14 +216,20 @@ def test_ngram_different_k(baseline_llm_generator, test_llm_generator,
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_ngram_disable_queue
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
def
test_ngram_disable_queue
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and
different ngram_prompt_lookup_max.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
seed
=
seed
,
temperature
=
0.0
)
tests/spec_decode/e2e/test_seed.py
View file @
775f00f8
...
...
@@ -2,11 +2,17 @@ import pytest
from
.conftest
import
run_equality_correctness_test
# main model
MAIN_MODEL
=
"JackFram/llama-68m"
# speculative model
SPEC_MODEL
=
"JackFram/llama-160m"
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
"model
_name
"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
...
...
@@ -31,26 +37,34 @@ from .conftest import run_equality_correctness_test
# Use smaller output len for fast test.
20
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
None
])
def
test_seeded_consistency
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
temperature
:
floa
t
,
output_len
:
int
):
def
test_seeded_consistency
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
in
t
,
temperature
:
float
,
output_len
:
int
):
"""Verify outputs are consistent across multiple runs with same seed
"""
run_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
temperature
=
temperature
,
seeded
=
Tru
e
,
force_output_len
=
True
)
disable_seed
=
Fals
e
,
)
# Ensure this same test does fail if we _don't_ include per-request seeds
with
pytest
.
raises
(
AssertionError
):
run_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
temperature
=
temperature
,
seeded
=
Fals
e
,
force_output_len
=
True
)
disable_seed
=
Tru
e
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment