Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1591c68f
Commit
1591c68f
authored
May 25, 2024
by
zhuwenwen
Browse files
merge v0.4.2
parents
09bcf00b
c7f2cf2b
Changes
265
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1600 additions
and
366 deletions
+1600
-366
tests/quantization/test_fp8.py
tests/quantization/test_fp8.py
+1
-1
tests/samplers/test_ignore_eos.py
tests/samplers/test_ignore_eos.py
+31
-0
tests/samplers/test_logprobs.py
tests/samplers/test_logprobs.py
+41
-3
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+46
-34
tests/spec_decode/e2e/conftest.py
tests/spec_decode/e2e/conftest.py
+242
-3
tests/spec_decode/e2e/test_compatibility.py
tests/spec_decode/e2e/test_compatibility.py
+11
-4
tests/spec_decode/e2e/test_logprobs.py
tests/spec_decode/e2e/test_logprobs.py
+335
-0
tests/spec_decode/e2e/test_multistep_correctness.py
tests/spec_decode/e2e/test_multistep_correctness.py
+41
-59
tests/spec_decode/e2e/test_ngram_correctness.py
tests/spec_decode/e2e/test_ngram_correctness.py
+172
-0
tests/spec_decode/test_multi_step_worker.py
tests/spec_decode/test_multi_step_worker.py
+80
-70
tests/spec_decode/test_ngram_worker.py
tests/spec_decode/test_ngram_worker.py
+206
-0
tests/spec_decode/test_spec_decode_worker.py
tests/spec_decode/test_spec_decode_worker.py
+73
-57
tests/spec_decode/utils.py
tests/spec_decode/utils.py
+9
-49
tests/tensorizer_loader/tensorize_vllm_model_for_testing.py
tests/tensorizer_loader/tensorize_vllm_model_for_testing.py
+3
-3
tests/tensorizer_loader/test_tensorizer.py
tests/tensorizer_loader/test_tensorizer.py
+2
-2
tests/test_logger.py
tests/test_logger.py
+188
-1
tests/test_logits_processor.py
tests/test_logits_processor.py
+10
-6
tests/tokenization/test_tokenizer.py
tests/tokenization/test_tokenizer.py
+20
-0
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+69
-64
tests/worker/test_swap.py
tests/worker/test_swap.py
+20
-10
No files found.
tests/quantization/test_fp8.py
View file @
1591c68f
...
@@ -20,5 +20,5 @@ def test_load_fp16_model(vllm_runner) -> None:
...
@@ -20,5 +20,5 @@ def test_load_fp16_model(vllm_runner) -> None:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
fc1
=
model
.
model
.
decoder
.
layers
[
0
].
fc1
fc1
=
model
.
model
.
decoder
.
layers
[
0
].
fc1
assert
isinstance
(
fc1
.
linear
_method
,
Fp8LinearMethod
)
assert
isinstance
(
fc1
.
quant
_method
,
Fp8LinearMethod
)
assert
fc1
.
weight
.
dtype
==
torch
.
float8_e4m3fn
assert
fc1
.
weight
.
dtype
==
torch
.
float8_e4m3fn
tests/samplers/test_ignore_eos.py
0 → 100644
View file @
1591c68f
"""Make sure ignore_eos works.
Run `pytest tests/samplers/test_ignore_eos.py`.
"""
import
pytest
from
vllm
import
SamplingParams
MODELS
=
[
"facebook/opt-125m"
]
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
1024
])
def
test_beam_search_single_input
(
vllm_runner
,
example_prompts
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
)
->
None
:
example_prompts
=
"1 + 1 is"
vllm_model
=
vllm_runner
(
model
,
dtype
=
dtype
)
sampling_params
=
SamplingParams
(
max_tokens
=
max_tokens
,
ignore_eos
=
True
)
ignore_eos_output
=
vllm_model
.
model
.
generate
(
example_prompts
,
sampling_params
=
sampling_params
)
print
(
len
(
ignore_eos_output
[
0
].
outputs
[
0
].
token_ids
))
assert
max_tokens
-
len
(
ignore_eos_output
[
0
].
outputs
[
0
].
token_ids
)
<
10
assert
max_tokens
-
len
(
ignore_eos_output
[
0
].
outputs
[
0
].
token_ids
)
>=
0
tests/samplers/test_logprobs.py
View file @
1591c68f
...
@@ -9,15 +9,26 @@ MODELS = ["facebook/opt-125m"]
...
@@ -9,15 +9,26 @@ MODELS = ["facebook/opt-125m"]
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"chunked_prefill_token_size"
,
[
1
,
4
,
16
,
-
1
])
@
pytest
.
mark
.
parametrize
(
"num_top_logprobs"
,
[
6
])
# 32000 == vocab_size
def
test_get_prompt_logprobs
(
def
test_get_prompt_logprobs
(
hf_runner
,
hf_runner
,
vllm_runner
,
vllm_runner
,
model
,
model
,
dtype
,
dtype
,
chunked_prefill_token_size
:
int
,
num_top_logprobs
:
int
,
example_prompts
,
example_prompts
,
):
):
max_num_seqs
=
256
enable_chunked_prefill
=
False
max_num_batched_tokens
=
None
if
chunked_prefill_token_size
!=
-
1
:
enable_chunked_prefill
=
True
max_num_seqs
=
min
(
chunked_prefill_token_size
,
max_num_seqs
)
max_num_batched_tokens
=
chunked_prefill_token_size
max_tokens
=
5
max_tokens
=
5
num_top_logprobs
=
6
hf_model
=
hf_runner
(
model
,
dtype
=
dtype
)
hf_model
=
hf_runner
(
model
,
dtype
=
dtype
)
hf_logprobs
=
hf_model
.
generate_greedy_logprobs
(
hf_logprobs
=
hf_model
.
generate_greedy_logprobs
(
example_prompts
,
example_prompts
,
...
@@ -25,10 +36,17 @@ def test_get_prompt_logprobs(
...
@@ -25,10 +36,17 @@ def test_get_prompt_logprobs(
)
)
del
hf_model
del
hf_model
vllm_model
=
vllm_runner
(
model
,
dtype
=
dtype
,
max_logprobs
=
num_top_logprobs
)
vllm_model
=
vllm_runner
(
model
,
dtype
=
dtype
,
max_logprobs
=
num_top_logprobs
,
enable_chunked_prefill
=
enable_chunked_prefill
,
max_num_batched_tokens
=
max_num_batched_tokens
,
max_num_seqs
=
max_num_seqs
,
)
vllm_sampling_params
=
SamplingParams
(
max_tokens
=
max_tokens
,
vllm_sampling_params
=
SamplingParams
(
max_tokens
=
max_tokens
,
logprobs
=
num_top_logprobs
,
logprobs
=
num_top_logprobs
,
prompt_logprobs
=
5
,
prompt_logprobs
=
num_top_logprobs
,
temperature
=
0.0
)
temperature
=
0.0
)
vllm_results
=
vllm_model
.
model
.
generate
(
vllm_results
=
vllm_model
.
model
.
generate
(
example_prompts
,
sampling_params
=
vllm_sampling_params
)
example_prompts
,
sampling_params
=
vllm_sampling_params
)
...
@@ -52,9 +70,18 @@ def test_get_prompt_logprobs(
...
@@ -52,9 +70,18 @@ def test_get_prompt_logprobs(
"The output text from the top logprob for each token position "
"The output text from the top logprob for each token position "
"should be the same as the output text in the result."
)
"should be the same as the output text in the result."
)
# The first prompt logprob is always None
assert
result
.
prompt_logprobs
[
0
]
is
None
for
prompt_logprobs
in
result
.
prompt_logprobs
[
1
:]:
# If the prompt token is not included in the top X
# logprob, it can return 1 more data
assert
(
len
(
prompt_logprobs
)
==
num_top_logprobs
or
len
(
prompt_logprobs
)
==
num_top_logprobs
+
1
)
# Test whether prompt logprobs are consistent with HF
# Test whether prompt logprobs are consistent with HF
for
vllm_result
,
hf_logprob
in
zip
(
vllm_results
,
hf_logprobs
):
for
vllm_result
,
hf_logprob
in
zip
(
vllm_results
,
hf_logprobs
):
# Check prompt logprobs
# Check prompt logprobs
# The first prompt logprob is always None, so we compare it from 1:.
vllm_prompt_logprobs
=
vllm_result
.
prompt_logprobs
[
1
:]
vllm_prompt_logprobs
=
vllm_result
.
prompt_logprobs
[
1
:]
for
i
,
vllm_prompt_logprob_dict
in
enumerate
(
vllm_prompt_logprobs
):
for
i
,
vllm_prompt_logprob_dict
in
enumerate
(
vllm_prompt_logprobs
):
for
token_id
,
logprob
in
vllm_prompt_logprob_dict
.
items
():
for
token_id
,
logprob
in
vllm_prompt_logprob_dict
.
items
():
...
@@ -74,6 +101,17 @@ def test_get_prompt_logprobs(
...
@@ -74,6 +101,17 @@ def test_get_prompt_logprobs(
"The token should be decoded by the time it is returned "
"The token should be decoded by the time it is returned "
" to the user."
)
" to the user."
)
# Test if prompt logprobs are correctly set.
for
vllm_result
in
vllm_results
:
token_ids
=
vllm_result
.
prompt_token_ids
prompt_logprobs
=
vllm_result
.
prompt_logprobs
# The first token doesn't have logprob.
assert
prompt_logprobs
[
0
]
is
None
for
token_id
,
logprob_dict
in
zip
(
token_ids
[
1
:],
prompt_logprobs
[
1
:]):
assert
token_id
in
logprob_dict
def
test_max_logprobs
():
def
test_max_logprobs
():
runner
=
VllmRunner
(
"facebook/opt-125m"
,
max_logprobs
=
1
)
runner
=
VllmRunner
(
"facebook/opt-125m"
,
max_logprobs
=
1
)
...
...
tests/samplers/test_sampler.py
View file @
1591c68f
...
@@ -8,6 +8,7 @@ import torch
...
@@ -8,6 +8,7 @@ import torch
from
transformers
import
GenerationConfig
,
GenerationMixin
from
transformers
import
GenerationConfig
,
GenerationMixin
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
Counter
from
vllm.utils
import
Counter
...
@@ -54,9 +55,10 @@ def _do_sample(
...
@@ -54,9 +55,10 @@ def _do_sample(
sampler
:
MockLogitsSampler
,
sampler
:
MockLogitsSampler
,
model_runner
:
ModelRunner
,
model_runner
:
ModelRunner
,
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
device
:
str
,
):
):
seq_group_metadata_list
=
[]
seq_group_metadata_list
=
[]
prompt
_lens
=
[]
seq
_lens
=
[]
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
SequenceGroupMetadata
(
...
@@ -66,11 +68,14 @@ def _do_sample(
...
@@ -66,11 +68,14 @@ def _do_sample(
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
block_tables
=
{
0
:
[
1
]},
block_tables
=
{
0
:
[
1
]},
))
))
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
seq_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
sampling_metadata
=
SamplingMetadata
.
prepare
(
prompt_lens
,
seq_group_metadata_list
,
subquery_lens
=
prompt_lens
)
seq_lens
,
query_lens
=
seq_lens
,
device
=
device
,
pin_memory
=
model_runner
.
pin_memory
)
return
sampler
(
logits
=
input_tensor
,
sampling_metadata
=
sampling_metadata
)
return
sampler
(
logits
=
input_tensor
,
sampling_metadata
=
sampling_metadata
)
...
@@ -85,7 +90,7 @@ def test_sampler_all_greedy(seed: int, device: str):
...
@@ -85,7 +90,7 @@ def test_sampler_all_greedy(seed: int, device: str):
sampling_params
=
SamplingParams
(
temperature
=
0
)
sampling_params
=
SamplingParams
(
temperature
=
0
)
sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
model_runner
,
sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
model_runner
,
sampling_params
)
sampling_params
,
device
)
expected
=
torch
.
argmax
(
fake_logits
,
dim
=-
1
)
expected
=
torch
.
argmax
(
fake_logits
,
dim
=-
1
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
nth_output
in
sequence_output
.
samples
:
for
nth_output
in
sequence_output
.
samples
:
...
@@ -111,7 +116,7 @@ def test_sampler_all_random(seed: int, device: str):
...
@@ -111,7 +116,7 @@ def test_sampler_all_random(seed: int, device: str):
n
=
random
.
randint
(
1
,
10
),
n
=
random
.
randint
(
1
,
10
),
)
)
sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
model_runner
,
sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
model_runner
,
sampling_params
)
sampling_params
,
device
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
nth_output
in
sequence_output
.
samples
:
for
nth_output
in
sequence_output
.
samples
:
...
@@ -137,7 +142,7 @@ def test_sampler_all_random_seed(seed: int, device: str):
...
@@ -137,7 +142,7 @@ def test_sampler_all_random_seed(seed: int, device: str):
seed
=
random
.
randint
(
0
,
10000
),
seed
=
random
.
randint
(
0
,
10000
),
)
)
sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
model_runner
,
sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
model_runner
,
sampling_params
)
sampling_params
,
device
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
nth_output
in
sequence_output
.
samples
:
for
nth_output
in
sequence_output
.
samples
:
...
@@ -160,10 +165,10 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str):
...
@@ -160,10 +165,10 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str):
seed
=
random
.
randint
(
0
,
10000
),
seed
=
random
.
randint
(
0
,
10000
),
)
)
first_sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
first_sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
model_runner
,
sampling_params
)
model_runner
,
sampling_params
,
device
)
second_sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
second_sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
model_runner
,
sampling_params
)
model_runner
,
sampling_params
,
device
)
assert
first_sampler_output
==
second_sampler_output
assert
first_sampler_output
==
second_sampler_output
...
@@ -183,7 +188,8 @@ def test_sampler_all_beam(seed: int, device: str):
...
@@ -183,7 +188,8 @@ def test_sampler_all_beam(seed: int, device: str):
best_of
=
2
,
best_of
=
2
,
use_beam_search
=
True
,
use_beam_search
=
True
,
)
)
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
model_runner
,
sampling_params
)
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
model_runner
,
sampling_params
,
device
)
# no assertion here as I am not sure how to determine whether
# no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
# whether there are no exceptions in the sampler
...
@@ -201,7 +207,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
...
@@ -201,7 +207,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
def
create_sampling_params
(
min_tokens
,
def
create_sampling_params
(
min_tokens
,
eos_token_id
=
0
,
eos_token_id
=
0
,
*
,
*
,
stop_token_ids
:
Optional
[
List
[
str
]]
=
None
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
prompt_logprobs
:
Optional
[
int
]
=
None
):
prompt_logprobs
:
Optional
[
int
]
=
None
):
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
min_tokens
=
min_tokens
,
min_tokens
=
min_tokens
,
...
@@ -210,7 +216,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
...
@@ -210,7 +216,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
# requesting prompt_logprobs changes the structure of `logits`
# requesting prompt_logprobs changes the structure of `logits`
prompt_logprobs
=
prompt_logprobs
,
prompt_logprobs
=
prompt_logprobs
,
)
)
sampling_params
.
eos
_token_id
=
eos_token_id
sampling_params
.
all_stop
_token_id
s
.
add
(
eos_token_id
)
return
sampling_params
return
sampling_params
def
create_sequence_data
(
num_input
=
3
,
num_generated
=
0
):
def
create_sequence_data
(
num_input
=
3
,
num_generated
=
0
):
...
@@ -415,7 +421,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
...
@@ -415,7 +421,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
"Invalid test case, need seq_group_metadata_list"
"Invalid test case, need seq_group_metadata_list"
batch_size
=
0
batch_size
=
0
prompt
_lens
=
[]
seq
_lens
=
[]
sampling_params_per_row
=
[]
sampling_params_per_row
=
[]
for
sgm
in
seq_group_metadata_list
:
for
sgm
in
seq_group_metadata_list
:
sampling_params
=
sgm
.
sampling_params
sampling_params
=
sgm
.
sampling_params
...
@@ -425,7 +431,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
...
@@ -425,7 +431,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
# a prompt seq_group has only one sequence
# a prompt seq_group has only one sequence
seq_data
=
next
(
iter
(
sgm
.
seq_data
.
values
()))
seq_data
=
next
(
iter
(
sgm
.
seq_data
.
values
()))
prompt_len
=
seq_data
.
get_prompt_len
()
prompt_len
=
seq_data
.
get_prompt_len
()
prompt
_lens
.
append
(
prompt_len
)
seq
_lens
.
append
(
prompt_len
)
if
sgm
.
sampling_params
.
prompt_logprobs
:
if
sgm
.
sampling_params
.
prompt_logprobs
:
# with prompt_logprobs each token in the prompt has a row in
# with prompt_logprobs each token in the prompt has a row in
...
@@ -443,20 +449,19 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
...
@@ -443,20 +449,19 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
"batch size"
)
"batch size"
)
_
,
fake_logits
,
sampler
,
model_runner
=
_prepare_test
(
batch_size
)
_
,
fake_logits
,
sampler
,
model_runner
=
_prepare_test
(
batch_size
)
sampling_metadata
=
model_runner
.
_prepare_sampl
e
(
sampling_metadata
=
SamplingMetadata
.
prepar
e
(
seq_group_metadata_list
,
seq_group_metadata_list
,
prompt_lens
=
prompt_lens
if
prompt_lens
else
None
,
seq_lens
=
seq_lens
if
seq_lens
else
None
,
subquery_lens
=
prompt_lens
if
prompt_lens
else
None
)
query_lens
=
seq_lens
if
seq_lens
else
None
,
device
=
device
,
pin_memory
=
model_runner
.
pin_memory
)
# the logits tensor is modified in-place by the sampler
# the logits tensor is modified in-place by the sampler
_
=
sampler
(
logits
=
fake_logits
,
sampling_metadata
=
sampling_metadata
)
_
=
sampler
(
logits
=
fake_logits
,
sampling_metadata
=
sampling_metadata
)
for
logits_idx
,
(
should_penalize
,
sampling_params
)
in
enumerate
(
for
logits_idx
,
(
should_penalize
,
sampling_params
)
in
enumerate
(
zip
(
expected_penalization
,
sampling_params_per_row
)):
zip
(
expected_penalization
,
sampling_params_per_row
)):
tokens_to_check
=
[
sampling_params
.
eos_token_id
]
tokens_to_check
=
sampling_params
.
all_stop_token_ids
if
sampling_params
.
stop_token_ids
:
tokens_to_check
.
extend
(
sampling_params
.
stop_token_ids
)
tokens_to_check
=
set
(
tokens_to_check
)
if
should_penalize
:
if
should_penalize
:
for
token_id
in
tokens_to_check
:
for
token_id
in
tokens_to_check
:
...
@@ -492,7 +497,7 @@ def test_sampler_mixed(seed: int, device: str):
...
@@ -492,7 +497,7 @@ def test_sampler_mixed(seed: int, device: str):
seq_group_metadata_list
=
[]
seq_group_metadata_list
=
[]
expected_tokens
:
List
[
Optional
[
List
[
int
]]]
=
[]
expected_tokens
:
List
[
Optional
[
List
[
int
]]]
=
[]
prompt
_lens
=
[]
seq
_lens
=
[]
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
expected
:
Optional
[
List
[
int
]]
=
None
expected
:
Optional
[
List
[
int
]]
=
None
sampling_type
=
random
.
randint
(
0
,
3
)
sampling_type
=
random
.
randint
(
0
,
3
)
...
@@ -527,11 +532,15 @@ def test_sampler_mixed(seed: int, device: str):
...
@@ -527,11 +532,15 @@ def test_sampler_mixed(seed: int, device: str):
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
block_tables
=
{
0
:
[
1
]},
block_tables
=
{
0
:
[
1
]},
))
))
prompt
_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
seq
_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
def
test_sampling
(
model_runner
:
ModelRunner
):
def
test_sampling
(
model_runner
:
ModelRunner
):
sampling_metadata
=
model_runner
.
_prepare_sample
(
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
prompt_lens
,
subquery_lens
=
prompt_lens
)
seq_group_metadata_list
,
seq_lens
,
query_lens
=
seq_lens
,
device
=
device
,
pin_memory
=
model_runner
.
pin_memory
)
sampler_output
=
sampler
(
logits
=
fake_logits
,
sampler_output
=
sampler
(
logits
=
fake_logits
,
sampling_metadata
=
sampling_metadata
)
sampling_metadata
=
sampling_metadata
)
...
@@ -566,7 +575,7 @@ def test_sampler_mixed(seed: int, device: str):
...
@@ -566,7 +575,7 @@ def test_sampler_mixed(seed: int, device: str):
# Shuffle the batch and resample
# Shuffle the batch and resample
target_index
=
list
(
range
(
batch_size
))
target_index
=
list
(
range
(
batch_size
))
for
list_to_shuffle
in
(
target_index
,
seq_group_metadata_list
,
for
list_to_shuffle
in
(
target_index
,
seq_group_metadata_list
,
expected_tokens
,
prompt
_lens
):
expected_tokens
,
seq
_lens
):
random
.
Random
(
seed
).
shuffle
(
list_to_shuffle
)
random
.
Random
(
seed
).
shuffle
(
list_to_shuffle
)
target_index
=
torch
.
tensor
(
target_index
)
target_index
=
torch
.
tensor
(
target_index
)
input_tensor
.
data
=
input_tensor
.
index_select
(
0
,
target_index
)
input_tensor
.
data
=
input_tensor
.
index_select
(
0
,
target_index
)
...
@@ -611,7 +620,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
...
@@ -611,7 +620,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
assert
len
(
warpers
)
==
2
# top_p and top_k
assert
len
(
warpers
)
==
2
# top_p and top_k
seq_group_metadata_list
=
[]
seq_group_metadata_list
=
[]
prompt
_lens
=
[]
seq
_lens
=
[]
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
SequenceGroupMetadata
(
...
@@ -625,11 +634,14 @@ def test_sampler_top_k_top_p(seed: int, device: str):
...
@@ -625,11 +634,14 @@ def test_sampler_top_k_top_p(seed: int, device: str):
),
),
block_tables
=
{
0
:
[
1
]},
block_tables
=
{
0
:
[
1
]},
))
))
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
seq_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
sampling_metadata
=
SamplingMetadata
.
prepare
(
prompt_lens
,
seq_group_metadata_list
,
subquery_lens
=
prompt_lens
)
seq_lens
,
query_lens
=
seq_lens
,
device
=
device
,
pin_memory
=
model_runner
.
pin_memory
)
sample_probs
=
None
sample_probs
=
None
...
...
tests/spec_decode/e2e/conftest.py
View file @
1591c68f
from
typing
import
List
,
Tuple
import
asyncio
import
time
from
itertools
import
cycle
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
pytest
import
pytest
import
ray
import
torch
from
pynvml
import
(
nvmlDeviceGetHandleByIndex
,
nvmlDeviceGetMemoryInfo
,
nvmlInit
)
from
tests.conftest
import
cleanup
from
tests.conftest
import
cleanup
from
vllm
import
LLM
from
vllm
import
LLM
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
Logprob
,
MultiModalData
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
Counter
,
random_uuid
class
AsyncLLM
:
"""AsyncLLM
Note: Current LLM class in vllm don't support async mode, for test purpose,
we implement async one in here. Maybe we could move to
vllm/entrypoints/llm.py in future.
Below AsyncLLM is directly borrow from vllm/entrypoints/llm.py with changes
to make to work in async mode.
"""
def
__init__
(
self
,
model
:
str
,
tokenizer
:
Optional
[
str
]
=
None
,
tokenizer_mode
:
str
=
"auto"
,
skip_tokenizer_init
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
tensor_parallel_size
:
int
=
1
,
dtype
:
str
=
"auto"
,
quantization
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
tokenizer_revision
:
Optional
[
str
]
=
None
,
seed
:
int
=
0
,
gpu_memory_utilization
:
float
=
0.9
,
swap_space
:
int
=
4
,
enforce_eager
:
bool
=
False
,
max_seq_len_to_capture
:
int
=
8192
,
disable_custom_all_reduce
:
bool
=
False
,
**
kwargs
,
)
->
None
:
if
"disable_log_stats"
not
in
kwargs
:
kwargs
[
"disable_log_stats"
]
=
True
self
.
engine_args
=
AsyncEngineArgs
(
model
=
model
,
tokenizer
=
tokenizer
,
tokenizer_mode
=
tokenizer_mode
,
skip_tokenizer_init
=
skip_tokenizer_init
,
trust_remote_code
=
trust_remote_code
,
tensor_parallel_size
=
tensor_parallel_size
,
dtype
=
dtype
,
quantization
=
quantization
,
revision
=
revision
,
tokenizer_revision
=
tokenizer_revision
,
seed
=
seed
,
gpu_memory_utilization
=
gpu_memory_utilization
,
swap_space
=
swap_space
,
enforce_eager
=
enforce_eager
,
max_seq_len_to_capture
=
max_seq_len_to_capture
,
engine_use_ray
=
True
,
disable_custom_all_reduce
=
disable_custom_all_reduce
,
**
kwargs
,
)
self
.
request_counter
=
Counter
()
def
generate
(
self
,
prompts
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
sampling_params
:
Optional
[
Union
[
SamplingParams
,
List
[
SamplingParams
]]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
List
[
int
]]]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
List
[
RequestOutput
]:
llm_engine
=
AsyncLLMEngine
.
from_engine_args
(
self
.
engine_args
,
usage_context
=
UsageContext
.
LLM_CLASS
)
if
prompts
is
None
:
raise
ValueError
(
"prompts must be provided."
)
if
isinstance
(
prompts
,
str
):
# Convert a single prompt to a list.
prompts
=
[
prompts
]
if
prompts
is
not
None
:
num_requests
=
len
(
prompts
)
if
sampling_params
is
None
:
# Use default sampling params.
sampling_params
=
SamplingParams
()
elif
isinstance
(
sampling_params
,
list
)
and
len
(
sampling_params
)
!=
num_requests
:
raise
ValueError
(
"The lengths of prompts and "
"sampling_params must be the same."
)
async
def
get_output
(
prompt
,
sampling_param
)
->
str
:
request_id
=
random_uuid
()
results_generator
=
llm_engine
.
generate
(
prompt
,
sampling_param
,
request_id
)
final_output
=
None
async
for
request_output
in
results_generator
:
final_output
=
request_output
return
final_output
outputs
=
[]
try
:
for
i
in
range
(
num_requests
):
prompt
=
prompts
[
i
]
if
prompts
is
not
None
else
None
res
=
asyncio
.
run
(
get_output
(
prompt
,
sampling_params
))
outputs
.
append
(
res
)
finally
:
ray
.
shutdown
()
return
outputs
@
pytest
.
fixture
@
pytest
.
fixture
...
@@ -35,9 +157,20 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
...
@@ -35,9 +157,20 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
test_name
=
request
.
node
.
name
test_name
=
request
.
node
.
name
def
generator_inner
():
def
generator_inner
():
print
(
f
'Creating
{
baseline_or_test
=
}
LLM for
{
test_name
=
}
.
{
kwargs
=
}
'
)
llm
=
LLM
(
**
kwargs
)
wait_for_gpu_memory_to_clear
(
devices
=
list
(
range
(
torch
.
cuda
.
device_count
())),
threshold_bytes
=
2
*
2
**
30
,
timeout_s
=
60
,
)
use_async
=
False
if
"use_async"
in
kwargs
:
use_async
=
kwargs
.
pop
(
"use_async"
)
print
(
f
'
{
use_async
=
}
'
)
print
(
f
'Creating
{
baseline_or_test
=
}
LLM for
{
test_name
=
}
.
{
kwargs
=
}
'
)
llm
=
AsyncLLM
(
**
kwargs
)
if
use_async
else
LLM
(
**
kwargs
)
set_random_seed
(
seed
)
set_random_seed
(
seed
)
yield
llm
yield
llm
...
@@ -64,3 +197,109 @@ def get_output_from_llm_generator(
...
@@ -64,3 +197,109 @@ def get_output_from_llm_generator(
del
llm
del
llm
return
tokens
,
token_ids
return
tokens
,
token_ids
def
get_logprobs_from_llm_generator
(
llm_generator
,
prompts
,
sampling_params
)
->
List
[
List
[
Dict
[
int
,
Logprob
]]]:
"""Returns a dict of (token_id: Logprob) for each generated position, for
each sequence in the batch.
"""
for
llm
in
llm_generator
():
outputs
=
llm
.
generate
(
prompts
,
sampling_params
,
use_tqdm
=
True
)
logprobs
=
[
output
.
outputs
[
0
].
logprobs
[:]
for
output
in
outputs
]
del
llm
return
logprobs
def
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
,
force_output_len
:
bool
,
print_tokens
:
bool
=
False
):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero.
"""
temperature
=
0.0
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
"San Francisco is know for its"
,
"Facebook was created in 2004 by"
,
"Curious George is a"
,
"Python 3.11 brings improvements to its"
,
]
prompts
=
[
prompt
for
prompt
,
_
in
zip
(
cycle
(
prompts
),
range
(
batch_size
))]
# If the test requires that we generated max_output_len tokens, then set the
# sampling params to ignore eos token.
ignore_eos
=
force_output_len
sampling_params
=
SamplingParams
(
max_tokens
=
max_output_len
,
ignore_eos
=
ignore_eos
,
temperature
=
temperature
,
)
spec_batch_tokens
,
spec_batch_token_ids
=
get_output_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
(
baseline_batch_tokens
,
baseline_batch_token_ids
)
=
get_output_from_llm_generator
(
baseline_llm_generator
,
prompts
,
sampling_params
)
assert
len
(
baseline_batch_token_ids
)
==
len
(
prompts
)
assert
len
(
spec_batch_token_ids
)
==
len
(
prompts
)
for
i
,
(
baseline_token_ids
,
baseline_tokens
,
spec_token_ids
,
spec_tokens
)
in
enumerate
(
zip
(
baseline_batch_token_ids
,
baseline_batch_tokens
,
spec_batch_token_ids
,
spec_batch_tokens
)):
if
print_tokens
:
print
(
f
'
{
i
=
}
{
baseline_tokens
=
}
'
)
print
(
f
'
{
i
=
}
{
spec_tokens
=
}
'
)
print
(
f
'
{
i
=
}
{
baseline_token_ids
=
}
'
)
print
(
f
'
{
i
=
}
{
spec_token_ids
=
}
'
)
assert
baseline_token_ids
==
spec_token_ids
def
wait_for_gpu_memory_to_clear
(
devices
:
List
[
int
],
threshold_bytes
:
int
,
timeout_s
:
float
=
120
)
->
None
:
# Use nvml instead of pytorch to reduce measurement error from torch cuda
# context.
nvmlInit
()
start_time
=
time
.
time
()
while
True
:
output
=
{}
output_raw
=
{}
for
device
in
devices
:
dev_handle
=
nvmlDeviceGetHandleByIndex
(
device
)
mem_info
=
nvmlDeviceGetMemoryInfo
(
dev_handle
)
gb_used
=
mem_info
.
used
/
2
**
30
output_raw
[
device
]
=
gb_used
output
[
device
]
=
f
'
{
gb_used
:.
02
f
}
'
print
(
'gpu memory used (GB): '
,
end
=
''
)
for
k
,
v
in
output
.
items
():
print
(
f
'
{
k
}
=
{
v
}
; '
,
end
=
''
)
print
(
''
)
dur_s
=
time
.
time
()
-
start_time
if
all
(
v
<=
(
threshold_bytes
/
2
**
30
)
for
v
in
output_raw
.
values
()):
print
(
f
'Done waiting for free GPU memory on devices
{
devices
=
}
'
f
'(
{
threshold_bytes
/
2
**
30
=
}
)
{
dur_s
=
:.
02
f
}
'
)
break
if
dur_s
>=
timeout_s
:
raise
ValueError
(
f
'Memory of devices
{
devices
=
}
not free after '
f
'
{
dur_s
=
:.
02
f
}
(
{
threshold_bytes
/
2
**
30
=
}
)'
)
time
.
sleep
(
5
)
tests/spec_decode/e2e/test_compatibility.py
View file @
1591c68f
...
@@ -42,10 +42,17 @@ def test_spec_decode_xfail_ray(test_llm_generator):
...
@@ -42,10 +42,17 @@ def test_spec_decode_xfail_ray(test_llm_generator):
temperature
=
temperature
,
temperature
=
temperature
,
)
)
with
pytest
.
raises
(
AssertionError
,
try
:
match
=
"Speculative decoding not yet supported for "
):
with
pytest
.
raises
(
get_output_from_llm_generator
(
test_llm_generator
,
prompts
,
AssertionError
,
sampling_params
)
match
=
"Speculative decoding not yet supported for "
):
get_output_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
finally
:
# we need to free up ray resource,
# so that latter test could use the gpu we allocated here
import
ray
ray
.
shutdown
()
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
...
tests/spec_decode/e2e/test_logprobs.py
0 → 100644
View file @
1591c68f
import
math
from
itertools
import
cycle
import
pytest
from
vllm
import
SamplingParams
from
.conftest
import
get_logprobs_from_llm_generator
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
"max_logprobs"
:
6
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_model"
:
"JackFram/llama-160m"
,
"num_speculative_tokens"
:
3
,
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
7
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_logprobs_equality
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify output logprobs are equal with and without speculative decoding.
"""
run_greedy_logprobs_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
"max_logprobs"
:
6
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_model"
:
"JackFram/llama-160m"
,
"num_speculative_tokens"
:
3
,
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
6
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
7
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_diff_num_logprobs
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
,
num_logprobs
:
int
):
"""Verify output logprobs are equal with and without spec decode.
This specifies a number of logprobs >1.
"""
run_greedy_logprobs_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
,
logprob_rank
=
num_logprobs
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_model"
:
"JackFram/llama-160m"
,
"num_speculative_tokens"
:
3
,
},
{
"speculative_model"
:
"JackFram/llama-160m"
,
"num_speculative_tokens"
:
6
,
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_logprobs_different_k
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Veriy logprob greedy equality with different speculation lens.
"""
run_greedy_logprobs_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_model"
:
"JackFram/llama-160m"
,
"num_speculative_tokens"
:
3
,
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_max_model_len"
:
32
,
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_logprobs_when_skip_speculation
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify logprobs greedy equality when some sequences skip speculation.
"""
run_greedy_logprobs_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_model"
:
"JackFram/llama-160m"
,
"num_speculative_tokens"
:
3
,
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_logprobs_temp_1
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify at least one logprob result has num_logprobs+1, which tests the
case where the sampled token is not in top-k logprobs.
Ideally, this test should validate equality with non-spec by getting
logprobs. This is left as future improvement.
"""
batch_size
=
8
max_output_len
=
output_len
force_output_len
=
True
logprob_rank
=
5
temperature
=
1.0
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
"San Francisco is know for its"
,
"Facebook was created in 2004 by"
,
"Curious George is a"
,
"Python 3.11 brings improvements to its"
,
]
prompts
=
[
prompt
for
prompt
,
_
in
zip
(
cycle
(
prompts
),
range
(
batch_size
))]
# If the test requires that we generated max_output_len tokens, then set the
# sampling params to ignore eos token.
ignore_eos
=
force_output_len
sampling_params
=
SamplingParams
(
max_tokens
=
max_output_len
,
ignore_eos
=
ignore_eos
,
temperature
=
temperature
,
logprobs
=
logprob_rank
,
)
spec_batch_logprobs
=
get_logprobs_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
num_returned_logprobs
=
[
len
(
logprob_dict
)
for
seq_logprobs
in
spec_batch_logprobs
for
logprob_dict
in
seq_logprobs
]
# Assert one of the returned logprobs has > num_logprobs (indicating the
# sampled token is not in top-k).
assert
any
([
num_returned
>
logprob_rank
for
num_returned
in
num_returned_logprobs
])
def
run_greedy_logprobs_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
,
force_output_len
:
bool
,
logprob_rank
:
int
=
1
):
"""Helper method that compares the logprobs outputs of both the baseline LLM
and the test LLM. It asserts greedy equality of the logprobs when the
temperature is zero.
"""
temperature
=
0.0
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
"San Francisco is know for its"
,
"Facebook was created in 2004 by"
,
"Curious George is a"
,
"Python 3.11 brings improvements to its"
,
]
prompts
=
[
prompt
for
prompt
,
_
in
zip
(
cycle
(
prompts
),
range
(
batch_size
))]
# If the test requires that we generated max_output_len tokens, then set the
# sampling params to ignore eos token.
ignore_eos
=
force_output_len
sampling_params
=
SamplingParams
(
max_tokens
=
max_output_len
,
ignore_eos
=
ignore_eos
,
temperature
=
temperature
,
logprobs
=
logprob_rank
,
)
spec_batch_logprobs
=
get_logprobs_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
baseline_batch_logprobs
=
get_logprobs_from_llm_generator
(
baseline_llm_generator
,
prompts
,
sampling_params
)
assert
len
(
baseline_batch_logprobs
)
==
len
(
prompts
)
assert
len
(
spec_batch_logprobs
)
==
len
(
prompts
)
# For each sequence in the batch.
for
i
,
(
baseline_logprobs
,
spec_logprobs
)
in
enumerate
(
zip
(
baseline_batch_logprobs
,
spec_batch_logprobs
)):
assert
len
(
spec_logprobs
)
==
len
(
baseline_logprobs
)
# For each generated position of the sequence.
for
pos
,
(
spec_pos_logprobs
,
baseline_pos_logprobs
)
in
enumerate
(
zip
(
spec_logprobs
,
baseline_logprobs
)):
# Map rank to token/logprob in spec output.
spec_rank_to_token_id
=
{
value
.
rank
:
key
for
key
,
value
in
spec_pos_logprobs
.
items
()
}
spec_rank_to_logprob
=
{
value
.
rank
:
value
.
logprob
for
key
,
value
in
spec_pos_logprobs
.
items
()
}
# Map rank to token/logprob in baseline output.
baseline_rank_to_token_id
=
{
value
.
rank
:
key
for
key
,
value
in
baseline_pos_logprobs
.
items
()
}
baseline_rank_to_logprob
=
{
value
.
rank
:
value
.
logprob
for
key
,
value
in
baseline_pos_logprobs
.
items
()
}
# Assert set of ranks returned is equal.
assert
set
(
spec_rank_to_token_id
.
keys
())
==
set
(
baseline_rank_to_token_id
.
keys
())
# Assert each logprob/token id is correct, keyed by rank.
for
rank
in
sorted
(
set
(
spec_rank_to_token_id
.
keys
())):
assert
spec_rank_to_token_id
[
rank
]
==
baseline_rank_to_token_id
[
rank
],
f
"
{
rank
}
"
assert
math
.
isclose
(
a
=
spec_rank_to_logprob
[
rank
],
b
=
baseline_rank_to_logprob
[
rank
],
abs_tol
=
1e-1
,
)
tests/spec_decode/e2e/test_correctness.py
→
tests/spec_decode/e2e/test_
multistep_
correctness.py
View file @
1591c68f
...
@@ -35,7 +35,8 @@ from transformers import AutoTokenizer
...
@@ -35,7 +35,8 @@ from transformers import AutoTokenizer
from
vllm
import
SamplingParams
from
vllm
import
SamplingParams
from
.conftest
import
get_output_from_llm_generator
from
.conftest
import
(
get_output_from_llm_generator
,
run_greedy_equality_correctness_test
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -49,7 +50,7 @@ from .conftest import get_output_from_llm_generator
...
@@ -49,7 +50,7 @@ from .conftest import get_output_from_llm_generator
"enforce_eager"
:
True
,
"enforce_eager"
:
True
,
# Required for spec decode.
# Required for spec decode.
"use_v2_block_manager"
:
True
"use_v2_block_manager"
:
True
,
}])
}])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
"per_test_common_llm_kwargs"
,
...
@@ -109,6 +110,44 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
...
@@ -109,6 +110,44 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
assert
actual_tokens
.
strip
()
==
expected_tokens
.
strip
()
assert
actual_tokens
.
strip
()
==
expected_tokens
.
strip
()
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Use a small model for a fast test.
# Note this is repeated in the test body; to initialize a tokenizer.
"model"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
# Use AsyncLLM engine
"use_async"
:
True
,
}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
},
])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_spec_decode_e2e_with_async_engine
(
test_llm_generator
,
baseline_llm_generator
,
batch_size
:
int
):
"""Verify spec decode works well with async LLM engine.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
32
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
"common_llm_kwargs"
,
[{
[{
...
@@ -538,60 +577,3 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
...
@@ -538,60 +577,3 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
batch_size
,
batch_size
,
max_output_len
=
output_len
,
max_output_len
=
output_len
,
force_output_len
=
True
)
force_output_len
=
True
)
def
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
,
force_output_len
:
bool
,
print_tokens
:
bool
=
False
):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero.
"""
temperature
=
0.0
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
"San Francisco is know for its"
,
"Facebook was created in 2004 by"
,
"Curious George is a"
,
"Python 3.11 brings improvements to its"
,
]
prompts
=
[
prompt
for
prompt
,
_
in
zip
(
cycle
(
prompts
),
range
(
batch_size
))]
# If the test requires that we generated max_output_len tokens, then set the
# sampling params to ignore eos token.
ignore_eos
=
force_output_len
sampling_params
=
SamplingParams
(
max_tokens
=
max_output_len
,
ignore_eos
=
ignore_eos
,
temperature
=
temperature
,
)
spec_batch_tokens
,
spec_batch_token_ids
=
get_output_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
(
baseline_batch_tokens
,
baseline_batch_token_ids
)
=
get_output_from_llm_generator
(
baseline_llm_generator
,
prompts
,
sampling_params
)
assert
len
(
baseline_batch_token_ids
)
==
len
(
prompts
)
assert
len
(
spec_batch_token_ids
)
==
len
(
prompts
)
for
i
,
(
baseline_token_ids
,
baseline_tokens
,
spec_token_ids
,
spec_tokens
)
in
enumerate
(
zip
(
baseline_batch_token_ids
,
baseline_batch_tokens
,
spec_batch_token_ids
,
spec_batch_tokens
)):
if
print_tokens
:
print
(
f
'
{
i
=
}
{
baseline_tokens
=
}
'
)
print
(
f
'
{
i
=
}
{
spec_tokens
=
}
'
)
print
(
f
'
{
i
=
}
{
baseline_token_ids
=
}
'
)
print
(
f
'
{
i
=
}
{
spec_token_ids
=
}
'
)
assert
baseline_token_ids
==
spec_token_ids
tests/spec_decode/e2e/test_ngram_correctness.py
0 → 100644
View file @
1591c68f
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
For ngram lookup, its idea comes from https://github.com/apoorvumang/prompt-lookup-decoding,
and is merged into transform code base: https://github.com/huggingface/transformers/pull/27775.
Since there is no model is needed for generate the proposal, we could make
the testcase much simpler than drafter multi-step one.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various ngram sizes / speculative sizes
With those tests, we can say at least, ngram spec would not break the correctess
for the target model outputs.
"""
import
pytest
from
.conftest
import
run_greedy_equality_correctness_test
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
# Print spec metrics.
"disable_log_stats"
:
False
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
{
"model"
:
"JackFram/llama-68m"
,
},
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
"[ngram]"
,
"num_speculative_tokens"
:
5
,
"ngram_prompt_lookup_max"
:
3
,
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
256
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
64
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_ngram_e2e_greedy_correctness
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify greedy equality on a tiny model with different batch size."""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"block_size"
:
8
,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override"
:
2
+
256
//
8
,
"max_model_len"
:
(
2
+
256
//
8
)
*
8
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
{
"model"
:
"JackFram/llama-160m"
,
},
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
"[ngram]"
,
"num_speculative_tokens"
:
5
,
"ngram_prompt_lookup_max"
:
3
,
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use small output len for fast test.
256
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_ngram_e2e_greedy_correctness_with_preemption
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
"[ngram]"
,
"num_speculative_tokens"
:
k
,
"ngram_prompt_lookup_max"
:
3
,
}
# Try a range of common k, as well as large speculation.
for
k
in
[
1
,
3
,
5
]
]
+
[
{
"speculative_model"
:
"[ngram]"
,
"num_speculative_tokens"
:
k
,
"ngram_prompt_lookup_max"
:
1
,
}
# Try a range of common k, as well as large speculation.
for
k
in
[
1
,
3
,
5
]
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_ngram_different_k
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and
different ngram_prompt_lookup_max.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
tests/spec_decode/test_multi_step_worker.py
View file @
1591c68f
...
@@ -5,13 +5,12 @@ import pytest
...
@@ -5,13 +5,12 @@ import pytest
import
torch
import
torch
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.spec_decode.multi_step_worker
import
(
DraftModelTop1Propos
er
,
from
vllm.spec_decode.multi_step_worker
import
MultiStepWork
er
MultiStepWork
er
)
from
vllm.spec_decode.top1_proposer
import
Top1Propos
er
from
vllm.worker.worker
import
Worker
from
vllm.worker.worker
import
Worker
from
.utils
import
(
assert_logprobs_dict_allclose
,
create_batch
,
from
.utils
import
(
assert_logprobs_dict_allclose
,
create_batch
,
create_execute_model_data
,
create_seq_group_metadata_from_prompts
,
create_worker
,
create_seq_group_metadata_from_prompts
,
create_worker
,
patch_execute_model_with_seeds
,
zero_kv_cache
)
patch_execute_model_with_seeds
,
zero_kv_cache
)
...
@@ -34,7 +33,7 @@ def test_assert_enough_kv_space(num_steps: int):
...
@@ -34,7 +33,7 @@ def test_assert_enough_kv_space(num_steps: int):
list
(
range
(
block_size
*
2
)),
list
(
range
(
block_size
*
2
)),
]
]
final_
seq
_lens
=
[
final_
prompt
_lens
=
[
len
(
prompt
+
output
)
+
num_steps
len
(
prompt
+
output
)
+
num_steps
for
prompt
,
output
in
zip
(
prompts
,
prev_output_tokens
)
for
prompt
,
output
in
zip
(
prompts
,
prev_output_tokens
)
]
]
...
@@ -43,7 +42,7 @@ def test_assert_enough_kv_space(num_steps: int):
...
@@ -43,7 +42,7 @@ def test_assert_enough_kv_space(num_steps: int):
prompts
,
prompts
,
num_gpu_blocks
,
num_gpu_blocks
,
block_size
,
block_size
,
final_
seq
_lens
,
final_
prompt
_lens
,
continuations
=
prev_output_tokens
)
continuations
=
prev_output_tokens
)
assert_enough_kv_space
=
MultiStepWorker
.
_assert_enough_kv_space
# pylint: disable=protected-access
assert_enough_kv_space
=
MultiStepWorker
.
_assert_enough_kv_space
# pylint: disable=protected-access
...
@@ -103,29 +102,34 @@ def test_same_output_for_single_step():
...
@@ -103,29 +102,34 @@ def test_same_output_for_single_step():
[
6
,
7
,
8
,
9
,
10
],
[
6
,
7
,
8
,
9
,
10
],
]
]
final_
seq
_lens
=
[
len
(
prompt
)
+
num_steps
for
prompt
in
prompts
]
final_
prompt
_lens
=
[
len
(
prompt
)
+
num_steps
for
prompt
in
prompts
]
multi_step_execute_model_data
=
create_execute_model_data
(
multi_step_seq_group
=
create_seq_group_metadata_from_prompts
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
prompts
,
num_gpu_blocks
,
block_size
,
num_gpu_blocks
,
final_seq_lens
=
final_seq_lens
))
block_size
,
final_prompt_lens
=
final_prompt_lens
)
single_step_execute_model_data
=
create_execute_model_data
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_seq_lens
=
final_seq_lens
))
zero_kv_cache
(
multi_step_worker
.
cache_engine
)
zero_kv_cache
(
multi_step_worker
.
cache_engine
)
set_random_seed
(
seed
)
set_random_seed
(
seed
)
actual_output
=
multi_step_worker
.
execute_model_multi_step
(
actual_output
,
_
=
multi_step_worker
.
sampler_output
(
**
multi_step_execute_model_data
.
to_dict
(),
num_steps
=
num_steps
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
multi_step_seq_group
),
sample_len
=
num_steps
)
assert
len
(
actual_output
)
==
num_steps
assert
len
(
actual_output
)
==
num_steps
actual_output
=
actual_output
[
0
]
actual_output
=
actual_output
[
0
]
single_step_seq_group
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_prompt_lens
=
final_prompt_lens
)
zero_kv_cache
(
worker
.
cache_engine
)
zero_kv_cache
(
worker
.
cache_engine
)
set_random_seed
(
seed
)
set_random_seed
(
seed
)
expected_output
=
worker
.
execute_model
(
expected_output
=
worker
.
execute_model
(
**
single_step_execute_model_data
.
to_dict
(),
)[
0
]
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
single_step_seq_group
))[
0
]
actual_token_ids
=
[
actual_token_ids
=
[
output
.
samples
[
0
].
output_token
for
output
in
actual_output
output
.
samples
[
0
].
output_token
for
output
in
actual_output
...
@@ -181,7 +185,7 @@ def test_same_output_for_multi_step():
...
@@ -181,7 +185,7 @@ def test_same_output_for_multi_step():
random
.
randint
(
0
,
1000
)
for
_
in
range
(
random
.
randint
(
10
,
20
))
random
.
randint
(
0
,
1000
)
for
_
in
range
(
random
.
randint
(
10
,
20
))
]
for
_
in
range
(
10
)]
]
for
_
in
range
(
10
)]
final_
seq
_lens
=
[
len
(
prompt
)
+
num_steps
for
prompt
in
prompts
]
final_
prompt
_lens
=
[
len
(
prompt
)
+
num_steps
for
prompt
in
prompts
]
rand_seeds
=
list
(
random
.
randint
(
0
,
100
)
for
_
in
range
(
num_steps
))
rand_seeds
=
list
(
random
.
randint
(
0
,
100
)
for
_
in
range
(
num_steps
))
multi_step_worker
.
execute_model
=
patch_execute_model_with_seeds
(
multi_step_worker
.
execute_model
=
patch_execute_model_with_seeds
(
...
@@ -189,19 +193,20 @@ def test_same_output_for_multi_step():
...
@@ -189,19 +193,20 @@ def test_same_output_for_multi_step():
worker
.
execute_model
=
patch_execute_model_with_seeds
(
worker
,
rand_seeds
)
worker
.
execute_model
=
patch_execute_model_with_seeds
(
worker
,
rand_seeds
)
continuations
=
[[
1
]
for
_
in
prompts
]
continuations
=
[[
1
]
for
_
in
prompts
]
execute_model_data
=
create_execute_model_data
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
create_seq_group_metadata_from_prompts
(
prompts
,
prompts
,
num_gpu_blocks
,
num_gpu_blocks
,
block_size
,
block_size
,
continuations
=
continuations
,
continuations
=
continuations
,
final_prompt_lens
=
final_prompt_lens
)
final_seq_lens
=
final_seq_lens
),
)
# Run multi-step.
# Run multi-step.
zero_kv_cache
(
multi_step_worker
.
cache_engine
)
zero_kv_cache
(
multi_step_worker
.
cache_engine
)
set_random_seed
(
seed
)
set_random_seed
(
seed
)
multi_step_output
=
multi_step_worker
.
execute_model_multi_step
(
multi_step_output
,
_
=
multi_step_worker
.
sampler_output
(
**
execute_model_data
.
to_dict
(),
num_steps
=
num_steps
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
),
sample_len
=
num_steps
)
# Run single-step repeatedly.
# Run single-step repeatedly.
zero_kv_cache
(
worker
.
cache_engine
)
zero_kv_cache
(
worker
.
cache_engine
)
...
@@ -211,16 +216,16 @@ def test_same_output_for_multi_step():
...
@@ -211,16 +216,16 @@ def test_same_output_for_multi_step():
for
_
in
multi_step_output
:
for
_
in
multi_step_output
:
execute_model_data
=
create_execute_model_data
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
create_seq_group_metadata_from_prompts
(
prompts
,
prompts
,
num_gpu_blocks
,
num_gpu_blocks
,
block_size
,
block_size
,
continuations
=
continuations
,
continuations
=
continuations
,
final_prompt_lens
=
final_prompt_lens
)
final_seq_lens
=
final_seq_lens
))
single_step_output
.
extend
(
single_step_output
.
extend
(
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
))
worker
.
execute_model
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
)))
# Append output tokens to new sequence data.
# Append output tokens to new sequence data.
for
i
,
seq_group_output
in
enumerate
(
single_step_output
[
-
1
]):
for
i
,
seq_group_output
in
enumerate
(
single_step_output
[
-
1
]):
...
@@ -266,7 +271,7 @@ def test_same_output_for_multi_step():
...
@@ -266,7 +271,7 @@ def test_same_output_for_multi_step():
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_draft_proposals_full_speculation_len
():
def
test_draft_proposals_full_speculation_len
():
"""Verify
DraftModel
Top1Proposer correctly handles case where all sequences
"""Verify Top1Proposer correctly handles case where all sequences
can speculate.
can speculate.
"""
"""
k
=
10
k
=
10
...
@@ -275,33 +280,36 @@ def test_draft_proposals_full_speculation_len():
...
@@ -275,33 +280,36 @@ def test_draft_proposals_full_speculation_len():
device
=
'cuda:0'
device
=
'cuda:0'
draft_worker
=
MagicMock
()
draft_worker
=
MagicMock
()
proposer
=
DraftModel
Top1Proposer
(
proposer
=
Top1Proposer
(
draft_
worker
=
draft_worker
,
worker
=
draft_worker
,
device
=
device
,
device
=
device
,
max_model_len
=
2048
,
vocab_size
=
vocab_size
,
vocab_size
=
vocab_size
,
max_proposal_len
=
2048
,
)
)
draft_worker
.
execute_model_multi_step
.
return_value
=
[
draft_worker
.
sampler_output
.
return_value
=
[
SamplerOutput
(
SamplerOutput
(
outputs
=
[],
outputs
=
[],
sampled_token_probs
=
torch
.
rand
(
batch_size
,
sampled_token_probs
=
torch
.
rand
(
batch_size
,
vocab_size
,
vocab_size
,
device
=
device
,
device
=
device
,
dtype
=
torch
.
float32
),
dtype
=
torch
.
float32
),
logprobs
=
torch
.
rand
(
batch_size
,
vocab_size
,
device
=
device
,
dtype
=
torch
.
float32
),
sampled_token_ids
=
torch
.
randint
(
low
=
0
,
sampled_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
high
=
vocab_size
,
size
=
(
batch_size
,
),
size
=
(
batch_size
,
),
device
=
device
,
device
=
device
,
dtype
=
torch
.
long
),
dtype
=
torch
.
long
),
)
for
_
in
range
(
k
)
)
for
_
in
range
(
k
)
]
]
,
True
execute_model_data
,
_
,
_
=
create_batch
(
batch_size
,
k
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
proposals
=
proposer
.
get_proposals
(
proposals
=
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
**
execute_model_data
.
to_dict
(),
seq_group_metadata_list
=
seq_group_metadata_list
,
max_proposal_len
=
k
,
num_lookahead_slots
=
k
),
)
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
...
@@ -315,7 +323,7 @@ def test_draft_proposals_full_speculation_len():
...
@@ -315,7 +323,7 @@ def test_draft_proposals_full_speculation_len():
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_draft_proposals_no_speculations
():
def
test_draft_proposals_no_speculations
():
"""Verify
DraftModel
Top1Proposer correctly handles case where no sequences
"""Verify Top1Proposer correctly handles case where no sequences
can speculate.
can speculate.
"""
"""
k
=
10
k
=
10
...
@@ -325,21 +333,20 @@ def test_draft_proposals_no_speculations():
...
@@ -325,21 +333,20 @@ def test_draft_proposals_no_speculations():
prompt_len
=
10
prompt_len
=
10
draft_worker
=
MagicMock
()
draft_worker
=
MagicMock
()
proposer
=
DraftModel
Top1Proposer
(
proposer
=
Top1Proposer
(
draft_
worker
=
draft_worker
,
worker
=
draft_worker
,
device
=
device
,
device
=
device
,
max_model_len
=
prompt_len
+
k
-
1
,
vocab_size
=
vocab_size
,
vocab_size
=
vocab_size
,
max_proposal_len
=
prompt_len
+
k
-
1
,
)
)
execute_model_data
,
_
,
_
=
create_batch
(
batch_size
,
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
,
k
,
prompt_len
=
prompt_len
)
prompt_len
=
prompt_len
)
proposals
=
proposer
.
get_proposals
(
proposals
=
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
**
execute_model_data
.
to_dict
(),
seq_group_metadata_list
=
seq_group_metadata_list
,
max_proposal_len
=
k
,
num_lookahead_slots
=
k
),
)
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
...
@@ -353,7 +360,7 @@ def test_draft_proposals_no_speculations():
...
@@ -353,7 +360,7 @@ def test_draft_proposals_no_speculations():
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_draft_proposals_mixed_k
():
def
test_draft_proposals_mixed_k
():
"""Verify
DraftModel
Top1Proposer correctly handles case some sequences can
"""Verify Top1Proposer correctly handles case some sequences can
speculate and some can't.
speculate and some can't.
"""
"""
k
=
10
k
=
10
...
@@ -374,20 +381,24 @@ def test_draft_proposals_mixed_k():
...
@@ -374,20 +381,24 @@ def test_draft_proposals_mixed_k():
for
_
in
range
(
expected_num_no_proposal_seqs
)]
+
[
small_prompt_len
]
for
_
in
range
(
expected_num_no_proposal_seqs
)]
+
[
small_prompt_len
]
draft_worker
=
MagicMock
()
draft_worker
=
MagicMock
()
proposer
=
DraftModel
Top1Proposer
(
proposer
=
Top1Proposer
(
draft_
worker
=
draft_worker
,
worker
=
draft_worker
,
device
=
device
,
device
=
device
,
max_model_len
=
long_prompt_len
+
prev_output_token_len
+
k
-
1
,
vocab_size
=
vocab_size
,
vocab_size
=
vocab_size
,
max_proposal_len
=
long_prompt_len
+
prev_output_token_len
+
k
-
1
,
)
)
draft_worker
.
execute_model_multi_step
.
return_value
=
[
draft_worker
.
sampler_output
.
return_value
=
[
SamplerOutput
(
SamplerOutput
(
outputs
=
[],
outputs
=
[],
sampled_token_probs
=
torch
.
rand
(
expected_num_proposal_seqs
,
sampled_token_probs
=
torch
.
rand
(
expected_num_proposal_seqs
,
vocab_size
,
vocab_size
,
device
=
device
,
device
=
device
,
dtype
=
torch
.
float32
),
dtype
=
torch
.
float32
),
logprobs
=
torch
.
rand
(
expected_num_proposal_seqs
,
vocab_size
,
device
=
device
,
dtype
=
torch
.
float32
),
sampled_token_ids
=
torch
.
randint
(
sampled_token_ids
=
torch
.
randint
(
low
=
0
,
low
=
0
,
high
=
vocab_size
,
high
=
vocab_size
,
...
@@ -395,19 +406,18 @@ def test_draft_proposals_mixed_k():
...
@@ -395,19 +406,18 @@ def test_draft_proposals_mixed_k():
device
=
device
,
device
=
device
,
dtype
=
torch
.
long
),
dtype
=
torch
.
long
),
)
for
_
in
range
(
k
)
)
for
_
in
range
(
k
)
]
]
,
True
execute_model_data
,
_
,
_
=
create_batch
(
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
batch_size
,
k
,
k
,
prompt_len
=
prompt_len
,
prompt_len
=
prompt_len
,
prev_output_token_len
=
prev_output_token_len
,
prev_output_token_len
=
prev_output_token_len
,
)
)
proposals
=
proposer
.
get_proposals
(
proposals
=
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
**
execute_model_data
.
to_dict
(),
seq_group_metadata_list
=
seq_group_metadata_list
,
max_proposal_len
=
k
,
num_lookahead_slots
=
k
),
)
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
...
...
tests/spec_decode/test_ngram_worker.py
0 → 100644
View file @
1591c68f
import
torch
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.spec_decode.ngram_worker
import
NGramWorker
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
.utils
import
create_seq_group_metadata_from_prompts
,
create_worker
def
test_ngram_algo_correctness_for_single_no_match
():
"""Verify our ngram algo find the right candidate in the prompt
For the scenario cannot find any candidate in one single batch
"""
block_size
=
32
num_gpu_blocks
=
2048
//
block_size
seed
=
100
model_name
=
'JackFram/llama-68m'
vocab_size
=
32_000
device
=
'cuda:0'
ngram_worker
=
create_worker
(
NGramWorker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
,
)
proposer
=
Top1Proposer
(
worker
=
ngram_worker
,
device
=
device
,
vocab_size
=
vocab_size
,
max_proposal_len
=
20
,
)
# set ngram window (0, 3], which is window=1/2/3
ngram_worker
.
set_ngram_window_size
(
0
,
3
)
prompts
=
[
# shall find no candidate
[
1
,
2
,
3
,
4
,
5
,
6
,
7
],
]
proposal_len
=
5
final_prompt_lens
=
[
len
(
prompt
)
+
proposal_len
for
prompt
in
prompts
]
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_prompt_lens
=
final_prompt_lens
)
proposals
=
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
proposal_len
),
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
proposals
.
proposal_token_ids
.
shape
==
torch
.
Size
([
1
,
proposal_len
])
assert
proposals
.
proposal_probs
.
shape
[:
-
1
]
==
torch
.
Size
([
1
,
proposal_len
])
assert
proposals
.
proposal_lens
.
shape
==
torch
.
Size
([
1
])
assert
proposals
.
proposal_lens
.
tolist
()
==
[
0
]
def
test_ngram_algo_correctness_for_batches_not_match_all
():
"""Verify our ngram algo find the right candidate in the prompt
For the scenario find some candidate not full in batchs
"""
block_size
=
32
num_gpu_blocks
=
2048
//
block_size
seed
=
100
model_name
=
'JackFram/llama-68m'
vocab_size
=
32_000
device
=
'cuda:0'
ngram_worker
=
create_worker
(
NGramWorker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
,
)
proposer
=
Top1Proposer
(
worker
=
ngram_worker
,
device
=
device
,
vocab_size
=
vocab_size
,
max_proposal_len
=
20
,
)
# set ngram window (0, 3], which is window=1/2/3
ngram_worker
.
set_ngram_window_size
(
0
,
3
)
prompts
=
[
# shall find no candidate
[
1
,
2
,
3
,
4
,
5
,
6
,
7
],
# shall find candidate 12,13,14,15,16
[
11
,
12
,
13
,
14
,
15
,
16
,
11
],
# shall find candidate 23,24,25,26,21
[
21
,
21
,
22
,
23
,
24
,
25
,
26
,
21
,
22
],
# shall find candidate 34,35,36,37,38
[
31
,
32
,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
31
,
32
,
33
],
# shall find no candidate as exceed max_proposal_len
[
31
,
32
,
31
,
32
,
31
,
32
,
31
,
32
,
31
,
32
,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
31
,
32
,
33
],
]
proposal_len
=
5
final_prompt_lens
=
[
len
(
prompt
)
+
proposal_len
for
prompt
in
prompts
]
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_prompt_lens
=
final_prompt_lens
)
proposals
=
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
proposal_len
),
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
proposals
.
proposal_token_ids
.
shape
==
torch
.
Size
([
5
,
proposal_len
])
assert
proposals
.
proposal_probs
.
shape
[:
-
1
]
==
torch
.
Size
([
5
,
proposal_len
])
assert
proposals
.
proposal_lens
.
shape
==
torch
.
Size
([
5
])
assert
proposals
.
proposal_lens
.
tolist
(
)
==
[
proposal_len
for
_
in
range
(
4
)]
+
[
0
]
for
i
in
range
(
proposal_len
):
assert
proposals
.
proposal_token_ids
[
0
][
i
]
==
0
assert
proposals
.
proposal_token_ids
[
1
][
i
]
==
prompts
[
1
][
i
+
1
]
assert
proposals
.
proposal_token_ids
[
2
][
i
]
==
prompts
[
2
][
i
+
3
]
assert
proposals
.
proposal_token_ids
[
3
][
i
]
==
prompts
[
3
][
i
+
5
]
assert
proposals
.
proposal_token_ids
[
4
][
i
]
==
-
1
def
test_ngram_algo_correctness_for_batches_match_all
():
"""Verify our ngram algo find the right candidate in the prompt
For the scenario find candidate in all batchs
"""
block_size
=
32
num_gpu_blocks
=
2048
//
block_size
seed
=
100
model_name
=
'JackFram/llama-68m'
vocab_size
=
32_000
device
=
'cuda:0'
ngram_worker
=
create_worker
(
NGramWorker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
,
)
proposer
=
Top1Proposer
(
worker
=
ngram_worker
,
device
=
device
,
vocab_size
=
vocab_size
,
max_proposal_len
=
20
,
)
# set ngram window (0, 3], which is window=1/2/3
ngram_worker
.
set_ngram_window_size
(
0
,
3
)
prompts
=
[
# shall find candidate 12,13,14,15,16
[
11
,
12
,
13
,
14
,
15
,
16
,
11
],
# shall find candidate 23,24,25,26,21
[
21
,
21
,
22
,
23
,
24
,
25
,
26
,
21
,
22
],
# shall find candidate 34,35,36,37,38
[
31
,
32
,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
31
,
32
,
33
],
]
proposal_len
=
5
final_prompt_lens
=
[
len
(
prompt
)
+
proposal_len
for
prompt
in
prompts
]
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_prompt_lens
=
final_prompt_lens
)
proposals
=
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
proposal_len
),
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
proposals
.
proposal_token_ids
.
shape
==
torch
.
Size
([
3
,
proposal_len
])
assert
proposals
.
proposal_probs
.
shape
[:
-
1
]
==
torch
.
Size
([
3
,
proposal_len
])
assert
proposals
.
proposal_lens
.
shape
==
torch
.
Size
([
3
])
assert
proposals
.
proposal_lens
.
tolist
()
==
[
proposal_len
for
_
in
range
(
3
)]
for
i
in
range
(
proposal_len
):
assert
proposals
.
proposal_token_ids
[
0
][
i
]
==
prompts
[
0
][
i
+
1
]
assert
proposals
.
proposal_token_ids
[
1
][
i
]
==
prompts
[
1
][
i
+
3
]
assert
proposals
.
proposal_token_ids
[
2
][
i
]
==
prompts
[
2
][
i
+
5
]
tests/spec_decode/test_spec_decode_worker.py
View file @
1591c68f
...
@@ -7,7 +7,7 @@ import torch
...
@@ -7,7 +7,7 @@ import torch
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.metrics
import
(
AsyncMetricsCollector
,
from
vllm.spec_decode.metrics
import
(
AsyncMetricsCollector
,
SpecDecodeWorkerMetrics
)
SpecDecodeWorkerMetrics
)
...
@@ -15,8 +15,7 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker
...
@@ -15,8 +15,7 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker
from
vllm.spec_decode.spec_decode_worker
import
(
SpecDecodeWorker
,
from
vllm.spec_decode.spec_decode_worker
import
(
SpecDecodeWorker
,
split_num_cache_blocks_evenly
)
split_num_cache_blocks_evenly
)
from
.utils
import
(
ExecuteModelData
,
create_batch
,
create_sampler_output_list
,
from
.utils
import
create_batch
,
create_sampler_output_list
,
mock_worker
mock_worker
)
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
6
])
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
6
])
...
@@ -33,27 +32,22 @@ def test_correctly_calls_draft_model(k: int, batch_size: int):
...
@@ -33,27 +32,22 @@ def test_correctly_calls_draft_model(k: int, batch_size: int):
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
rejection_sampler
,
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
rejection_sampler
,
metrics_collector
)
metrics_collector
)
exception_secret
=
'artifical stop'
exception_secret
=
'artific
i
al stop'
draft_worker
.
get_spec_proposals
.
side_effect
=
ValueError
(
exception_secret
)
draft_worker
.
get_spec_proposals
.
side_effect
=
ValueError
(
exception_secret
)
execute_model_data
,
_
,
_
=
create_batch
(
batch_size
,
k
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
)
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
num_lookahead_slots
=
k
)
call_args_list
=
draft_worker
.
get_spec_proposals
.
call_args_list
call_args_list
=
draft_worker
.
get_spec_proposals
.
call_args_list
assert
len
(
call_args_list
)
==
1
assert
len
(
call_args_list
)
==
1
for
args
,
_
in
call_args_list
:
for
args
,
_
in
call_args_list
:
(
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_out
,
actual_execute_model_data
=
args
[
0
]
blocks_to_copy
,
actual_k
)
=
args
assert
actual_execute_model_data
==
execute_model_req
actual_execute_model_data
=
ExecuteModelData
(
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
)
assert
actual_execute_model_data
==
execute_model_data
assert
actual_k
==
k
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
6
])
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
6
])
...
@@ -93,7 +87,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
...
@@ -93,7 +87,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
proposal_lens
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int64
,
proposal_lens
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
'cuda'
)
*
k
device
=
'cuda'
)
*
k
execute_model_data
,
prompts
,
prev_output_tokens
=
create_batch
(
seq_group_metadata_list
,
prompts
,
prev_output_tokens
=
create_batch
(
batch_size
,
k
)
batch_size
,
k
)
draft_worker
.
get_spec_proposals
.
return_value
=
SpeculativeProposals
(
draft_worker
.
get_spec_proposals
.
return_value
=
SpeculativeProposals
(
...
@@ -101,24 +95,24 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
...
@@ -101,24 +95,24 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
proposal_probs
=
proposal_probs
,
proposal_probs
=
proposal_probs
,
proposal_lens
=
proposal_lens
)
proposal_lens
=
proposal_lens
)
exception_secret
=
'artifical stop'
exception_secret
=
'artific
i
al stop'
target_worker
.
execute_model
.
side_effect
=
ValueError
(
exception_secret
)
target_worker
.
execute_model
.
side_effect
=
ValueError
(
exception_secret
)
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
worker
.
execute_model
(
execute_model_req
=
ExecuteModelRequest
(
num_lookahead_slots
=
k
)
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
))
seen_contexts
=
[]
seen_contexts
=
[]
call_args_list
=
target_worker
.
execute_model
.
call_args_list
call_args_list
=
target_worker
.
execute_model
.
call_args_list
assert
len
(
call_args_list
)
==
1
assert
len
(
call_args_list
)
==
1
for
args
,
kwargs
in
call_args_list
:
for
_
,
kwargs
in
call_args_list
:
target_execute_model_data
=
ExecuteModelData
.
from_dict
(
kwargs
)
seq_group_metadata_list
=
kwargs
[
"execute_model_req"
].
seq_group_metadata_list
assert
len
(
target_execute_model_data
.
seq_group_metadata_list
)
==
(
assert
len
(
seq_group_metadata_list
)
==
(
k
+
1
)
*
batch_size
k
+
1
)
*
batch_size
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_group_metadata
in
(
target_execute_model_data
.
seq_group_metadata_list
):
for
seq_data
in
seq_group_metadata
.
seq_data
.
values
():
for
seq_data
in
seq_group_metadata
.
seq_data
.
values
():
seen_contexts
.
append
(
seq_data
.
get_token_ids
())
seen_contexts
.
append
(
seq_data
.
get_token_ids
())
...
@@ -175,7 +169,7 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
...
@@ -175,7 +169,7 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
proposal_lens
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int64
,
proposal_lens
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
'cuda'
)
*
k
device
=
'cuda'
)
*
k
execute_model_data
,
_
,
_
=
create_batch
(
batch_size
,
k
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
draft_worker
.
get_spec_proposals
.
return_value
=
SpeculativeProposals
(
draft_worker
.
get_spec_proposals
.
return_value
=
SpeculativeProposals
(
proposal_token_ids
=
proposal_token_ids
,
proposal_token_ids
=
proposal_token_ids
,
...
@@ -192,17 +186,24 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
...
@@ -192,17 +186,24 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
vocab_size
,
vocab_size
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
device
=
'cuda'
)
target_token_logprobs
=
torch
.
rand
(
1
,
batch_size
*
(
k
+
1
),
vocab_size
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
target_output
=
create_sampler_output_list
(
target_token_ids
,
target_output
=
create_sampler_output_list
(
target_token_ids
,
target_token_probs
)
target_token_probs
,
target_token_logprobs
)
target_worker
.
execute_model
.
return_value
=
[
target_output
[
0
]]
target_worker
.
execute_model
.
return_value
=
[
target_output
[
0
]]
exception_secret
=
'artifical stop'
exception_secret
=
'artific
i
al stop'
rejection_sampler
.
side_effect
=
ValueError
(
exception_secret
)
rejection_sampler
.
side_effect
=
ValueError
(
exception_secret
)
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
worker
.
execute_model
(
execute_model_req
=
ExecuteModelRequest
(
num_lookahead_slots
=
k
)
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
))
assert
len
(
rejection_sampler
.
call_args_list
)
==
1
assert
len
(
rejection_sampler
.
call_args_list
)
==
1
_
,
kwargs
=
rejection_sampler
.
call_args_list
[
0
]
_
,
kwargs
=
rejection_sampler
.
call_args_list
[
0
]
...
@@ -256,7 +257,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
...
@@ -256,7 +257,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
proposal_lens
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int64
,
proposal_lens
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
'cuda'
)
*
k
device
=
'cuda'
)
*
k
execute_model_data
,
_
,
_
=
create_batch
(
batch_size
,
k
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
draft_worker
.
get_spec_proposals
.
return_value
=
SpeculativeProposals
(
draft_worker
.
get_spec_proposals
.
return_value
=
SpeculativeProposals
(
proposal_token_ids
=
proposal_token_ids
,
proposal_token_ids
=
proposal_token_ids
,
...
@@ -273,8 +274,14 @@ def test_correctly_formats_output(k: int, batch_size: int):
...
@@ -273,8 +274,14 @@ def test_correctly_formats_output(k: int, batch_size: int):
vocab_size
,
vocab_size
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
device
=
'cuda'
)
target_token_logprobs
=
torch
.
rand
(
1
,
batch_size
*
(
k
+
1
),
vocab_size
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
target_output
=
create_sampler_output_list
(
target_token_ids
,
target_output
=
create_sampler_output_list
(
target_token_ids
,
target_token_probs
)
target_token_probs
,
target_token_logprobs
)
target_worker
.
execute_model
.
return_value
=
[
target_output
[
0
]]
target_worker
.
execute_model
.
return_value
=
[
target_output
[
0
]]
...
@@ -290,15 +297,18 @@ def test_correctly_formats_output(k: int, batch_size: int):
...
@@ -290,15 +297,18 @@ def test_correctly_formats_output(k: int, batch_size: int):
rejection_sampler
.
return_value
=
rejection_sampler_output
rejection_sampler
.
return_value
=
rejection_sampler_output
output
=
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
output
=
worker
.
execute_model
(
execute_model_req
=
ExecuteModelRequest
(
num_lookahead_slots
=
k
)
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
))
expected_output
=
create_sampler_output_list
(
expected_output
=
create_sampler_output_list
(
rejection_sampler_output
.
transpose
(
0
,
1
),
[
None
for
_
in
range
(
k
+
1
)])
token_ids
=
rejection_sampler_output
.
transpose
(
0
,
1
),
probs
=
[
None
for
_
in
range
(
k
+
1
)],
logprobs
=
[
None
for
_
in
range
(
k
+
1
)])
seq_ids
=
[
seq_ids
=
[
next
(
iter
(
seq_group_metadata
.
seq_data
.
keys
()))
next
(
iter
(
seq_group_metadata
.
seq_data
.
keys
()))
for
seq_group_metadata
in
execute_model_data
.
seq_group_metadata_list
for
seq_group_metadata
in
seq_group_metadata_list
]
]
actual_output_by_seq
=
{
seq_id
:
[]
for
seq_id
in
seq_ids
}
actual_output_by_seq
=
{
seq_id
:
[]
for
seq_id
in
seq_ids
}
expected_output_by_seq
=
{
seq_id
:
[]
for
seq_id
in
seq_ids
}
expected_output_by_seq
=
{
seq_id
:
[]
for
seq_id
in
seq_ids
}
...
@@ -328,7 +338,6 @@ def test_correctly_formats_output(k: int, batch_size: int):
...
@@ -328,7 +338,6 @@ def test_correctly_formats_output(k: int, batch_size: int):
continue
continue
assert
actual_by_step
[
i
].
output_token
==
expected_by_step
[
assert
actual_by_step
[
i
].
output_token
==
expected_by_step
[
i
].
output_token
i
].
output_token
assert
actual_by_step
[
i
].
logprobs
==
expected_by_step
[
i
].
logprobs
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
])
...
@@ -370,7 +379,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
...
@@ -370,7 +379,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
proposal_lens
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int64
,
proposal_lens
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
'cuda'
)
*
k
device
=
'cuda'
)
*
k
execute_model_data
,
_
,
_
=
create_batch
(
batch_size
,
k
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
draft_worker
.
get_spec_proposals
.
return_value
=
SpeculativeProposals
(
draft_worker
.
get_spec_proposals
.
return_value
=
SpeculativeProposals
(
proposal_token_ids
=
proposal_token_ids
,
proposal_token_ids
=
proposal_token_ids
,
...
@@ -387,8 +396,14 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
...
@@ -387,8 +396,14 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
vocab_size
,
vocab_size
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
device
=
'cuda'
)
target_token_logprobs
=
torch
.
rand
(
1
,
batch_size
*
(
k
+
1
),
vocab_size
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
target_output
=
create_sampler_output_list
(
target_token_ids
,
target_output
=
create_sampler_output_list
(
target_token_ids
,
target_token_probs
)
target_token_probs
,
target_token_logprobs
)
target_worker
.
execute_model
.
return_value
=
[
target_output
[
0
]]
target_worker
.
execute_model
.
return_value
=
[
target_output
[
0
]]
...
@@ -409,8 +424,9 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
...
@@ -409,8 +424,9 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
metrics_collector
.
maybe_collect_rejsample_metrics
.
return_value
=
(
metrics_collector
.
maybe_collect_rejsample_metrics
.
return_value
=
(
mock_rejsample_metrics
)
mock_rejsample_metrics
)
output
=
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
output
=
worker
.
execute_model
(
execute_model_req
=
ExecuteModelRequest
(
num_lookahead_slots
=
k
)
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
))
assert
output
[
0
].
spec_decode_worker_metrics
==
mock_rejsample_metrics
assert
output
[
0
].
spec_decode_worker_metrics
==
mock_rejsample_metrics
call_args_list
=
(
call_args_list
=
(
...
@@ -443,21 +459,21 @@ def test_k_equals_zero(k: int, batch_size: int):
...
@@ -443,21 +459,21 @@ def test_k_equals_zero(k: int, batch_size: int):
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
rejection_sampler
,
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
rejection_sampler
,
metrics_collector
)
metrics_collector
)
execute_model_data
,
prompts
,
prev_output_tokens
=
create_batch
(
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
batch_size
,
k
,
prev_output_token_len
=
0
)
k
,
prev_output_token_len
=
0
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
)
out
=
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
out
=
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
num_lookahead_slots
=
k
)
assert
len
(
out
)
==
1
,
f
"expected only one token output when
{
k
=
}
"
assert
len
(
out
)
==
1
,
f
"expected only one token output when
{
k
=
}
"
assert
out
[
0
].
probs
is
None
,
"expect gpu tensor references to be None"
assert
out
[
0
].
probs
is
None
,
"expect gpu tensor references to be None"
assert
out
[
assert
out
[
0
].
sampled_tokens
is
None
,
"expect gpu tensor references to be None"
0
].
sampled_tokens
is
None
,
"expect gpu tensor references to be None"
draft_worker
.
execute_model
.
assert_called_once_with
(
draft_worker
.
execute_model
.
assert_called_once_with
(
execute_model_req
)
**
execute_model_data
.
to_dict
())
target_worker
.
execute_model
.
assert_called_once_with
(
execute_model_req
)
target_worker
.
execute_model
.
assert_called_once_with
(
**
execute_model_data
.
to_dict
())
@
pytest
.
mark
.
parametrize
(
'k'
,
[
0
,
5
])
@
pytest
.
mark
.
parametrize
(
'k'
,
[
0
,
5
])
...
@@ -484,21 +500,21 @@ def test_empty_input_batch(k: int, batch_size: int):
...
@@ -484,21 +500,21 @@ def test_empty_input_batch(k: int, batch_size: int):
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
rejection_sampler
,
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
rejection_sampler
,
metrics_collector
)
metrics_collector
)
execute_model_data
,
prompts
,
prev_output_tokens
=
create_batch
(
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
batch_size
,
k
,
prev_output_token_len
=
0
)
k
,
prev_output_token_len
=
0
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
)
out
=
worker
.
execute_model
(
**
execute_model_data
.
to_dict
(),
out
=
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
num_lookahead_slots
=
k
)
assert
len
(
out
)
==
1
,
f
"expected only one token output when
{
k
=
}
"
assert
len
(
out
)
==
1
,
f
"expected only one token output when
{
k
=
}
"
assert
out
[
0
].
probs
is
None
,
"expect gpu tensor references to be None"
assert
out
[
0
].
probs
is
None
,
"expect gpu tensor references to be None"
assert
out
[
assert
out
[
0
].
sampled_tokens
is
None
,
"expect gpu tensor references to be None"
0
].
sampled_tokens
is
None
,
"expect gpu tensor references to be None"
draft_worker
.
execute_model
.
assert_called_once_with
(
draft_worker
.
execute_model
.
assert_called_once_with
(
execute_model_req
)
**
execute_model_data
.
to_dict
())
target_worker
.
execute_model
.
assert_called_once_with
(
execute_model_req
)
target_worker
.
execute_model
.
assert_called_once_with
(
**
execute_model_data
.
to_dict
())
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
...
...
tests/spec_decode/utils.py
View file @
1591c68f
from
dataclasses
import
dataclass
,
fields
from
itertools
import
count
from
itertools
import
count
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Union
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Union
from
unittest.mock
import
MagicMock
from
unittest.mock
import
MagicMock
...
@@ -16,50 +15,10 @@ from vllm.worker.cache_engine import CacheEngine
...
@@ -16,50 +15,10 @@ from vllm.worker.cache_engine import CacheEngine
from
vllm.worker.worker
import
Worker
from
vllm.worker.worker
import
Worker
@
dataclass
class
ExecuteModelData
:
"""Helper data structure which facilitates cleaner tests.
"""
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
blocks_to_swap_in
:
Dict
[
int
,
int
]
blocks_to_swap_out
:
Dict
[
int
,
int
]
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
def
to_dict
(
self
):
return
dict
(
(
field
.
name
,
getattr
(
self
,
field
.
name
))
for
field
in
fields
(
self
))
@
classmethod
def
from_dict
(
cls
,
d
):
cleaned
=
dict
((
field
.
name
,
d
[
field
.
name
])
for
field
in
fields
(
cls
))
return
cls
(
**
cleaned
)
def
round_up_to_next_block
(
seq_len
:
int
,
block_size
:
int
)
->
int
:
def
round_up_to_next_block
(
seq_len
:
int
,
block_size
:
int
)
->
int
:
return
(
seq_len
+
block_size
-
1
)
//
block_size
return
(
seq_len
+
block_size
-
1
)
//
block_size
def
create_execute_model_data
(
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_copy
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
)
->
ExecuteModelData
:
if
blocks_to_swap_in
is
None
:
blocks_to_swap_in
=
{}
if
blocks_to_swap_out
is
None
:
blocks_to_swap_out
=
{}
if
blocks_to_copy
is
None
:
blocks_to_copy
=
{}
return
ExecuteModelData
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
)
def
mock_worker
(
cls
=
None
,
def
mock_worker
(
cls
=
None
,
vocab_size
:
int
=
30_000
,
vocab_size
:
int
=
30_000
,
max_model_len
:
int
=
2048
,
max_model_len
:
int
=
2048
,
...
@@ -144,7 +103,7 @@ def create_seq_group_metadata_from_prompts(
...
@@ -144,7 +103,7 @@ def create_seq_group_metadata_from_prompts(
prompts
:
List
[
List
[
int
]],
prompts
:
List
[
List
[
int
]],
num_gpu_blocks
:
int
,
num_gpu_blocks
:
int
,
block_size
:
int
,
block_size
:
int
,
final_
seq
_lens
:
List
[
int
],
final_
prompt
_lens
:
List
[
int
],
continuations
:
Optional
[
List
[
List
[
int
]]]
=
None
,
continuations
:
Optional
[
List
[
List
[
int
]]]
=
None
,
seq_ids
:
Optional
[
List
[
int
]]
=
None
,
seq_ids
:
Optional
[
List
[
int
]]
=
None
,
)
->
List
[
SequenceGroupMetadata
]:
)
->
List
[
SequenceGroupMetadata
]:
...
@@ -162,7 +121,7 @@ def create_seq_group_metadata_from_prompts(
...
@@ -162,7 +121,7 @@ def create_seq_group_metadata_from_prompts(
free_gpu_blocks
.
pop
()
free_gpu_blocks
.
pop
()
for
_
in
range
(
round_up_to_next_block
(
final_len
,
block_size
))
for
_
in
range
(
round_up_to_next_block
(
final_len
,
block_size
))
]
]
for
i
,
final_len
in
enumerate
(
final_
seq
_lens
)
for
i
,
final_len
in
enumerate
(
final_
prompt
_lens
)
}
}
return
[
return
[
...
@@ -201,6 +160,7 @@ def assert_logprobs_dict_allclose(
...
@@ -201,6 +160,7 @@ def assert_logprobs_dict_allclose(
def
create_sampler_output_list
(
def
create_sampler_output_list
(
token_ids
:
torch
.
Tensor
,
token_ids
:
torch
.
Tensor
,
probs
:
Iterable
[
Optional
[
torch
.
Tensor
]],
probs
:
Iterable
[
Optional
[
torch
.
Tensor
]],
logprobs
:
Iterable
[
Optional
[
torch
.
Tensor
]],
seq_ids
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
SamplerOutput
]:
seq_ids
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
SamplerOutput
]:
num_steps
,
batch_size
=
token_ids
.
shape
num_steps
,
batch_size
=
token_ids
.
shape
token_ids_by_step
=
token_ids
.
tolist
()
token_ids_by_step
=
token_ids
.
tolist
()
...
@@ -222,6 +182,7 @@ def create_sampler_output_list(
...
@@ -222,6 +182,7 @@ def create_sampler_output_list(
)
for
seq_index
,
token_id
in
enumerate
(
token_ids_by_step
[
step
])
)
for
seq_index
,
token_id
in
enumerate
(
token_ids_by_step
[
step
])
],
],
sampled_token_probs
=
probs
[
step
],
sampled_token_probs
=
probs
[
step
],
logprobs
=
logprobs
[
step
],
sampled_token_ids
=
token_ids
[
step
])
sampled_token_ids
=
token_ids
[
step
])
for
step
in
range
(
num_steps
)
for
step
in
range
(
num_steps
)
]
]
...
@@ -251,13 +212,12 @@ def create_batch(batch_size,
...
@@ -251,13 +212,12 @@ def create_batch(batch_size,
prev_output_tokens
=
[[
prev_output_tokens
=
[[
next
(
iterator
)
for
_
in
range
(
prev_output_token_len
)
next
(
iterator
)
for
_
in
range
(
prev_output_token_len
)
]
for
_
in
range
(
batch_size
)]
]
for
_
in
range
(
batch_size
)]
final_
seq
_lens
=
[
final_
prompt
_lens
=
[
len
(
prompt
)
+
len
(
prev_output_token
)
+
k
+
1
len
(
prompt
)
+
len
(
prev_output_token
)
+
k
+
1
for
prompt
,
prev_output_token
in
zip
(
prompts
,
prev_output_tokens
)
for
prompt
,
prev_output_token
in
zip
(
prompts
,
prev_output_tokens
)
]
]
execute_model_data
=
create_execute_model_data
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
prompts
,
num_gpu_blocks
,
block_size
,
final_prompt_lens
,
block_size
,
final_seq_lens
,
prev_output_tokens
,
seq_ids
)
prev_output_tokens
,
seq_ids
),
)
return
seq_group_metadata_list
,
prompts
,
prev_output_tokens
return
execute_model_data
,
prompts
,
prev_output_tokens
tests/tensorizer_loader/tensorize_vllm_model_for_testing.py
View file @
1591c68f
...
@@ -6,14 +6,14 @@ import uuid
...
@@ -6,14 +6,14 @@ import uuid
from
functools
import
partial
from
functools
import
partial
from
typing
import
Type
from
typing
import
Type
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
tensorizer
import
(
DecryptionParams
,
EncryptionParams
,
TensorDeserializer
,
from
tensorizer
import
(
DecryptionParams
,
EncryptionParams
,
TensorDeserializer
,
TensorSerializer
,
stream_io
)
TensorSerializer
,
stream_io
)
from
tensorizer.utils
import
convert_bytes
,
get_mem_usage
,
no_init_or_tensor
from
tensorizer.utils
import
convert_bytes
,
get_mem_usage
,
no_init_or_tensor
from
transformers
import
AutoConfig
,
PretrainedConfig
from
transformers
import
AutoConfig
,
PretrainedConfig
from
vllm.distributed
import
initialize_model_parallel
from
vllm.distributed
import
(
init_distributed_environment
,
initialize_model_parallel
)
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerArgs
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerArgs
...
@@ -226,7 +226,7 @@ model_name = model_ref.split("/")[1]
...
@@ -226,7 +226,7 @@ model_name = model_ref.split("/")[1]
os
.
environ
[
"MASTER_ADDR"
]
=
"127.0.0.1"
os
.
environ
[
"MASTER_ADDR"
]
=
"127.0.0.1"
os
.
environ
[
"MASTER_PORT"
]
=
"8080"
os
.
environ
[
"MASTER_PORT"
]
=
"8080"
torch
.
distributed
.
init_process_group
(
world_size
=
1
,
rank
=
0
)
init_
distributed
_environment
(
world_size
=
1
,
rank
=
0
,
local_
rank
=
0
)
initialize_model_parallel
()
initialize_model_parallel
()
keyfile
=
args
.
keyfile
if
args
.
keyfile
else
None
keyfile
=
args
.
keyfile
if
args
.
keyfile
else
None
...
...
tests/tensorizer_loader/test_tensorizer.py
View file @
1591c68f
...
@@ -50,10 +50,10 @@ def test_load_with_tensorizer(mock_agent, tensorizer_config):
...
@@ -50,10 +50,10 @@ def test_load_with_tensorizer(mock_agent, tensorizer_config):
mock_agent_instance
.
deserialize
.
return_value
=
MagicMock
()
mock_agent_instance
.
deserialize
.
return_value
=
MagicMock
()
result
=
load_with_tensorizer
(
tensorizer_config
,
result
=
load_with_tensorizer
(
tensorizer_config
,
linear
_method
=
mock_linear_method
)
quant
_method
=
mock_linear_method
)
mock_agent
.
assert_called_once_with
(
tensorizer_config
,
mock_agent
.
assert_called_once_with
(
tensorizer_config
,
linear
_method
=
mock_linear_method
)
quant
_method
=
mock_linear_method
)
mock_agent_instance
.
deserialize
.
assert_called_once
()
mock_agent_instance
.
deserialize
.
assert_called_once
()
assert
result
==
mock_agent_instance
.
deserialize
.
return_value
assert
result
==
mock_agent_instance
.
deserialize
.
return_value
...
...
tests/test_logger.py
View file @
1591c68f
import
json
import
logging
import
os
import
os
import
sys
import
sys
import
tempfile
import
tempfile
from
json.decoder
import
JSONDecodeError
from
tempfile
import
NamedTemporaryFile
from
typing
import
Any
from
unittest.mock
import
patch
from
uuid
import
uuid4
from
vllm.logger
import
enable_trace_function_call
import
pytest
from
vllm.logger
import
(
_DATE_FORMAT
,
_FORMAT
,
_configure_vllm_root_logger
,
enable_trace_function_call
,
init_logger
)
from
vllm.logging
import
NewLineFormatter
def
f1
(
x
):
def
f1
(
x
):
...
@@ -25,3 +36,179 @@ def test_trace_function_call():
...
@@ -25,3 +36,179 @@ def test_trace_function_call():
assert
"f2"
in
content
assert
"f2"
in
content
sys
.
settrace
(
None
)
sys
.
settrace
(
None
)
os
.
remove
(
path
)
os
.
remove
(
path
)
def
test_default_vllm_root_logger_configuration
():
"""This test presumes that VLLM_CONFIGURE_LOGGING (default: True) and
VLLM_LOGGING_CONFIG_PATH (default: None) are not configured and default
behavior is activated."""
logger
=
logging
.
getLogger
(
"vllm"
)
assert
logger
.
level
==
logging
.
DEBUG
assert
not
logger
.
propagate
handler
=
logger
.
handlers
[
0
]
assert
handler
.
stream
==
sys
.
stdout
assert
handler
.
level
==
logging
.
INFO
formatter
=
handler
.
formatter
assert
formatter
is
not
None
assert
isinstance
(
formatter
,
NewLineFormatter
)
assert
formatter
.
_fmt
==
_FORMAT
assert
formatter
.
datefmt
==
_DATE_FORMAT
@
patch
(
"vllm.logger.VLLM_CONFIGURE_LOGGING"
,
1
)
@
patch
(
"vllm.logger.VLLM_LOGGING_CONFIG_PATH"
,
None
)
def
test_descendent_loggers_depend_on_and_propagate_logs_to_root_logger
():
"""This test presumes that VLLM_CONFIGURE_LOGGING (default: True) and
VLLM_LOGGING_CONFIG_PATH (default: None) are not configured and default
behavior is activated."""
root_logger
=
logging
.
getLogger
(
"vllm"
)
root_handler
=
root_logger
.
handlers
[
0
]
unique_name
=
f
"vllm.
{
uuid4
()
}
"
logger
=
init_logger
(
unique_name
)
assert
logger
.
name
==
unique_name
assert
logger
.
level
==
logging
.
NOTSET
assert
not
logger
.
handlers
assert
logger
.
propagate
message
=
"Hello, world!"
with
patch
.
object
(
root_handler
,
"emit"
)
as
root_handle_mock
:
logger
.
info
(
message
)
root_handle_mock
.
assert_called_once
()
_
,
call_args
,
_
=
root_handle_mock
.
mock_calls
[
0
]
log_record
=
call_args
[
0
]
assert
unique_name
==
log_record
.
name
assert
message
==
log_record
.
msg
assert
message
==
log_record
.
msg
assert
log_record
.
levelno
==
logging
.
INFO
@
patch
(
"vllm.logger.VLLM_CONFIGURE_LOGGING"
,
0
)
@
patch
(
"vllm.logger.VLLM_LOGGING_CONFIG_PATH"
,
None
)
def
test_logger_configuring_can_be_disabled
():
"""This test calls _configure_vllm_root_logger again to test custom logging
config behavior, however mocks are used to ensure no changes in behavior or
configuration occur."""
with
patch
(
"logging.config.dictConfig"
)
as
dict_config_mock
:
_configure_vllm_root_logger
()
dict_config_mock
.
assert_not_called
()
@
patch
(
"vllm.logger.VLLM_CONFIGURE_LOGGING"
,
1
)
@
patch
(
"vllm.logger.VLLM_LOGGING_CONFIG_PATH"
,
"/if/there/is/a/file/here/then/you/did/this/to/yourself.json"
,
)
def
test_an_error_is_raised_when_custom_logging_config_file_does_not_exist
():
"""This test calls _configure_vllm_root_logger again to test custom logging
config behavior, however it fails before any change in behavior or
configuration occurs."""
with
pytest
.
raises
(
RuntimeError
)
as
ex_info
:
_configure_vllm_root_logger
()
assert
ex_info
.
type
==
RuntimeError
assert
"File does not exist"
in
str
(
ex_info
)
@
patch
(
"vllm.logger.VLLM_CONFIGURE_LOGGING"
,
1
)
def
test_an_error_is_raised_when_custom_logging_config_is_invalid_json
():
"""This test calls _configure_vllm_root_logger again to test custom logging
config behavior, however it fails before any change in behavior or
configuration occurs."""
with
NamedTemporaryFile
(
encoding
=
"utf-8"
,
mode
=
"w"
)
as
logging_config_file
:
logging_config_file
.
write
(
"---
\n
loggers: []
\n
version: 1"
)
logging_config_file
.
flush
()
with
patch
(
"vllm.logger.VLLM_LOGGING_CONFIG_PATH"
,
logging_config_file
.
name
):
with
pytest
.
raises
(
JSONDecodeError
)
as
ex_info
:
_configure_vllm_root_logger
()
assert
ex_info
.
type
==
JSONDecodeError
assert
"Expecting value"
in
str
(
ex_info
)
@
patch
(
"vllm.logger.VLLM_CONFIGURE_LOGGING"
,
1
)
@
pytest
.
mark
.
parametrize
(
"unexpected_config"
,
(
"Invalid string"
,
[{
"version"
:
1
,
"loggers"
:
[]
}],
0
,
))
def
test_an_error_is_raised_when_custom_logging_config_is_unexpected_json
(
unexpected_config
:
Any
):
"""This test calls _configure_vllm_root_logger again to test custom logging
config behavior, however it fails before any change in behavior or
configuration occurs."""
with
NamedTemporaryFile
(
encoding
=
"utf-8"
,
mode
=
"w"
)
as
logging_config_file
:
logging_config_file
.
write
(
json
.
dumps
(
unexpected_config
))
logging_config_file
.
flush
()
with
patch
(
"vllm.logger.VLLM_LOGGING_CONFIG_PATH"
,
logging_config_file
.
name
):
with
pytest
.
raises
(
ValueError
)
as
ex_info
:
_configure_vllm_root_logger
()
assert
ex_info
.
type
==
ValueError
assert
"Invalid logging config. Expected Dict, got"
in
str
(
ex_info
)
@
patch
(
"vllm.logger.VLLM_CONFIGURE_LOGGING"
,
1
)
def
test_custom_logging_config_is_parsed_and_used_when_provided
():
"""This test calls _configure_vllm_root_logger again to test custom logging
config behavior, however mocks are used to ensure no changes in behavior or
configuration occur."""
valid_logging_config
=
{
"loggers"
:
{
"vllm.test_logger.logger"
:
{
"handlers"
:
[],
"propagate"
:
False
,
}
},
"version"
:
1
}
with
NamedTemporaryFile
(
encoding
=
"utf-8"
,
mode
=
"w"
)
as
logging_config_file
:
logging_config_file
.
write
(
json
.
dumps
(
valid_logging_config
))
logging_config_file
.
flush
()
with
patch
(
"vllm.logger.VLLM_LOGGING_CONFIG_PATH"
,
logging_config_file
.
name
),
patch
(
"logging.config.dictConfig"
)
as
dict_config_mock
:
_configure_vllm_root_logger
()
assert
dict_config_mock
.
called_with
(
valid_logging_config
)
@
patch
(
"vllm.logger.VLLM_CONFIGURE_LOGGING"
,
0
)
def
test_custom_logging_config_causes_an_error_if_configure_logging_is_off
():
"""This test calls _configure_vllm_root_logger again to test custom logging
config behavior, however mocks are used to ensure no changes in behavior or
configuration occur."""
valid_logging_config
=
{
"loggers"
:
{
"vllm.test_logger.logger"
:
{
"handlers"
:
[],
}
},
"version"
:
1
}
with
NamedTemporaryFile
(
encoding
=
"utf-8"
,
mode
=
"w"
)
as
logging_config_file
:
logging_config_file
.
write
(
json
.
dumps
(
valid_logging_config
))
logging_config_file
.
flush
()
with
patch
(
"vllm.logger.VLLM_LOGGING_CONFIG_PATH"
,
logging_config_file
.
name
):
with
pytest
.
raises
(
RuntimeError
)
as
ex_info
:
_configure_vllm_root_logger
()
assert
ex_info
.
type
is
RuntimeError
expected_message_snippet
=
(
"VLLM_CONFIGURE_LOGGING evaluated to false, but "
"VLLM_LOGGING_CONFIG_PATH was given."
)
assert
expected_message_snippet
in
str
(
ex_info
)
# Remember! The root logger is assumed to have been configured as
# though VLLM_CONFIGURE_LOGGING=1 and VLLM_LOGGING_CONFIG_PATH=None.
root_logger
=
logging
.
getLogger
(
"vllm"
)
other_logger_name
=
f
"vllm.test_logger.
{
uuid4
()
}
"
other_logger
=
init_logger
(
other_logger_name
)
assert
other_logger
.
handlers
!=
root_logger
.
handlers
assert
other_logger
.
level
!=
root_logger
.
level
assert
other_logger
.
propagate
tests/test_logits_processor.py
View file @
1591c68f
...
@@ -6,6 +6,7 @@ import pytest
...
@@ -6,6 +6,7 @@ import pytest
import
torch
import
torch
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.worker.model_runner
import
ModelRunner
from
vllm.worker.model_runner
import
ModelRunner
...
@@ -69,7 +70,7 @@ def test_logits_processors(seed: int, device: str):
...
@@ -69,7 +70,7 @@ def test_logits_processors(seed: int, device: str):
return
logits
return
logits
seq_group_metadata_list
=
[]
seq_group_metadata_list
=
[]
prompt
_lens
=
[]
seq
_lens
=
[]
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
SequenceGroupMetadata
(
...
@@ -80,11 +81,14 @@ def test_logits_processors(seed: int, device: str):
...
@@ -80,11 +81,14 @@ def test_logits_processors(seed: int, device: str):
logits_processors
=
[
pick_ith
]),
logits_processors
=
[
pick_ith
]),
block_tables
=
{
0
:
[
1
]},
block_tables
=
{
0
:
[
1
]},
))
))
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
seq_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
sampling_metadata
=
SamplingMetadata
.
prepare
(
prompt_lens
,
seq_group_metadata_list
,
subquery_lens
=
prompt_lens
)
seq_lens
,
query_lens
=
seq_lens
,
device
=
model_runner
.
device
,
pin_memory
=
model_runner
.
pin_memory
)
logits_processor_output
=
logits_processor
(
logits_processor_output
=
logits_processor
(
embedding
=
None
,
embedding
=
None
,
hidden_states
=
input_tensor
,
hidden_states
=
input_tensor
,
...
...
tests/tokenization/test_tokenizer.py
0 → 100644
View file @
1591c68f
import
pytest
from
transformers
import
PreTrainedTokenizerBase
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
TOKENIZER_NAMES
=
[
"facebook/opt-125m"
,
"gpt2"
,
]
@
pytest
.
mark
.
parametrize
(
"tokenizer_name"
,
TOKENIZER_NAMES
)
def
test_tokenizer_revision
(
tokenizer_name
:
str
):
# Assume that "main" branch always exists
tokenizer
=
get_tokenizer
(
tokenizer_name
,
revision
=
"main"
)
assert
isinstance
(
tokenizer
,
PreTrainedTokenizerBase
)
# Assume that "never" branch always does not exist
with
pytest
.
raises
(
OSError
,
match
=
'not a valid git identifier'
):
get_tokenizer
(
tokenizer_name
,
revision
=
"never"
)
tests/worker/test_model_runner.py
View file @
1591c68f
...
@@ -2,7 +2,10 @@ import pytest
...
@@ -2,7 +2,10 @@ import pytest
import
torch
import
torch
from
vllm.config
import
ModelConfig
,
SchedulerConfig
from
vllm.config
import
ModelConfig
,
SchedulerConfig
from
vllm.distributed.parallel_state
import
init_distributed_environment
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
get_open_port
from
vllm.worker.model_runner
import
ModelRunner
,
_get_graph_batch_size
from
vllm.worker.model_runner
import
ModelRunner
,
_get_graph_batch_size
...
@@ -20,14 +23,14 @@ def test_prepare_prompt(batch_size):
...
@@ -20,14 +23,14 @@ def test_prepare_prompt(batch_size):
lora_config
=
None
)
lora_config
=
None
)
model_runner
.
set_block_size
(
16
)
model_runner
.
set_block_size
(
16
)
prompt
_lens
=
[]
seq
_lens
=
[]
seq_group_metadata_list
=
[]
seq_group_metadata_list
=
[]
block_tables
=
{
0
:
[
1
]}
block_tables
=
{
0
:
[
1
]}
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
# make sure all tokens fit into one block
# make sure all tokens fit into one block
prompt
_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq
_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
prompt
_lens
.
append
(
prompt
_len
)
seq
_lens
.
append
(
seq
_len
)
seq_data
=
SequenceData
(
list
(
range
(
prompt
_len
)))
seq_data
=
SequenceData
(
list
(
range
(
seq
_len
)))
seq_group_metadata
=
SequenceGroupMetadata
(
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
is_prompt
=
True
,
...
@@ -40,29 +43,29 @@ def test_prepare_prompt(batch_size):
...
@@ -40,29 +43,29 @@ def test_prepare_prompt(batch_size):
expected_selected_token_indices
=
[]
expected_selected_token_indices
=
[]
selected_token_start_idx
=
0
selected_token_start_idx
=
0
for
prompt
_len
in
prompt
_lens
:
for
seq
_len
in
seq
_lens
:
expected_selected_token_indices
.
append
(
selected_token_start_idx
+
expected_selected_token_indices
.
append
(
selected_token_start_idx
+
prompt_len
-
1
)
seq_len
-
1
)
selected_token_start_idx
+=
prompt_len
selected_token_start_idx
+=
seq_len
(
input_tokens
,
input_positions
,
attn_metadata
,
return_prompt_lens
,
_
,
_
,
_
,
(
input_tokens
,
input_positions
,
attn_metadata
,
return_seq_lens
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
slot_mapping
)
=
(
model_runner
.
_prepare_prompt
(
seq_group_metadata_list
))
slot_mapping
)
=
(
model_runner
.
_prepare_prompt
(
seq_group_metadata_list
))
assert
return_seq_lens
==
seq_lens
assert
return_prompt_lens
==
prompt_lens
assert
len
(
slot_mapping
)
==
len
(
input_tokens
)
assert
len
(
slot_mapping
)
==
len
(
input_tokens
)
# Verify input metadata is correct for prompts.
# Verify input metadata is correct for prompts.
device
=
model_runner
.
device
device
=
model_runner
.
device
assert
attn_metadata
.
is_prompt
is
True
assert
attn_metadata
.
is_prompt
is
True
assert
torch
.
allclose
(
attn_metadata
.
prompt_lens_tensor
,
assert
torch
.
allclose
(
torch
.
tensor
(
prompt_lens
,
device
=
device
))
attn_metadata
.
seq_lens_tensor
,
assert
attn_metadata
.
prompt_lens
==
prompt_lens
torch
.
tensor
(
seq_lens
,
device
=
device
,
dtype
=
torch
.
int
))
assert
attn_metadata
.
max_prompt_len
==
max
(
prompt_lens
)
assert
attn_metadata
.
seq_lens
==
seq_lens
assert
attn_metadata
.
max_seq_len
==
max
(
seq_lens
)
# Test subquery start locs.
# Test subquery start locs.
start_idx
=
0
start_idx
=
0
start_loc
=
[
start_idx
]
start_loc
=
[
start_idx
]
for
prompt
_len
in
prompt
_lens
:
for
seq
_len
in
seq
_lens
:
start_idx
+=
prompt
_len
start_idx
+=
seq
_len
start_loc
.
append
(
start_idx
)
start_loc
.
append
(
start_idx
)
assert
torch
.
allclose
(
assert
torch
.
allclose
(
attn_metadata
.
subquery_start_loc
,
attn_metadata
.
subquery_start_loc
,
...
@@ -72,17 +75,16 @@ def test_prepare_prompt(batch_size):
...
@@ -72,17 +75,16 @@ def test_prepare_prompt(batch_size):
# equivalent to subquery_start_loc.
# equivalent to subquery_start_loc.
start_idx
=
0
start_idx
=
0
seq_start_loc
=
[
start_idx
]
seq_start_loc
=
[
start_idx
]
for
prompt
_len
in
prompt
_lens
:
for
seq
_len
in
seq
_lens
:
start_idx
+=
prompt
_len
start_idx
+=
seq
_len
seq_start_loc
.
append
(
start_idx
)
seq_start_loc
.
append
(
start_idx
)
assert
torch
.
allclose
(
assert
torch
.
allclose
(
attn_metadata
.
seq_start_loc
,
attn_metadata
.
seq_start_loc
,
torch
.
tensor
(
start_loc
,
dtype
=
torch
.
int32
,
device
=
device
))
torch
.
tensor
(
start_loc
,
dtype
=
torch
.
int32
,
device
=
device
))
assert
attn_metadata
.
max_context_len
is
None
assert
torch
.
allclose
(
assert
torch
.
allclose
(
attn_metadata
.
context_lens
,
attn_metadata
.
context_lens
_tensor
,
torch
.
zeros
(
attn_metadata
.
context_lens
.
shape
[
0
],
torch
.
zeros
(
attn_metadata
.
context_lens
_tensor
.
shape
[
0
],
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
device
=
device
))
device
=
device
))
...
@@ -93,15 +95,18 @@ def test_prepare_prompt(batch_size):
...
@@ -93,15 +95,18 @@ def test_prepare_prompt(batch_size):
# Cuda graph should not be used for prerill.
# Cuda graph should not be used for prerill.
assert
attn_metadata
.
use_cuda_graph
is
False
assert
attn_metadata
.
use_cuda_graph
is
False
assert
len
(
input_tokens
)
==
sum
(
prompt
_lens
)
assert
len
(
input_tokens
)
==
sum
(
seq
_lens
)
assert
len
(
input_positions
)
==
sum
(
prompt
_lens
)
assert
len
(
input_positions
)
==
sum
(
seq
_lens
)
torch
.
testing
.
assert_close
(
input_tokens
,
input_positions
)
torch
.
testing
.
assert_close
(
input_tokens
,
input_positions
)
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
sampling_metadata
=
SamplingMetadata
.
prepare
(
prompt_lens
,
seq_group_metadata_list
,
subquery_lens
=
prompt_lens
)
seq_lens
,
assert
len
(
input_tokens
)
==
sum
(
prompt_lens
)
query_lens
=
seq_lens
,
assert
len
(
input_positions
)
==
sum
(
prompt_lens
)
device
=
model_runner
.
device
,
pin_memory
=
model_runner
.
pin_memory
)
assert
len
(
input_tokens
)
==
sum
(
seq_lens
)
assert
len
(
input_positions
)
==
sum
(
seq_lens
)
actual
=
sampling_metadata
.
selected_token_indices
actual
=
sampling_metadata
.
selected_token_indices
expected
=
torch
.
tensor
(
expected_selected_token_indices
,
expected
=
torch
.
tensor
(
expected_selected_token_indices
,
device
=
actual
.
device
,
device
=
actual
.
device
,
...
@@ -140,13 +145,13 @@ def test_prepare_decode_cuda_graph(batch_size):
...
@@ -140,13 +145,13 @@ def test_prepare_decode_cuda_graph(batch_size):
lora_config
=
None
)
lora_config
=
None
)
model_runner
.
set_block_size
(
16
)
model_runner
.
set_block_size
(
16
)
prompt
_lens
=
[]
seq
_lens
=
[]
seq_group_metadata_list
=
[]
seq_group_metadata_list
=
[]
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
# make sure all tokens fit into one block
# make sure all tokens fit into one block
prompt
_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq
_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
prompt
_lens
.
append
(
prompt
_len
)
seq
_lens
.
append
(
seq
_len
)
seq_data
=
list
(
range
(
prompt
_len
))
seq_data
=
list
(
range
(
seq
_len
))
seq_data
=
SequenceData
(
seq_data
)
seq_data
=
SequenceData
(
seq_data
)
seq_group_metadata
=
SequenceGroupMetadata
(
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
...
@@ -166,14 +171,13 @@ def test_prepare_decode_cuda_graph(batch_size):
...
@@ -166,14 +171,13 @@ def test_prepare_decode_cuda_graph(batch_size):
# Verify input metadata is correct for prompts.
# Verify input metadata is correct for prompts.
device
=
model_runner
.
device
device
=
model_runner
.
device
assert
attn_metadata
.
is_prompt
is
False
assert
attn_metadata
.
is_prompt
is
False
assert
attn_metadata
.
prompt_lens
is
None
assert
attn_metadata
.
seq_lens
is
None
assert
attn_metadata
.
max_prompt_len
is
None
assert
attn_metadata
.
subquery_start_loc
is
None
assert
attn_metadata
.
subquery_start_loc
is
None
assert
attn_metadata
.
seq_start_loc
is
None
assert
attn_metadata
.
seq_start_loc
is
None
assert
attn_metadata
.
max_
context
_len
==
max
(
prompt
_lens
)
assert
attn_metadata
.
max_
seq
_len
==
max
(
seq
_lens
)
assert
torch
.
allclose
(
assert
torch
.
allclose
(
attn_metadata
.
context_lens
[:
len
(
prompt
_lens
)],
attn_metadata
.
seq_lens_tensor
[:
len
(
seq
_lens
)],
torch
.
tensor
(
prompt
_lens
,
dtype
=
torch
.
int
,
device
=
device
))
torch
.
tensor
(
seq
_lens
,
dtype
=
torch
.
int
,
device
=
device
))
# block table's first index corresponds to each batch, meaning in
# block table's first index corresponds to each batch, meaning in
# decoding it is each token.
# decoding it is each token.
...
@@ -192,12 +196,15 @@ def test_prepare_decode_cuda_graph(batch_size):
...
@@ -192,12 +196,15 @@ def test_prepare_decode_cuda_graph(batch_size):
# Verify Sampling
# Verify Sampling
expected_selected_token_indices
=
[]
expected_selected_token_indices
=
[]
selected_token_start_idx
=
0
selected_token_start_idx
=
0
for
prompt
_len
in
prompt
_lens
:
for
seq
_len
in
seq
_lens
:
expected_selected_token_indices
.
append
(
selected_token_start_idx
)
expected_selected_token_indices
.
append
(
selected_token_start_idx
)
selected_token_start_idx
+=
1
selected_token_start_idx
+=
1
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
sampling_metadata
=
SamplingMetadata
.
prepare
(
prompt_lens
,
seq_group_metadata_list
,
subquery_lens
=
prompt_lens
)
seq_lens
,
query_lens
=
seq_lens
,
device
=
model_runner
.
device
,
pin_memory
=
model_runner
.
pin_memory
)
actual
=
sampling_metadata
.
selected_token_indices
actual
=
sampling_metadata
.
selected_token_indices
expected
=
torch
.
tensor
(
expected_selected_token_indices
,
expected
=
torch
.
tensor
(
expected_selected_token_indices
,
device
=
actual
.
device
,
device
=
actual
.
device
,
...
@@ -232,29 +239,27 @@ def test_empty_seq_group():
...
@@ -232,29 +239,27 @@ def test_empty_seq_group():
assert
attn_metadata
is
None
assert
attn_metadata
is
None
assert
len
(
slot_mapping
)
==
0
assert
len
(
slot_mapping
)
==
0
(
input_tokens
,
input_positions
,
attn_metadata
,
return_prompt_lens
,
_
,
_
,
_
,
(
input_tokens
,
input_positions
,
attn_metadata
,
return_seq_lens
,
_
,
_
,
_
,
_
,
_
,
_
,
_
,
slot_mapping
)
=
(
model_runner
.
_prepare_prompt
(
seq_group_metadata_list
))
slot_mapping
)
=
(
model_runner
.
_prepare_prompt
(
seq_group_metadata_list
))
assert
len
(
input_tokens
)
==
0
assert
len
(
input_tokens
)
==
0
assert
len
(
input_positions
)
==
0
assert
len
(
input_positions
)
==
0
assert
attn_metadata
is
None
assert
attn_metadata
is
None
assert
len
(
slot_mapping
)
==
0
assert
len
(
slot_mapping
)
==
0
assert
len
(
return_prompt_lens
)
==
0
assert
len
(
return_seq_lens
)
==
0
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
list
(
range
(
2
,
128
)))
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
,
False
])
def
test_hybrid_batches
(
batch_size
,
enforce_eager
,
monkeypatch
):
def
get_world_size
(
group
=
None
):
@
pytest
.
fixture
return
1
def
distributed_init
():
init_distributed_environment
(
world_size
=
1
,
rank
=
0
,
distributed_init_method
=
f
"tcp://127.0.0.1:
{
get_open_port
()
}
"
,
local_rank
=
0
)
def
mock_get_process_group_ranks
(
group
=
None
):
return
[
0
]
monkeypatch
.
setattr
(
torch
.
distributed
,
"get_world_size"
,
get_world_size
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
list
(
range
(
2
,
128
))
)
monkeypatch
.
setattr
(
torch
.
distributed
,
"get_process_group_ranks"
,
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
,
False
])
mock_get_process_group_ranks
)
def
test_hybrid_batches
(
batch_size
,
enforce_eager
,
distributed_init
):
model_config
=
ModelConfig
(
model_config
=
ModelConfig
(
"facebook/opt-125m"
,
"facebook/opt-125m"
,
...
@@ -280,7 +285,7 @@ def test_hybrid_batches(batch_size, enforce_eager, monkeypatch):
...
@@ -280,7 +285,7 @@ def test_hybrid_batches(batch_size, enforce_eager, monkeypatch):
model_runner
.
set_block_size
(
16
)
model_runner
.
set_block_size
(
16
)
# Add prefill requests.
# Add prefill requests.
prompt
_lens
=
[]
seq
_lens
=
[]
seq_group_metadata_list
=
[]
seq_group_metadata_list
=
[]
prefill_metadata_list
=
[]
prefill_metadata_list
=
[]
decode_metadata_list
=
[]
decode_metadata_list
=
[]
...
@@ -289,9 +294,9 @@ def test_hybrid_batches(batch_size, enforce_eager, monkeypatch):
...
@@ -289,9 +294,9 @@ def test_hybrid_batches(batch_size, enforce_eager, monkeypatch):
decode_batch_size
=
batch_size
-
prefill_batch_size
decode_batch_size
=
batch_size
-
prefill_batch_size
for
i
in
range
(
prefill_batch_size
):
for
i
in
range
(
prefill_batch_size
):
# make sure all tokens fit into one block
# make sure all tokens fit into one block
prompt
_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq
_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
prompt
_lens
.
append
(
prompt
_len
)
seq
_lens
.
append
(
seq
_len
)
seq_data
=
SequenceData
(
list
(
range
(
prompt
_len
)))
seq_data
=
SequenceData
(
list
(
range
(
seq
_len
)))
seq_group_metadata
=
SequenceGroupMetadata
(
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
is_prompt
=
True
,
...
@@ -306,8 +311,8 @@ def test_hybrid_batches(batch_size, enforce_eager, monkeypatch):
...
@@ -306,8 +311,8 @@ def test_hybrid_batches(batch_size, enforce_eager, monkeypatch):
# Add decode requests
# Add decode requests
for
i
in
range
(
prefill_batch_size
,
batch_size
):
for
i
in
range
(
prefill_batch_size
,
batch_size
):
# make sure all tokens fit into one block
# make sure all tokens fit into one block
prompt
_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq
_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
prompt_toks
=
list
(
range
(
prompt
_len
))
prompt_toks
=
list
(
range
(
seq
_len
))
seq_data
=
SequenceData
(
prompt_toks
)
seq_data
=
SequenceData
(
prompt_toks
)
seq_group_metadata
=
SequenceGroupMetadata
(
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
...
@@ -335,7 +340,7 @@ def test_hybrid_batches(batch_size, enforce_eager, monkeypatch):
...
@@ -335,7 +340,7 @@ def test_hybrid_batches(batch_size, enforce_eager, monkeypatch):
else
:
else
:
assert
attn_metadata
.
num_decode_tokens
==
_get_graph_batch_size
(
assert
attn_metadata
.
num_decode_tokens
==
_get_graph_batch_size
(
decode_batch_size
)
decode_batch_size
)
assert
attn_metadata
.
num_prefill_tokens
==
sum
(
prompt
_lens
)
assert
attn_metadata
.
num_prefill_tokens
==
sum
(
seq
_lens
)
# Verify attn metadata is consistent. We don't need to test individual
# Verify attn metadata is consistent. We don't need to test individual
# values here because they are tested above.
# values here because they are tested above.
...
...
tests/worker/test_swap.py
View file @
1591c68f
import
torch
import
torch
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.worker.worker
import
Worker
from
vllm.worker.worker
import
Worker
...
@@ -54,10 +55,14 @@ def test_swap() -> None:
...
@@ -54,10 +55,14 @@ def test_swap() -> None:
# Test swap out.
# Test swap out.
blocks_to_swap_out
=
{
3
:
72
,
56
:
35
,
84
:
34
}
blocks_to_swap_out
=
{
3
:
72
,
56
:
35
,
84
:
34
}
worker
.
execute_model
(
seq_group_metadata_list
=
[],
execute_model_req
=
ExecuteModelRequest
(
blocks_to_swap_in
=
{},
seq_group_metadata_list
=
[],
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_swap_in
=
{},
blocks_to_copy
=
{})
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
{},
)
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
gpu_key_cache
,
gpu_value_cache
=
gpu_cache
[
i
]
gpu_key_cache
,
gpu_value_cache
=
gpu_cache
[
i
]
cpu_key_cache
,
cpu_value_cache
=
cpu_cache
[
i
]
cpu_key_cache
,
cpu_value_cache
=
cpu_cache
[
i
]
...
@@ -66,14 +71,19 @@ def test_swap() -> None:
...
@@ -66,14 +71,19 @@ def test_swap() -> None:
assert
allclose
(
gpu_value_cache
[
src
],
cpu_value_cache
[
dst
])
assert
allclose
(
gpu_value_cache
[
src
],
cpu_value_cache
[
dst
])
# Test swap in.
# Test swap in.
blocks_to_swap_in
=
{
19
:
45
,
67
:
23
,
12
:
78
,
40
:
99
,
1
:
71
}
execute_model_req
.
blocks_to_swap_out
=
{}
worker
.
execute_model
(
seq_group_metadata_list
=
[],
execute_model_req
.
blocks_to_swap_in
=
{
blocks_to_swap_in
=
blocks_to_swap_in
,
19
:
45
,
blocks_to_swap_out
=
{},
67
:
23
,
blocks_to_copy
=
{})
12
:
78
,
40
:
99
,
1
:
71
}
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
gpu_key_cache
,
gpu_value_cache
=
gpu_cache
[
i
]
gpu_key_cache
,
gpu_value_cache
=
gpu_cache
[
i
]
cpu_key_cache
,
cpu_value_cache
=
cpu_cache
[
i
]
cpu_key_cache
,
cpu_value_cache
=
cpu_cache
[
i
]
for
src
,
dst
in
blocks_to_swap_in
.
items
():
for
src
,
dst
in
execute_model_req
.
blocks_to_swap_in
.
items
():
assert
allclose
(
gpu_key_cache
[
dst
],
cpu_key_cache
[
src
])
assert
allclose
(
gpu_key_cache
[
dst
],
cpu_key_cache
[
src
])
assert
allclose
(
gpu_value_cache
[
dst
],
cpu_value_cache
[
src
])
assert
allclose
(
gpu_value_cache
[
dst
],
cpu_value_cache
[
src
])
Prev
1
2
3
4
5
6
7
8
9
10
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment