Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cf069aa8
Unverified
Commit
cf069aa8
authored
Mar 03, 2025
by
Harry Mellor
Committed by
GitHub
Mar 02, 2025
Browse files
Update deprecated Python 3.8 typing (#13971)
parent
bf33700e
Changes
300
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
114 additions
and
125 deletions
+114
-125
tests/quantization/test_configs.py
tests/quantization/test_configs.py
+1
-2
tests/quantization/test_register_quantization_config.py
tests/quantization/test_register_quantization_config.py
+4
-4
tests/samplers/test_logprobs.py
tests/samplers/test_logprobs.py
+1
-3
tests/samplers/test_no_bad_words.py
tests/samplers/test_no_bad_words.py
+8
-8
tests/samplers/test_rejection_sampler.py
tests/samplers/test_rejection_sampler.py
+5
-6
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+22
-22
tests/spec_decode/e2e/conftest.py
tests/spec_decode/e2e/conftest.py
+5
-4
tests/spec_decode/test_batch_expansion.py
tests/spec_decode/test_batch_expansion.py
+1
-3
tests/spec_decode/test_multi_step_worker.py
tests/spec_decode/test_multi_step_worker.py
+7
-8
tests/spec_decode/test_scorer.py
tests/spec_decode/test_scorer.py
+1
-2
tests/spec_decode/test_spec_decode_worker.py
tests/spec_decode/test_spec_decode_worker.py
+5
-6
tests/spec_decode/utils.py
tests/spec_decode/utils.py
+16
-17
tests/test_cache_block_hashing.py
tests/test_cache_block_hashing.py
+3
-3
tests/test_inputs.py
tests/test_inputs.py
+1
-3
tests/test_logger.py
tests/test_logger.py
+1
-1
tests/test_logits_processor.py
tests/test_logits_processor.py
+1
-2
tests/test_utils.py
tests/test_utils.py
+2
-2
tests/tokenization/test_detokenize.py
tests/tokenization/test_detokenize.py
+12
-11
tests/tokenization/test_tokenizer_group.py
tests/tokenization/test_tokenizer_group.py
+2
-2
tests/tokenization/test_tokenizer_registry.py
tests/tokenization/test_tokenizer_registry.py
+16
-16
No files found.
tests/quantization/test_configs.py
View file @
cf069aa8
...
@@ -5,7 +5,6 @@ Run `pytest tests/quantization/test_configs.py --forked`.
...
@@ -5,7 +5,6 @@ Run `pytest tests/quantization/test_configs.py --forked`.
"""
"""
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Tuple
import
pytest
import
pytest
...
@@ -53,7 +52,7 @@ MODEL_ARG_EXPTYPES = [
...
@@ -53,7 +52,7 @@ MODEL_ARG_EXPTYPES = [
@
pytest
.
mark
.
parametrize
(
"model_arg_exptype"
,
MODEL_ARG_EXPTYPES
)
@
pytest
.
mark
.
parametrize
(
"model_arg_exptype"
,
MODEL_ARG_EXPTYPES
)
def
test_auto_gptq
(
model_arg_exptype
:
T
uple
[
str
,
None
,
str
])
->
None
:
def
test_auto_gptq
(
model_arg_exptype
:
t
uple
[
str
,
None
,
str
])
->
None
:
model_path
,
quantization_arg
,
expected_type
=
model_arg_exptype
model_path
,
quantization_arg
,
expected_type
=
model_arg_exptype
try
:
try
:
...
...
tests/quantization/test_register_quantization_config.py
View file @
cf069aa8
...
@@ -5,7 +5,7 @@ See https://github.com/vllm-project/vllm/issues/11926 for more details.
...
@@ -5,7 +5,7 @@ See https://github.com/vllm-project/vllm/issues/11926 for more details.
Run `pytest tests/quantization/test_register_quantization_config.py`.
Run `pytest tests/quantization/test_register_quantization_config.py`.
"""
"""
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Optional
import
pytest
import
pytest
import
torch
import
torch
...
@@ -58,7 +58,7 @@ class CustomQuantConfig(QuantizationConfig):
...
@@ -58,7 +58,7 @@ class CustomQuantConfig(QuantizationConfig):
"""Name of the quantization method."""
"""Name of the quantization method."""
return
"custom_quant"
return
"custom_quant"
def
get_supported_act_dtypes
(
self
)
->
L
ist
[
"torch.dtype"
]:
def
get_supported_act_dtypes
(
self
)
->
l
ist
[
"torch.dtype"
]:
"""List of supported activation dtypes."""
"""List of supported activation dtypes."""
return
[
torch
.
float16
,
torch
.
bfloat16
]
return
[
torch
.
float16
,
torch
.
bfloat16
]
...
@@ -68,12 +68,12 @@ class CustomQuantConfig(QuantizationConfig):
...
@@ -68,12 +68,12 @@ class CustomQuantConfig(QuantizationConfig):
return
-
1
return
-
1
@
staticmethod
@
staticmethod
def
get_config_filenames
()
->
L
ist
[
str
]:
def
get_config_filenames
()
->
l
ist
[
str
]:
"""List of filenames to search for in the model directory."""
"""List of filenames to search for in the model directory."""
return
[]
return
[]
@
classmethod
@
classmethod
def
from_config
(
cls
,
config
:
D
ict
[
str
,
Any
])
->
"CustomQuantConfig"
:
def
from_config
(
cls
,
config
:
d
ict
[
str
,
Any
])
->
"CustomQuantConfig"
:
"""Create a config class from the model's quantization config."""
"""Create a config class from the model's quantization config."""
return
CustomQuantConfig
(
num_bits
=
config
.
get
(
"num_bits"
,
8
))
return
CustomQuantConfig
(
num_bits
=
config
.
get
(
"num_bits"
,
8
))
...
...
tests/samplers/test_logprobs.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
import
pytest
import
pytest
import
torch
import
torch
...
@@ -70,7 +68,7 @@ def test_get_prompt_logprobs(
...
@@ -70,7 +68,7 @@ def test_get_prompt_logprobs(
assert
(
len
(
logprobs
)
==
num_top_logprobs
assert
(
len
(
logprobs
)
==
num_top_logprobs
or
len
(
logprobs
)
==
num_top_logprobs
+
1
)
or
len
(
logprobs
)
==
num_top_logprobs
+
1
)
output_text
=
result
.
outputs
[
0
].
text
output_text
=
result
.
outputs
[
0
].
text
output_string_from_most_likely_tokens_lst
:
L
ist
[
str
]
=
[]
output_string_from_most_likely_tokens_lst
:
l
ist
[
str
]
=
[]
for
top_logprobs
in
result
.
outputs
[
0
].
logprobs
:
for
top_logprobs
in
result
.
outputs
[
0
].
logprobs
:
top_logprob
=
next
(
iter
(
top_logprobs
.
values
()))
top_logprob
=
next
(
iter
(
top_logprobs
.
values
()))
output_string_from_most_likely_tokens_lst
.
append
(
output_string_from_most_likely_tokens_lst
.
append
(
...
...
tests/samplers/test_no_bad_words.py
View file @
cf069aa8
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
Run `pytest tests/samplers/test_no_bad_words.py`.
Run `pytest tests/samplers/test_no_bad_words.py`.
"""
"""
from
typing
import
List
,
Optional
from
typing
import
Optional
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
...
@@ -16,8 +16,8 @@ def _generate(
...
@@ -16,8 +16,8 @@ def _generate(
prompt
:
str
,
prompt
:
str
,
num_prompt_tokens
:
int
,
num_prompt_tokens
:
int
,
temperature
:
float
=
0
,
temperature
:
float
=
0
,
bad_words
:
Optional
[
L
ist
[
str
]]
=
None
,
bad_words
:
Optional
[
l
ist
[
str
]]
=
None
,
)
->
L
ist
[
int
]:
)
->
l
ist
[
int
]:
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
temperature
=
temperature
,
temperature
=
temperature
,
bad_words
=
bad_words
,
bad_words
=
bad_words
,
...
@@ -59,7 +59,7 @@ class TestOneTokenBadWord:
...
@@ -59,7 +59,7 @@ class TestOneTokenBadWord:
def
_generate
(
self
,
def
_generate
(
self
,
model
:
LLM
,
model
:
LLM
,
bad_words
:
Optional
[
L
ist
[
str
]]
=
None
)
->
L
ist
[
int
]:
bad_words
:
Optional
[
l
ist
[
str
]]
=
None
)
->
l
ist
[
int
]:
return
_generate
(
return
_generate
(
model
=
model
,
model
=
model
,
prompt
=
self
.
PROMPT
,
prompt
=
self
.
PROMPT
,
...
@@ -69,7 +69,7 @@ class TestOneTokenBadWord:
...
@@ -69,7 +69,7 @@ class TestOneTokenBadWord:
def
_encode
(
self
,
def
_encode
(
self
,
prompt
:
str
,
prompt
:
str
,
add_special_tokens
:
bool
=
True
)
->
L
ist
[
int
]:
add_special_tokens
:
bool
=
True
)
->
l
ist
[
int
]:
return
self
.
tokenizer
(
prompt
,
return
self
.
tokenizer
(
prompt
,
add_special_tokens
=
add_special_tokens
).
input_ids
add_special_tokens
=
add_special_tokens
).
input_ids
...
@@ -149,7 +149,7 @@ class TestTwoTokenBadWord:
...
@@ -149,7 +149,7 @@ class TestTwoTokenBadWord:
def
_generate
(
self
,
def
_generate
(
self
,
model
:
LLM
,
model
:
LLM
,
bad_words
:
Optional
[
L
ist
[
str
]]
=
None
)
->
L
ist
[
int
]:
bad_words
:
Optional
[
l
ist
[
str
]]
=
None
)
->
l
ist
[
int
]:
return
_generate
(
return
_generate
(
model
=
model
,
model
=
model
,
prompt
=
self
.
PROMPT
,
prompt
=
self
.
PROMPT
,
...
@@ -158,7 +158,7 @@ class TestTwoTokenBadWord:
...
@@ -158,7 +158,7 @@ class TestTwoTokenBadWord:
)
)
@
staticmethod
@
staticmethod
def
_contains
(
sequence
:
L
ist
[
int
],
subsequence
:
L
ist
[
int
])
->
bool
:
def
_contains
(
sequence
:
l
ist
[
int
],
subsequence
:
l
ist
[
int
])
->
bool
:
searched
=
False
searched
=
False
for
start
in
range
(
len
(
sequence
)):
for
start
in
range
(
len
(
sequence
)):
...
@@ -181,6 +181,6 @@ class TestTwoTokenBadWord:
...
@@ -181,6 +181,6 @@ class TestTwoTokenBadWord:
def
_encode
(
self
,
def
_encode
(
self
,
prompt
:
str
,
prompt
:
str
,
add_special_tokens
:
bool
=
True
)
->
L
ist
[
int
]:
add_special_tokens
:
bool
=
True
)
->
l
ist
[
int
]:
return
self
.
tokenizer
(
prompt
,
return
self
.
tokenizer
(
prompt
,
add_special_tokens
=
add_special_tokens
).
input_ids
add_special_tokens
=
add_special_tokens
).
input_ids
tests/samplers/test_rejection_sampler.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
"""Tests for rejection sampling."""
"""Tests for rejection sampling."""
from
typing
import
List
,
Tuple
import
pytest
import
pytest
import
torch
import
torch
...
@@ -416,8 +415,8 @@ def test_rejection_sampling_approximates_target_distribution(
...
@@ -416,8 +415,8 @@ def test_rejection_sampling_approximates_target_distribution(
draft_and_target_probs_equal
)
draft_and_target_probs_equal
)
sample_sizes
=
[
10
,
100
,
1_000
,
10_000
,
100_000
]
sample_sizes
=
[
10
,
100
,
1_000
,
10_000
,
100_000
]
distance_wrt_reference
:
L
ist
[
float
]
=
[]
distance_wrt_reference
:
l
ist
[
float
]
=
[]
distance_wrt_target
:
L
ist
[
float
]
=
[]
distance_wrt_target
:
l
ist
[
float
]
=
[]
for
num_samples
in
sample_sizes
:
for
num_samples
in
sample_sizes
:
(
reference_vs_rejsample_dist
,
(
reference_vs_rejsample_dist
,
...
@@ -452,7 +451,7 @@ def test_rejection_sampling_approximates_target_distribution(
...
@@ -452,7 +451,7 @@ def test_rejection_sampling_approximates_target_distribution(
expected_improvement_multiplier
)
expected_improvement_multiplier
)
def
get_ratio_first_to_last
(
elements
:
L
ist
[
float
])
->
float
:
def
get_ratio_first_to_last
(
elements
:
l
ist
[
float
])
->
float
:
return
elements
[
0
]
/
elements
[
-
1
]
return
elements
[
0
]
/
elements
[
-
1
]
...
@@ -477,7 +476,7 @@ class _CorrectnessTestHelper:
...
@@ -477,7 +476,7 @@ class _CorrectnessTestHelper:
def
generate_probs_for_test
(
def
generate_probs_for_test
(
self
,
draft_and_target_probs_equal
:
bool
self
,
draft_and_target_probs_equal
:
bool
)
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
draft_probs
,
target_probs
=
(
F
.
softmax
(
draft_probs
,
target_probs
=
(
F
.
softmax
(
torch
.
rand
(
self
.
vocab_size
,
dtype
=
torch
.
float32
),
torch
.
rand
(
self
.
vocab_size
,
dtype
=
torch
.
float32
),
dim
=-
1
,
dim
=-
1
,
...
@@ -499,7 +498,7 @@ class _CorrectnessTestHelper:
...
@@ -499,7 +498,7 @@ class _CorrectnessTestHelper:
def
run_and_compare_distributions
(
self
,
draft_probs
:
torch
.
Tensor
,
def
run_and_compare_distributions
(
self
,
draft_probs
:
torch
.
Tensor
,
target_probs
:
torch
.
Tensor
,
target_probs
:
torch
.
Tensor
,
reference_probs
:
torch
.
Tensor
,
reference_probs
:
torch
.
Tensor
,
num_samples
:
int
)
->
T
uple
[
float
,
float
]:
num_samples
:
int
)
->
t
uple
[
float
,
float
]:
# Sample using rejection sampling.
# Sample using rejection sampling.
rej_sample_probs
=
self
.
_estimate_rejection_sampling_pdf
(
rej_sample_probs
=
self
.
_estimate_rejection_sampling_pdf
(
draft_probs
,
target_probs
,
num_samples
)
draft_probs
,
target_probs
,
num_samples
)
...
...
tests/samplers/test_sampler.py
View file @
cf069aa8
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
import
itertools
import
itertools
import
random
import
random
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Optional
from
unittest.mock
import
Mock
,
patch
from
unittest.mock
import
Mock
,
patch
import
pytest
import
pytest
...
@@ -30,7 +30,7 @@ class MockLogitsSampler(Sampler):
...
@@ -30,7 +30,7 @@ class MockLogitsSampler(Sampler):
def
_prepare_test
(
def
_prepare_test
(
batch_size
:
int
batch_size
:
int
)
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
MockLogitsSampler
]:
)
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
MockLogitsSampler
]:
input_tensor
=
torch
.
rand
((
batch_size
,
1024
),
dtype
=
torch
.
float16
)
input_tensor
=
torch
.
rand
((
batch_size
,
1024
),
dtype
=
torch
.
float16
)
fake_logits
=
torch
.
full
((
batch_size
,
VOCAB_SIZE
),
fake_logits
=
torch
.
full
((
batch_size
,
VOCAB_SIZE
),
1e-2
,
1e-2
,
...
@@ -53,8 +53,8 @@ def _do_sample(
...
@@ -53,8 +53,8 @@ def _do_sample(
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
device
:
str
,
device
:
str
,
):
):
seq_group_metadata_list
:
L
ist
[
SequenceGroupMetadata
]
=
[]
seq_group_metadata_list
:
l
ist
[
SequenceGroupMetadata
]
=
[]
seq_lens
:
L
ist
[
int
]
=
[]
seq_lens
:
l
ist
[
int
]
=
[]
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
SequenceGroupMetadata
(
...
@@ -171,7 +171,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
...
@@ -171,7 +171,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
def
create_sampling_params
(
min_tokens
,
def
create_sampling_params
(
min_tokens
,
eos_token_id
=
0
,
eos_token_id
=
0
,
*
,
*
,
stop_token_ids
:
Optional
[
L
ist
[
int
]]
=
None
,
stop_token_ids
:
Optional
[
l
ist
[
int
]]
=
None
,
prompt_logprobs
:
Optional
[
int
]
=
None
):
prompt_logprobs
:
Optional
[
int
]
=
None
):
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
min_tokens
=
min_tokens
,
min_tokens
=
min_tokens
,
...
@@ -196,7 +196,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
...
@@ -196,7 +196,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
batch_size
=
random
.
randint
(
1
,
128
)
batch_size
=
random
.
randint
(
1
,
128
)
expected_penalization
=
[]
expected_penalization
=
[]
sequence_metadata_list
:
L
ist
[
SequenceGroupMetadata
]
=
[]
sequence_metadata_list
:
l
ist
[
SequenceGroupMetadata
]
=
[]
# 20% chance to generate seq group metadata list with all prompts
# 20% chance to generate seq group metadata list with all prompts
is_prompt
=
random
.
random
()
<
0.2
is_prompt
=
random
.
random
()
<
0.2
while
batch_size
>
0
:
while
batch_size
>
0
:
...
@@ -216,8 +216,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
...
@@ -216,8 +216,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
eos_token_id
=
eos_token_id
,
eos_token_id
=
eos_token_id
,
stop_token_ids
=
stop_token_ids
)
stop_token_ids
=
stop_token_ids
)
seq_data
:
D
ict
[
int
,
SequenceData
]
=
{}
seq_data
:
d
ict
[
int
,
SequenceData
]
=
{}
seq_group_penalization
:
L
ist
[
bool
]
=
[]
seq_group_penalization
:
l
ist
[
bool
]
=
[]
for
_
in
range
(
num_seqs
):
for
_
in
range
(
num_seqs
):
num_input
=
random
.
randint
(
1
,
100
)
num_input
=
random
.
randint
(
1
,
100
)
num_generated
=
0
if
is_prompt
else
random
.
randint
(
1
,
100
)
num_generated
=
0
if
is_prompt
else
random
.
randint
(
1
,
100
)
...
@@ -376,16 +376,16 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
...
@@ -376,16 +376,16 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
else
:
else
:
test_cases
=
[
generate_test_case
()]
test_cases
=
[
generate_test_case
()]
def
run_test_case
(
*
,
expected_penalization
:
L
ist
[
bool
],
def
run_test_case
(
*
,
expected_penalization
:
l
ist
[
bool
],
seq_group_metadata_list
:
L
ist
[
SequenceGroupMetadata
]):
seq_group_metadata_list
:
l
ist
[
SequenceGroupMetadata
]):
assert
expected_penalization
,
\
assert
expected_penalization
,
\
"Invalid test case, need expected_penalization"
"Invalid test case, need expected_penalization"
assert
seq_group_metadata_list
,
\
assert
seq_group_metadata_list
,
\
"Invalid test case, need seq_group_metadata_list"
"Invalid test case, need seq_group_metadata_list"
batch_size
=
0
batch_size
=
0
seq_lens
:
L
ist
[
int
]
=
[]
seq_lens
:
l
ist
[
int
]
=
[]
sampling_params_per_row
:
L
ist
[
SamplingParams
]
=
[]
sampling_params_per_row
:
l
ist
[
SamplingParams
]
=
[]
for
sgm
in
seq_group_metadata_list
:
for
sgm
in
seq_group_metadata_list
:
sampling_params
=
sgm
.
sampling_params
sampling_params
=
sgm
.
sampling_params
...
@@ -456,11 +456,11 @@ def test_sampler_mixed(seed: int, device: str):
...
@@ -456,11 +456,11 @@ def test_sampler_mixed(seed: int, device: str):
batch_size
=
random
.
randint
(
1
,
256
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
fake_logits
,
sampler
=
_prepare_test
(
batch_size
)
input_tensor
,
fake_logits
,
sampler
=
_prepare_test
(
batch_size
)
seq_group_metadata_list
:
L
ist
[
SequenceGroupMetadata
]
=
[]
seq_group_metadata_list
:
l
ist
[
SequenceGroupMetadata
]
=
[]
expected_tokens
:
L
ist
[
Optional
[
L
ist
[
int
]]]
=
[]
expected_tokens
:
l
ist
[
Optional
[
l
ist
[
int
]]]
=
[]
seq_lens
:
L
ist
[
int
]
=
[]
seq_lens
:
l
ist
[
int
]
=
[]
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
expected
:
Optional
[
L
ist
[
int
]]
=
None
expected
:
Optional
[
l
ist
[
int
]]
=
None
sampling_type
=
random
.
randint
(
0
,
2
)
sampling_type
=
random
.
randint
(
0
,
2
)
if
sampling_type
==
0
:
if
sampling_type
==
0
:
sampling_params
=
SamplingParams
(
temperature
=
0
)
sampling_params
=
SamplingParams
(
temperature
=
0
)
...
@@ -492,7 +492,7 @@ def test_sampler_mixed(seed: int, device: str):
...
@@ -492,7 +492,7 @@ def test_sampler_mixed(seed: int, device: str):
))
))
seq_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
seq_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
generators
:
D
ict
[
str
,
torch
.
Generator
]
=
{}
generators
:
d
ict
[
str
,
torch
.
Generator
]
=
{}
def
test_sampling
():
def
test_sampling
():
sampling_metadata
=
SamplingMetadata
.
prepare
(
sampling_metadata
=
SamplingMetadata
.
prepare
(
...
@@ -587,8 +587,8 @@ def test_sampler_top_k_top_p(seed: int, device: str):
...
@@ -587,8 +587,8 @@ def test_sampler_top_k_top_p(seed: int, device: str):
device
=
device
)
device
=
device
)
assert
len
(
processors
)
==
2
# top_p and top_k
assert
len
(
processors
)
==
2
# top_p and top_k
seq_group_metadata_list
:
L
ist
[
SequenceGroupMetadata
]
=
[]
seq_group_metadata_list
:
l
ist
[
SequenceGroupMetadata
]
=
[]
seq_lens
:
L
ist
[
int
]
=
[]
seq_lens
:
l
ist
[
int
]
=
[]
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
SequenceGroupMetadata
(
...
@@ -669,10 +669,10 @@ def test_sampler_repetition_penalty_mixed(device: str):
...
@@ -669,10 +669,10 @@ def test_sampler_repetition_penalty_mixed(device: str):
vocab_size
=
8
vocab_size
=
8
def
test_sampling_params
(
sampling_params
:
L
ist
[
SamplingParams
]):
def
test_sampling_params
(
sampling_params
:
l
ist
[
SamplingParams
]):
seq_group_metadata_list
:
L
ist
[
SequenceGroupMetadata
]
=
[]
seq_group_metadata_list
:
l
ist
[
SequenceGroupMetadata
]
=
[]
seq_lens
:
L
ist
[
int
]
=
[]
seq_lens
:
l
ist
[
int
]
=
[]
for
i
in
range
(
2
):
for
i
in
range
(
2
):
seq_group_metadata_list
.
append
(
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
SequenceGroupMetadata
(
...
...
tests/spec_decode/e2e/conftest.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
collections.abc
import
Sequence
from
itertools
import
cycle
from
itertools
import
cycle
from
typing
import
List
,
Optional
,
Sequence
,
Tuple
,
Union
from
typing
import
Optional
,
Union
import
pytest
import
pytest
import
torch
import
torch
...
@@ -64,9 +65,9 @@ def maybe_assert_ngram_worker(llm):
...
@@ -64,9 +65,9 @@ def maybe_assert_ngram_worker(llm):
def
get_output_from_llm_generator
(
def
get_output_from_llm_generator
(
llm_generator
,
prompts
,
llm_generator
,
prompts
,
sampling_params
)
->
T
uple
[
L
ist
[
str
],
L
ist
[
L
ist
[
int
]],
float
]:
sampling_params
)
->
t
uple
[
l
ist
[
str
],
l
ist
[
l
ist
[
int
]],
float
]:
tokens
:
L
ist
[
str
]
=
[]
tokens
:
l
ist
[
str
]
=
[]
token_ids
:
L
ist
[
L
ist
[
int
]]
=
[]
token_ids
:
l
ist
[
l
ist
[
int
]]
=
[]
acceptance_rate
:
float
=
-
1.0
acceptance_rate
:
float
=
-
1.0
for
llm
in
llm_generator
():
for
llm
in
llm_generator
():
maybe_assert_ngram_worker
(
llm
)
maybe_assert_ngram_worker
(
llm
)
...
...
tests/spec_decode/test_batch_expansion.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
import
pytest
import
pytest
import
torch
import
torch
...
@@ -42,7 +40,7 @@ def test_get_token_ids_to_score(k: int):
...
@@ -42,7 +40,7 @@ def test_get_token_ids_to_score(k: int):
device
=
'cuda'
,
device
=
'cuda'
,
)
)
expected_output
:
L
ist
[
L
ist
[
int
]]
=
[
expected_output
:
l
ist
[
l
ist
[
int
]]
=
[
[],
[],
]
]
for
i
in
range
(
proposal_token_ids
.
shape
[
0
]):
for
i
in
range
(
proposal_token_ids
.
shape
[
0
]):
...
...
tests/spec_decode/test_multi_step_worker.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
random
import
random
from
typing
import
Dict
,
List
from
unittest.mock
import
MagicMock
from
unittest.mock
import
MagicMock
import
pytest
import
pytest
...
@@ -221,7 +220,7 @@ def test_same_output_for_multi_step():
...
@@ -221,7 +220,7 @@ def test_same_output_for_multi_step():
# Run single-step repeatedly.
# Run single-step repeatedly.
zero_kv_cache
(
worker
.
cache_engine
)
zero_kv_cache
(
worker
.
cache_engine
)
single_step_output
:
L
ist
[
SamplerOutput
]
=
[]
single_step_output
:
l
ist
[
SamplerOutput
]
=
[]
continuations
=
[[
1
]
for
_
in
prompts
]
continuations
=
[[
1
]
for
_
in
prompts
]
set_random_seed
(
seed
)
set_random_seed
(
seed
)
...
@@ -243,15 +242,15 @@ def test_same_output_for_multi_step():
...
@@ -243,15 +242,15 @@ def test_same_output_for_multi_step():
continuations
[
i
].
append
(
seq_group_output
.
samples
[
0
].
output_token
)
continuations
[
i
].
append
(
seq_group_output
.
samples
[
0
].
output_token
)
# Get token ids and logprobs for comparison.
# Get token ids and logprobs for comparison.
multi_step_output_logprobs
:
L
ist
[
L
ist
[
D
ict
[
int
,
multi_step_output_logprobs
:
l
ist
[
l
ist
[
d
ict
[
int
,
Logprob
]]]
=
[[]
Logprob
]]]
=
[[]
for
_
in
prompts
]
for
_
in
prompts
]
single_step_output_logprobs
:
L
ist
[
L
ist
[
D
ict
[
int
,
single_step_output_logprobs
:
l
ist
[
l
ist
[
d
ict
[
int
,
Logprob
]]]
=
[[]
Logprob
]]]
=
[[]
for
_
in
prompts
]
for
_
in
prompts
]
multi_step_output_token_ids
:
L
ist
[
L
ist
[
int
]]
=
[[]
for
_
in
prompts
]
multi_step_output_token_ids
:
l
ist
[
l
ist
[
int
]]
=
[[]
for
_
in
prompts
]
single_step_output_token_ids
:
L
ist
[
L
ist
[
int
]]
=
[[]
for
_
in
prompts
]
single_step_output_token_ids
:
l
ist
[
l
ist
[
int
]]
=
[[]
for
_
in
prompts
]
for
i
,
_
in
enumerate
(
prompts
):
for
i
,
_
in
enumerate
(
prompts
):
for
multi_step
,
single_step
in
zip
(
multi_step_output
,
for
multi_step
,
single_step
in
zip
(
multi_step_output
,
single_step_output
):
single_step_output
):
...
@@ -336,7 +335,7 @@ def test_multi_step_with_batch_expansion_correct_output():
...
@@ -336,7 +335,7 @@ def test_multi_step_with_batch_expansion_correct_output():
# will simulate the bonus token case with the second token
# will simulate the bonus token case with the second token
# being the bonus token.
# being the bonus token.
zero_kv_cache
(
worker
.
cache_engine
)
zero_kv_cache
(
worker
.
cache_engine
)
single_step_output
:
L
ist
[
SamplerOutput
]
=
[]
single_step_output
:
l
ist
[
SamplerOutput
]
=
[]
set_random_seed
(
seed
)
set_random_seed
(
seed
)
for
_
in
range
(
num_steps
):
for
_
in
range
(
num_steps
):
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
...
@@ -430,7 +429,7 @@ def test_multi_step_with_batch_expansion_incorrect_output():
...
@@ -430,7 +429,7 @@ def test_multi_step_with_batch_expansion_incorrect_output():
# will simulate the bonus token case with the second token
# will simulate the bonus token case with the second token
# being the bonus token.
# being the bonus token.
zero_kv_cache
(
worker
.
cache_engine
)
zero_kv_cache
(
worker
.
cache_engine
)
single_step_output
:
L
ist
[
SamplerOutput
]
=
[]
single_step_output
:
l
ist
[
SamplerOutput
]
=
[]
set_random_seed
(
seed
)
set_random_seed
(
seed
)
for
_
in
range
(
num_steps
):
for
_
in
range
(
num_steps
):
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
...
...
tests/spec_decode/test_scorer.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
random
import
random
from
typing
import
List
import
pytest
import
pytest
import
torch
import
torch
...
@@ -15,7 +14,7 @@ from vllm.worker.worker import Worker
...
@@ -15,7 +14,7 @@ from vllm.worker.worker import Worker
from
.utils
import
create_batch
,
create_worker
from
.utils
import
create_batch
,
create_worker
def
create_proposal
(
propose_lens
:
L
ist
[
int
],
vocab_size
:
int
,
def
create_proposal
(
propose_lens
:
l
ist
[
int
],
vocab_size
:
int
,
device
:
str
)
->
SpeculativeProposals
:
device
:
str
)
->
SpeculativeProposals
:
batch_size
=
len
(
propose_lens
)
batch_size
=
len
(
propose_lens
)
max_propose_len
=
max
(
propose_lens
)
max_propose_len
=
max
(
propose_lens
)
...
...
tests/spec_decode/test_spec_decode_worker.py
View file @
cf069aa8
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
import
random
import
random
from
collections
import
defaultdict
from
collections
import
defaultdict
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
from
typing
import
Dict
,
List
,
Set
from
unittest.mock
import
MagicMock
from
unittest.mock
import
MagicMock
import
pytest
import
pytest
...
@@ -123,7 +122,7 @@ def test_batch_expansion_correctly_calls_target_model(
...
@@ -123,7 +122,7 @@ def test_batch_expansion_correctly_calls_target_model(
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
))
num_lookahead_slots
=
k
))
seen_contexts
:
L
ist
[
L
ist
[
int
]]
=
[]
seen_contexts
:
l
ist
[
l
ist
[
int
]]
=
[]
call_args_list
=
target_worker
.
execute_model
.
call_args_list
call_args_list
=
target_worker
.
execute_model
.
call_args_list
assert
len
(
call_args_list
)
==
1
assert
len
(
call_args_list
)
==
1
...
@@ -136,7 +135,7 @@ def test_batch_expansion_correctly_calls_target_model(
...
@@ -136,7 +135,7 @@ def test_batch_expansion_correctly_calls_target_model(
for
seq_data
in
seq_group_metadata
.
seq_data
.
values
():
for
seq_data
in
seq_group_metadata
.
seq_data
.
values
():
seen_contexts
.
append
(
seq_data
.
get_token_ids
())
seen_contexts
.
append
(
seq_data
.
get_token_ids
())
expected_seen_contexts
:
L
ist
[
L
ist
[
int
]]
=
[]
expected_seen_contexts
:
l
ist
[
l
ist
[
int
]]
=
[]
for
prompt
,
prev_generated
,
draft_tokens
in
zip
(
for
prompt
,
prev_generated
,
draft_tokens
in
zip
(
prompts
,
prev_output_tokens
,
proposal_token_ids
.
tolist
()):
prompts
,
prev_output_tokens
,
proposal_token_ids
.
tolist
()):
...
@@ -338,11 +337,11 @@ def test_correctly_formats_output(k: int, batch_size: int,
...
@@ -338,11 +337,11 @@ def test_correctly_formats_output(k: int, batch_size: int,
next
(
iter
(
seq_group_metadata
.
seq_data
.
keys
()))
next
(
iter
(
seq_group_metadata
.
seq_data
.
keys
()))
for
seq_group_metadata
in
seq_group_metadata_list
for
seq_group_metadata
in
seq_group_metadata_list
]
]
actual_output_by_seq
:
D
ict
[
int
,
L
ist
[
SequenceOutput
]]
=
{
actual_output_by_seq
:
d
ict
[
int
,
l
ist
[
SequenceOutput
]]
=
{
seq_id
:
[]
seq_id
:
[]
for
seq_id
in
seq_ids
for
seq_id
in
seq_ids
}
}
expected_output_by_seq
:
D
ict
[
int
,
L
ist
[
SequenceOutput
]]
=
{
expected_output_by_seq
:
d
ict
[
int
,
l
ist
[
SequenceOutput
]]
=
{
seq_id
:
[]
seq_id
:
[]
for
seq_id
in
seq_ids
for
seq_id
in
seq_ids
}
}
...
@@ -728,7 +727,7 @@ def test_populate_seq_ids_with_bonus_tokens():
...
@@ -728,7 +727,7 @@ def test_populate_seq_ids_with_bonus_tokens():
size
=
(
batch_size
,
(
k
+
1
)),
size
=
(
batch_size
,
(
k
+
1
)),
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
device
=
'cuda'
)
device
=
'cuda'
)
expected_request_id_seq_ids_mapping
:
D
ict
[
str
,
S
et
[
int
]]
=
defaultdict
(
set
)
expected_request_id_seq_ids_mapping
:
d
ict
[
str
,
s
et
[
int
]]
=
defaultdict
(
set
)
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_id
in
seq_group_metadata
.
seq_data
:
for
seq_id
in
seq_group_metadata
.
seq_data
:
expected_request_id_seq_ids_mapping
[
expected_request_id_seq_ids_mapping
[
...
...
tests/spec_decode/utils.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
collections.abc
import
Sequence
as
GenericSequence
from
itertools
import
count
from
itertools
import
count
from
typing
import
Callable
,
Dict
,
List
,
Optional
from
typing
import
Callable
,
Optional
,
TypeVar
,
Union
from
typing
import
Sequence
as
GenericSequence
from
typing
import
TypeVar
,
Union
from
unittest.mock
import
MagicMock
from
unittest.mock
import
MagicMock
import
torch
import
torch
...
@@ -44,7 +43,7 @@ def mock_worker(cls=None,
...
@@ -44,7 +43,7 @@ def mock_worker(cls=None,
return
worker
return
worker
def
patch_execute_model_with_seeds
(
worker
:
Worker
,
rand_seeds
:
L
ist
[
int
]):
def
patch_execute_model_with_seeds
(
worker
:
Worker
,
rand_seeds
:
l
ist
[
int
]):
seed_iter
=
iter
(
rand_seeds
)
seed_iter
=
iter
(
rand_seeds
)
original_execute_model
=
worker
.
execute_model
original_execute_model
=
worker
.
execute_model
...
@@ -56,7 +55,7 @@ def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]):
...
@@ -56,7 +55,7 @@ def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]):
return
new_execute_model
return
new_execute_model
def
zero_kv_cache
(
cache_engine
:
L
ist
[
CacheEngine
]):
def
zero_kv_cache
(
cache_engine
:
l
ist
[
CacheEngine
]):
assert
cache_engine
[
0
].
gpu_cache
assert
cache_engine
[
0
].
gpu_cache
for
key_blocks
,
value_blocks
in
cache_engine
[
0
].
gpu_cache
:
for
key_blocks
,
value_blocks
in
cache_engine
[
0
].
gpu_cache
:
key_blocks
.
zero_
()
key_blocks
.
zero_
()
...
@@ -106,13 +105,13 @@ def create_worker(cls: Callable[..., T],
...
@@ -106,13 +105,13 @@ def create_worker(cls: Callable[..., T],
def
create_seq_group_metadata_from_prompts
(
def
create_seq_group_metadata_from_prompts
(
prompts
:
L
ist
[
L
ist
[
int
]],
prompts
:
l
ist
[
l
ist
[
int
]],
num_gpu_blocks
:
int
,
num_gpu_blocks
:
int
,
block_size
:
int
,
block_size
:
int
,
final_prompt_lens
:
L
ist
[
int
],
final_prompt_lens
:
l
ist
[
int
],
continuations
:
Optional
[
L
ist
[
L
ist
[
int
]]]
=
None
,
continuations
:
Optional
[
l
ist
[
l
ist
[
int
]]]
=
None
,
seq_ids
:
Optional
[
L
ist
[
int
]]
=
None
,
seq_ids
:
Optional
[
l
ist
[
int
]]
=
None
,
)
->
L
ist
[
SequenceGroupMetadata
]:
)
->
l
ist
[
SequenceGroupMetadata
]:
if
continuations
is
None
:
if
continuations
is
None
:
continuations
=
[[]
for
_
in
prompts
]
continuations
=
[[]
for
_
in
prompts
]
...
@@ -149,11 +148,11 @@ def create_seq_group_metadata_from_prompts(
...
@@ -149,11 +148,11 @@ def create_seq_group_metadata_from_prompts(
def
create_chunked_seq_group_metadata_from_prompt
(
def
create_chunked_seq_group_metadata_from_prompt
(
prompt
:
L
ist
[
int
],
prompt
:
l
ist
[
int
],
num_gpu_blocks
:
int
,
num_gpu_blocks
:
int
,
chunk_size
:
int
,
chunk_size
:
int
,
block_size
:
int
,
block_size
:
int
,
seq_id
:
Optional
[
int
]
=
None
)
->
L
ist
[
SequenceGroupMetadata
]:
seq_id
:
Optional
[
int
]
=
None
)
->
l
ist
[
SequenceGroupMetadata
]:
if
seq_id
is
None
:
if
seq_id
is
None
:
seq_id
=
0
seq_id
=
0
...
@@ -184,8 +183,8 @@ def create_chunked_seq_group_metadata_from_prompt(
...
@@ -184,8 +183,8 @@ def create_chunked_seq_group_metadata_from_prompt(
def
assert_logprobs_dict_allclose
(
def
assert_logprobs_dict_allclose
(
actual_logprobs
:
L
ist
[
D
ict
[
int
,
Logprob
]],
actual_logprobs
:
l
ist
[
d
ict
[
int
,
Logprob
]],
expected_logprobs
:
L
ist
[
D
ict
[
int
,
Logprob
]])
->
None
:
expected_logprobs
:
l
ist
[
d
ict
[
int
,
Logprob
]])
->
None
:
for
single_step_actual_logprobs
,
single_step_expected_logprobs
in
zip
(
for
single_step_actual_logprobs
,
single_step_expected_logprobs
in
zip
(
actual_logprobs
,
expected_logprobs
):
actual_logprobs
,
expected_logprobs
):
assert
set
(
single_step_actual_logprobs
.
keys
())
==
set
(
assert
set
(
single_step_actual_logprobs
.
keys
())
==
set
(
...
@@ -202,7 +201,7 @@ def create_sampler_output_list(
...
@@ -202,7 +201,7 @@ def create_sampler_output_list(
token_ids
:
torch
.
Tensor
,
token_ids
:
torch
.
Tensor
,
probs
:
GenericSequence
[
Optional
[
torch
.
Tensor
]],
probs
:
GenericSequence
[
Optional
[
torch
.
Tensor
]],
logprobs
:
GenericSequence
[
Optional
[
torch
.
Tensor
]],
logprobs
:
GenericSequence
[
Optional
[
torch
.
Tensor
]],
seq_ids
:
Optional
[
L
ist
[
int
]]
=
None
)
->
L
ist
[
SamplerOutput
]:
seq_ids
:
Optional
[
l
ist
[
int
]]
=
None
)
->
l
ist
[
SamplerOutput
]:
num_steps
,
batch_size
=
token_ids
.
shape
num_steps
,
batch_size
=
token_ids
.
shape
token_ids_by_step
=
token_ids
.
tolist
()
token_ids_by_step
=
token_ids
.
tolist
()
...
@@ -231,9 +230,9 @@ def create_sampler_output_list(
...
@@ -231,9 +230,9 @@ def create_sampler_output_list(
def
create_batch
(
batch_size
,
def
create_batch
(
batch_size
,
k
,
k
,
prompt_len
:
Union
[
int
,
L
ist
[
int
]]
=
10
,
prompt_len
:
Union
[
int
,
l
ist
[
int
]]
=
10
,
prev_output_token_len
:
int
=
10
,
prev_output_token_len
:
int
=
10
,
seq_ids
:
Optional
[
L
ist
[
int
]]
=
None
,
seq_ids
:
Optional
[
l
ist
[
int
]]
=
None
,
num_gpu_blocks
:
Optional
[
int
]
=
None
,
num_gpu_blocks
:
Optional
[
int
]
=
None
,
block_size
:
Optional
[
int
]
=
None
,
block_size
:
Optional
[
int
]
=
None
,
prefill_chunk_size
:
Optional
[
int
]
=
None
):
prefill_chunk_size
:
Optional
[
int
]
=
None
):
...
...
tests/test_cache_block_hashing.py
View file @
cf069aa8
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
Run `pytest tests/test_cache_block_hashing.py`.
Run `pytest tests/test_cache_block_hashing.py`.
"""
"""
from
typing
import
List
,
Optional
from
typing
import
Optional
import
pytest
import
pytest
...
@@ -44,7 +44,7 @@ def flatten_2d(li):
...
@@ -44,7 +44,7 @@ def flatten_2d(li):
@
pytest
.
mark
.
parametrize
(
"concurrent_lora_int_ids"
,
@
pytest
.
mark
.
parametrize
(
"concurrent_lora_int_ids"
,
[[
None
],
[
1
],
[
None
,
1
],
[
None
,
1
,
2
],
[
1
,
2
]])
[[
None
],
[
1
],
[
None
,
1
],
[
None
,
1
,
2
],
[
1
,
2
]])
def
test_auto_prefix_caching
(
model
:
str
,
block_size
:
int
,
max_num_seqs
:
int
,
def
test_auto_prefix_caching
(
model
:
str
,
block_size
:
int
,
max_num_seqs
:
int
,
concurrent_lora_int_ids
:
L
ist
[
Optional
[
int
]]):
concurrent_lora_int_ids
:
l
ist
[
Optional
[
int
]]):
tokenizer
=
TokenizerGroup
(
tokenizer
=
TokenizerGroup
(
tokenizer_id
=
"facebook/opt-125m"
,
tokenizer_id
=
"facebook/opt-125m"
,
...
@@ -53,7 +53,7 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
...
@@ -53,7 +53,7 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
max_input_length
=
None
,
max_input_length
=
None
,
)
)
hashes
:
L
ist
[
L
ist
[
L
ist
[
int
]]]
=
[]
hashes
:
l
ist
[
l
ist
[
l
ist
[
int
]]]
=
[]
for
prefix
in
prefixes
:
for
prefix
in
prefixes
:
for
lora_int_id
in
concurrent_lora_int_ids
:
for
lora_int_id
in
concurrent_lora_int_ids
:
...
...
tests/test_inputs.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
import
pytest
import
pytest
from
vllm.inputs
import
zip_enc_dec_prompts
from
vllm.inputs
import
zip_enc_dec_prompts
...
@@ -45,7 +43,7 @@ def test_parse_single_batch_string_consistent(string_input: str):
...
@@ -45,7 +43,7 @@ def test_parse_single_batch_string_consistent(string_input: str):
@
pytest
.
mark
.
parametrize
(
'token_input'
,
TOKEN_INPUTS
)
@
pytest
.
mark
.
parametrize
(
'token_input'
,
TOKEN_INPUTS
)
def
test_parse_single_batch_token_consistent
(
token_input
:
L
ist
[
int
]):
def
test_parse_single_batch_token_consistent
(
token_input
:
l
ist
[
int
]):
assert
parse_and_batch_prompt
(
token_input
)
\
assert
parse_and_batch_prompt
(
token_input
)
\
==
parse_and_batch_prompt
([
token_input
])
==
parse_and_batch_prompt
([
token_input
])
...
...
tests/test_logger.py
View file @
cf069aa8
...
@@ -155,7 +155,7 @@ def test_an_error_is_raised_when_custom_logging_config_is_unexpected_json(
...
@@ -155,7 +155,7 @@ def test_an_error_is_raised_when_custom_logging_config_is_unexpected_json(
with
pytest
.
raises
(
ValueError
)
as
ex_info
:
with
pytest
.
raises
(
ValueError
)
as
ex_info
:
_configure_vllm_root_logger
()
_configure_vllm_root_logger
()
assert
ex_info
.
type
==
ValueError
# noqa: E721
assert
ex_info
.
type
==
ValueError
# noqa: E721
assert
"Invalid logging config. Expected
D
ict, got"
in
str
(
ex_info
)
assert
"Invalid logging config. Expected
d
ict, got"
in
str
(
ex_info
)
@
patch
(
"vllm.logger.VLLM_CONFIGURE_LOGGING"
,
1
)
@
patch
(
"vllm.logger.VLLM_CONFIGURE_LOGGING"
,
1
)
...
...
tests/test_logits_processor.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
random
import
random
from
typing
import
Tuple
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
pytest
import
pytest
...
@@ -33,7 +32,7 @@ class MockLogitsProcessor(LogitsProcessor):
...
@@ -33,7 +32,7 @@ class MockLogitsProcessor(LogitsProcessor):
def
_prepare_test
(
def
_prepare_test
(
batch_size
:
int
batch_size
:
int
)
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
MockLogitsProcessor
]:
)
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
MockLogitsProcessor
]:
vocab_size
=
32000
vocab_size
=
32000
input_tensor
=
torch
.
rand
((
batch_size
,
1024
),
dtype
=
torch
.
float16
)
input_tensor
=
torch
.
rand
((
batch_size
,
1024
),
dtype
=
torch
.
float16
)
fake_logits
=
torch
.
full
((
batch_size
,
vocab_size
),
fake_logits
=
torch
.
full
((
batch_size
,
vocab_size
),
...
...
tests/test_utils.py
View file @
cf069aa8
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
import
asyncio
import
asyncio
import
os
import
os
import
socket
import
socket
from
typing
import
AsyncIterator
,
Tuple
from
collections.abc
import
AsyncIterator
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
pytest
import
pytest
...
@@ -33,7 +33,7 @@ async def test_merge_async_iterators():
...
@@ -33,7 +33,7 @@ async def test_merge_async_iterators():
iterators
=
[
mock_async_iterator
(
i
)
for
i
in
range
(
3
)]
iterators
=
[
mock_async_iterator
(
i
)
for
i
in
range
(
3
)]
merged_iterator
=
merge_async_iterators
(
*
iterators
)
merged_iterator
=
merge_async_iterators
(
*
iterators
)
async
def
stream_output
(
generator
:
AsyncIterator
[
T
uple
[
int
,
str
]]):
async
def
stream_output
(
generator
:
AsyncIterator
[
t
uple
[
int
,
str
]]):
async
for
idx
,
output
in
generator
:
async
for
idx
,
output
in
generator
:
print
(
f
"idx:
{
idx
}
, output:
{
output
}
"
)
print
(
f
"idx:
{
idx
}
, output:
{
output
}
"
)
...
...
tests/tokenization/test_detokenize.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
from
collections.abc
import
Generator
from
typing
import
Any
,
Optional
import
pytest
import
pytest
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
...
@@ -163,7 +164,7 @@ def detokenizer(tokenizer_name: str) -> Detokenizer:
...
@@ -163,7 +164,7 @@ def detokenizer(tokenizer_name: str) -> Detokenizer:
@
pytest
.
fixture
(
name
=
"complete_sequence_token_ids"
)
@
pytest
.
fixture
(
name
=
"complete_sequence_token_ids"
)
def
create_complete_sequence_token_ids
(
complete_sequence
:
str
,
def
create_complete_sequence_token_ids
(
complete_sequence
:
str
,
tokenizer
)
->
L
ist
[
int
]:
tokenizer
)
->
l
ist
[
int
]:
complete_sequence_token_ids
=
tokenizer
(
complete_sequence
).
input_ids
complete_sequence_token_ids
=
tokenizer
(
complete_sequence
).
input_ids
return
complete_sequence_token_ids
return
complete_sequence_token_ids
...
@@ -178,7 +179,7 @@ def create_sequence(prompt_token_ids=None):
...
@@ -178,7 +179,7 @@ def create_sequence(prompt_token_ids=None):
def
create_dummy_logprobs
(
def
create_dummy_logprobs
(
complete_sequence_token_ids
:
L
ist
[
int
])
->
L
ist
[
D
ict
[
int
,
Logprob
]]:
complete_sequence_token_ids
:
l
ist
[
int
])
->
l
ist
[
d
ict
[
int
,
Logprob
]]:
return
[{
return
[{
token_id
:
Logprob
(
logprob
=
0.0
),
token_id
:
Logprob
(
logprob
=
0.0
),
token_id
+
1
:
Logprob
(
logprob
=
0.1
)
token_id
+
1
:
Logprob
(
logprob
=
0.1
)
...
@@ -186,10 +187,10 @@ def create_dummy_logprobs(
...
@@ -186,10 +187,10 @@ def create_dummy_logprobs(
def
create_dummy_prompt_logprobs
(
def
create_dummy_prompt_logprobs
(
complete_sequence_token_ids
:
L
ist
[
int
]
complete_sequence_token_ids
:
l
ist
[
int
]
)
->
L
ist
[
Optional
[
D
ict
[
int
,
Any
]]]:
)
->
l
ist
[
Optional
[
d
ict
[
int
,
Any
]]]:
# logprob for the first prompt token is None.
# logprob for the first prompt token is None.
logprobs
:
L
ist
[
Optional
[
D
ict
[
int
,
Any
]]]
=
[
None
]
logprobs
:
l
ist
[
Optional
[
d
ict
[
int
,
Any
]]]
=
[
None
]
logprobs
.
extend
(
create_dummy_logprobs
(
complete_sequence_token_ids
)[
1
:])
logprobs
.
extend
(
create_dummy_logprobs
(
complete_sequence_token_ids
)[
1
:])
return
logprobs
return
logprobs
...
@@ -198,7 +199,7 @@ def create_dummy_prompt_logprobs(
...
@@ -198,7 +199,7 @@ def create_dummy_prompt_logprobs(
@
pytest
.
mark
.
parametrize
(
"tokenizer_name"
,
TOKENIZERS
)
@
pytest
.
mark
.
parametrize
(
"tokenizer_name"
,
TOKENIZERS
)
@
pytest
.
mark
.
parametrize
(
"skip_special_tokens"
,
[
True
,
False
],
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"skip_special_tokens"
,
[
True
,
False
],
indirect
=
True
)
def
test_decode_sequence_logprobs
(
complete_sequence
:
str
,
def
test_decode_sequence_logprobs
(
complete_sequence
:
str
,
complete_sequence_token_ids
:
L
ist
[
int
],
complete_sequence_token_ids
:
l
ist
[
int
],
detokenizer
:
Detokenizer
,
detokenizer
:
Detokenizer
,
skip_special_tokens
:
bool
):
skip_special_tokens
:
bool
):
"""Verify Detokenizer decodes logprobs correctly."""
"""Verify Detokenizer decodes logprobs correctly."""
...
@@ -208,8 +209,8 @@ def test_decode_sequence_logprobs(complete_sequence: str,
...
@@ -208,8 +209,8 @@ def test_decode_sequence_logprobs(complete_sequence: str,
# Run sequentially.
# Run sequentially.
seq
=
create_sequence
()
seq
=
create_sequence
()
dummy_logprobs
=
create_dummy_logprobs
(
complete_sequence_token_ids
)
dummy_logprobs
=
create_dummy_logprobs
(
complete_sequence_token_ids
)
sequential_logprobs_text_chosen_token
:
L
ist
[
str
]
=
[]
sequential_logprobs_text_chosen_token
:
l
ist
[
str
]
=
[]
sequential_logprobs_text_other_token
:
L
ist
[
str
]
=
[]
sequential_logprobs_text_other_token
:
l
ist
[
str
]
=
[]
for
new_token
,
logprobs
in
zip
(
complete_sequence_token_ids
,
for
new_token
,
logprobs
in
zip
(
complete_sequence_token_ids
,
dummy_logprobs
):
dummy_logprobs
):
seq
.
append_token_id
(
new_token
,
logprobs
)
seq
.
append_token_id
(
new_token
,
logprobs
)
...
@@ -232,7 +233,7 @@ def test_decode_sequence_logprobs(complete_sequence: str,
...
@@ -232,7 +233,7 @@ def test_decode_sequence_logprobs(complete_sequence: str,
@
pytest
.
mark
.
parametrize
(
"complete_sequence"
,
TRUTH
)
@
pytest
.
mark
.
parametrize
(
"complete_sequence"
,
TRUTH
)
@
pytest
.
mark
.
parametrize
(
"tokenizer_name"
,
TOKENIZERS
)
@
pytest
.
mark
.
parametrize
(
"tokenizer_name"
,
TOKENIZERS
)
def
test_decode_prompt_logprobs
(
complete_sequence_token_ids
:
L
ist
[
int
],
def
test_decode_prompt_logprobs
(
complete_sequence_token_ids
:
l
ist
[
int
],
detokenizer
:
Detokenizer
):
detokenizer
:
Detokenizer
):
"""Verify Detokenizer decodes prompt logprobs correctly."""
"""Verify Detokenizer decodes prompt logprobs correctly."""
sampling_params
=
SamplingParams
(
skip_special_tokens
=
True
,
sampling_params
=
SamplingParams
(
skip_special_tokens
=
True
,
...
@@ -249,7 +250,7 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int],
...
@@ -249,7 +250,7 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int],
dummy_logprobs
,
dummy_logprobs
,
position_offset
=
0
)
position_offset
=
0
)
# First logprob is None.
# First logprob is None.
decoded_prompt_logprobs
:
L
ist
[
D
ict
[
int
,
Any
]]
=
dummy_logprobs
[
decoded_prompt_logprobs
:
l
ist
[
d
ict
[
int
,
Any
]]
=
dummy_logprobs
[
1
:]
# type: ignore
1
:]
# type: ignore
# decoded_prompt_logprobs doesn't contain the first token.
# decoded_prompt_logprobs doesn't contain the first token.
...
...
tests/tokenization/test_tokenizer_group.py
View file @
cf069aa8
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
import
asyncio
import
asyncio
import
os
import
os
import
sys
import
sys
from
typing
import
List
,
Optional
from
typing
import
Optional
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
pytest
import
pytest
...
@@ -129,7 +129,7 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
...
@@ -129,7 +129,7 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
def
__init__
(
self
,
def
__init__
(
self
,
*
args
,
*
args
,
fail_at
:
Optional
[
L
ist
[
int
]]
=
None
,
fail_at
:
Optional
[
l
ist
[
int
]]
=
None
,
**
kwargs
):
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
i
=
0
self
.
i
=
0
...
...
tests/tokenization/test_tokenizer_registry.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Union
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
vllm.transformers_utils.tokenizer_base
import
(
TokenizerBase
,
from
vllm.transformers_utils.tokenizer_base
import
(
TokenizerBase
,
...
@@ -17,15 +17,15 @@ class TestTokenizer(TokenizerBase):
...
@@ -17,15 +17,15 @@ class TestTokenizer(TokenizerBase):
return
TestTokenizer
()
return
TestTokenizer
()
@
property
@
property
def
all_special_tokens_extended
(
self
)
->
L
ist
[
str
]:
def
all_special_tokens_extended
(
self
)
->
l
ist
[
str
]:
raise
NotImplementedError
()
raise
NotImplementedError
()
@
property
@
property
def
all_special_tokens
(
self
)
->
L
ist
[
str
]:
def
all_special_tokens
(
self
)
->
l
ist
[
str
]:
raise
NotImplementedError
()
raise
NotImplementedError
()
@
property
@
property
def
all_special_ids
(
self
)
->
L
ist
[
int
]:
def
all_special_ids
(
self
)
->
l
ist
[
int
]:
raise
NotImplementedError
()
raise
NotImplementedError
()
@
property
@
property
...
@@ -58,7 +58,7 @@ class TestTokenizer(TokenizerBase):
...
@@ -58,7 +58,7 @@ class TestTokenizer(TokenizerBase):
def
__call__
(
def
__call__
(
self
,
self
,
text
:
Union
[
str
,
L
ist
[
str
],
L
ist
[
int
]],
text
:
Union
[
str
,
l
ist
[
str
],
l
ist
[
int
]],
text_pair
:
Optional
[
str
]
=
None
,
text_pair
:
Optional
[
str
]
=
None
,
add_special_tokens
:
bool
=
False
,
add_special_tokens
:
bool
=
False
,
truncation
:
bool
=
False
,
truncation
:
bool
=
False
,
...
@@ -66,10 +66,10 @@ class TestTokenizer(TokenizerBase):
...
@@ -66,10 +66,10 @@ class TestTokenizer(TokenizerBase):
):
):
raise
NotImplementedError
()
raise
NotImplementedError
()
def
get_vocab
(
self
)
->
D
ict
[
str
,
int
]:
def
get_vocab
(
self
)
->
d
ict
[
str
,
int
]:
raise
NotImplementedError
()
raise
NotImplementedError
()
def
get_added_vocab
(
self
)
->
D
ict
[
str
,
int
]:
def
get_added_vocab
(
self
)
->
d
ict
[
str
,
int
]:
raise
NotImplementedError
()
raise
NotImplementedError
()
def
encode_one
(
def
encode_one
(
...
@@ -77,33 +77,33 @@ class TestTokenizer(TokenizerBase):
...
@@ -77,33 +77,33 @@ class TestTokenizer(TokenizerBase):
text
:
str
,
text
:
str
,
truncation
:
bool
=
False
,
truncation
:
bool
=
False
,
max_length
:
Optional
[
int
]
=
None
,
max_length
:
Optional
[
int
]
=
None
,
)
->
L
ist
[
int
]:
)
->
l
ist
[
int
]:
raise
NotImplementedError
()
raise
NotImplementedError
()
def
encode
(
self
,
def
encode
(
self
,
text
:
str
,
text
:
str
,
add_special_tokens
:
Optional
[
bool
]
=
None
)
->
L
ist
[
int
]:
add_special_tokens
:
Optional
[
bool
]
=
None
)
->
l
ist
[
int
]:
raise
NotImplementedError
()
raise
NotImplementedError
()
def
apply_chat_template
(
self
,
def
apply_chat_template
(
self
,
messages
:
L
ist
[
"ChatCompletionMessageParam"
],
messages
:
l
ist
[
"ChatCompletionMessageParam"
],
tools
:
Optional
[
L
ist
[
D
ict
[
str
,
Any
]]]
=
None
,
tools
:
Optional
[
l
ist
[
d
ict
[
str
,
Any
]]]
=
None
,
**
kwargs
)
->
L
ist
[
int
]:
**
kwargs
)
->
l
ist
[
int
]:
raise
NotImplementedError
()
raise
NotImplementedError
()
def
convert_tokens_to_string
(
self
,
tokens
:
L
ist
[
str
])
->
str
:
def
convert_tokens_to_string
(
self
,
tokens
:
l
ist
[
str
])
->
str
:
raise
NotImplementedError
()
raise
NotImplementedError
()
def
decode
(
self
,
def
decode
(
self
,
ids
:
Union
[
L
ist
[
int
],
int
],
ids
:
Union
[
l
ist
[
int
],
int
],
skip_special_tokens
:
bool
=
True
)
->
str
:
skip_special_tokens
:
bool
=
True
)
->
str
:
raise
NotImplementedError
()
raise
NotImplementedError
()
def
convert_ids_to_tokens
(
def
convert_ids_to_tokens
(
self
,
self
,
ids
:
L
ist
[
int
],
ids
:
l
ist
[
int
],
skip_special_tokens
:
bool
=
True
,
skip_special_tokens
:
bool
=
True
,
)
->
L
ist
[
str
]:
)
->
l
ist
[
str
]:
raise
NotImplementedError
()
raise
NotImplementedError
()
...
...
Prev
1
…
5
6
7
8
9
10
11
12
13
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment