Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0e9164b4
"vscode:/vscode.git/clone" did not exist on "9324e10275cce6e0fd189bf1ebb0c399d858e9e1"
Unverified
Commit
0e9164b4
authored
Jun 15, 2024
by
Cyrus Leung
Committed by
GitHub
Jun 15, 2024
Browse files
[mypy] Enable type checking for test directory (#5017)
parent
1b8a0d71
Changes
92
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
107 additions
and
73 deletions
+107
-73
tests/models/test_fp8.py
tests/models/test_fp8.py
+2
-1
tests/prefix_caching/test_prefix_caching.py
tests/prefix_caching/test_prefix_caching.py
+4
-1
tests/quantization/test_configs.py
tests/quantization/test_configs.py
+2
-1
tests/samplers/test_logprobs.py
tests/samplers/test_logprobs.py
+7
-4
tests/samplers/test_rejection_sampler.py
tests/samplers/test_rejection_sampler.py
+2
-2
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+23
-18
tests/spec_decode/e2e/conftest.py
tests/spec_decode/e2e/conftest.py
+7
-6
tests/spec_decode/test_batch_expansion.py
tests/spec_decode/test_batch_expansion.py
+4
-2
tests/spec_decode/test_multi_step_worker.py
tests/spec_decode/test_multi_step_worker.py
+12
-7
tests/spec_decode/test_spec_decode_worker.py
tests/spec_decode/test_spec_decode_worker.py
+12
-5
tests/spec_decode/utils.py
tests/spec_decode/utils.py
+9
-5
tests/test_cache_block_hashing.py
tests/test_cache_block_hashing.py
+1
-1
tests/test_logger.py
tests/test_logger.py
+1
-0
tests/tokenization/test_detokenize.py
tests/tokenization/test_detokenize.py
+2
-2
tests/utils.py
tests/utils.py
+1
-1
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+12
-11
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+2
-2
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+2
-2
vllm/core/block/block_table.py
vllm/core/block/block_table.py
+1
-1
vllm/core/block/naive_block.py
vllm/core/block/naive_block.py
+1
-1
No files found.
tests/models/test_fp8.py
View file @
0e9164b4
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
Note: these tests will only pass on L4 GPU.
Note: these tests will only pass on L4 GPU.
"""
"""
import
os
import
os
from
typing
import
List
import
pytest
import
pytest
import
torch
import
torch
...
@@ -100,7 +101,7 @@ def test_models(example_prompts, model_name, kv_cache_dtype) -> None:
...
@@ -100,7 +101,7 @@ def test_models(example_prompts, model_name, kv_cache_dtype) -> None:
]
]
params
=
SamplingParams
(
max_tokens
=
20
,
temperature
=
0
)
params
=
SamplingParams
(
max_tokens
=
20
,
temperature
=
0
)
generations
=
[]
generations
:
List
[
str
]
=
[]
# Note: these need to be run 1 at a time due to numerical precision,
# Note: these need to be run 1 at a time due to numerical precision,
# since the expected strs were generated this way.
# since the expected strs were generated this way.
for
prompt
in
formatted_prompts
:
for
prompt
in
formatted_prompts
:
...
...
tests/prefix_caching/test_prefix_caching.py
View file @
0e9164b4
...
@@ -2,8 +2,11 @@
...
@@ -2,8 +2,11 @@
Run `pytest tests/prefix_caching/test_prefix_caching.py`.
Run `pytest tests/prefix_caching/test_prefix_caching.py`.
"""
"""
from
typing
import
List
import
pytest
import
pytest
from
vllm.block
import
PhysicalTokenBlock
from
vllm.core.block_manager_v1
import
CachedBlockAllocator
from
vllm.core.block_manager_v1
import
CachedBlockAllocator
from
vllm.utils
import
Device
from
vllm.utils
import
Device
...
@@ -43,7 +46,7 @@ def test_block_allocator(
...
@@ -43,7 +46,7 @@ def test_block_allocator(
def
test_eviction
(
num_blocks
:
int
,
):
def
test_eviction
(
num_blocks
:
int
,
):
block_size
=
16
block_size
=
16
block_allocator
=
CachedBlockAllocator
(
Device
.
CPU
,
block_size
,
num_blocks
)
block_allocator
=
CachedBlockAllocator
(
Device
.
CPU
,
block_size
,
num_blocks
)
blocks
=
[]
blocks
:
List
[
PhysicalTokenBlock
]
=
[]
for
i
in
range
(
num_blocks
):
for
i
in
range
(
num_blocks
):
# use i as the block_hash
# use i as the block_hash
...
...
tests/quantization/test_configs.py
View file @
0e9164b4
...
@@ -4,6 +4,7 @@ Run `pytest tests/quantization/test_configs.py --forked`.
...
@@ -4,6 +4,7 @@ Run `pytest tests/quantization/test_configs.py --forked`.
"""
"""
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Tuple
import
pytest
import
pytest
...
@@ -51,7 +52,7 @@ MODEL_ARG_EXPTYPES = [
...
@@ -51,7 +52,7 @@ MODEL_ARG_EXPTYPES = [
@
pytest
.
mark
.
parametrize
(
"model_arg_exptype"
,
MODEL_ARG_EXPTYPES
)
@
pytest
.
mark
.
parametrize
(
"model_arg_exptype"
,
MODEL_ARG_EXPTYPES
)
def
test_auto_gptq
(
model_arg_exptype
:
str
)
->
None
:
def
test_auto_gptq
(
model_arg_exptype
:
Tuple
[
str
,
None
,
str
]
)
->
None
:
model_path
,
quantization_arg
,
expected_type
=
model_arg_exptype
model_path
,
quantization_arg
,
expected_type
=
model_arg_exptype
try
:
try
:
...
...
tests/samplers/test_logprobs.py
View file @
0e9164b4
from
typing
import
List
import
pytest
import
pytest
import
torch
import
torch
...
@@ -62,21 +64,22 @@ def test_get_prompt_logprobs(
...
@@ -62,21 +64,22 @@ def test_get_prompt_logprobs(
for
logprobs
in
result
.
outputs
[
0
].
logprobs
:
for
logprobs
in
result
.
outputs
[
0
].
logprobs
:
assert
len
(
logprobs
)
==
num_top_logprobs
assert
len
(
logprobs
)
==
num_top_logprobs
output_text
=
result
.
outputs
[
0
].
text
output_text
=
result
.
outputs
[
0
].
text
output_string_from_most_likely_tokens
=
[]
output_string_from_most_likely_tokens
_lst
:
List
[
str
]
=
[]
for
top_logprobs
in
result
.
outputs
[
0
].
logprobs
:
for
top_logprobs
in
result
.
outputs
[
0
].
logprobs
:
top_logprob
=
next
(
iter
(
top_logprobs
.
values
()))
top_logprob
=
next
(
iter
(
top_logprobs
.
values
()))
output_string_from_most_likely_tokens
.
append
(
output_string_from_most_likely_tokens
_lst
.
append
(
top_logprob
.
decoded_token
)
top_logprob
.
decoded_token
)
if
detokenize
:
if
detokenize
:
output_string_from_most_likely_tokens
=
""
.
join
(
output_string_from_most_likely_tokens
=
""
.
join
(
output_string_from_most_likely_tokens
)
output_string_from_most_likely_tokens
_lst
)
assert
output_text
==
output_string_from_most_likely_tokens
,
(
assert
output_text
==
output_string_from_most_likely_tokens
,
(
"The output text from the top logprob for each token position "
"The output text from the top logprob for each token position "
"should be the same as the output text in the result."
)
"should be the same as the output text in the result."
)
else
:
else
:
assert
output_text
==
''
assert
output_text
==
''
assert
output_string_from_most_likely_tokens
==
[
None
]
*
max_tokens
assert
output_string_from_most_likely_tokens_lst
==
([
None
]
*
max_tokens
)
# The first prompt logprob is always None
# The first prompt logprob is always None
assert
result
.
prompt_logprobs
[
0
]
is
None
assert
result
.
prompt_logprobs
[
0
]
is
None
...
...
tests/samplers/test_rejection_sampler.py
View file @
0e9164b4
...
@@ -246,8 +246,8 @@ def test_rejection_sampling_approximates_target_distribution(
...
@@ -246,8 +246,8 @@ def test_rejection_sampling_approximates_target_distribution(
draft_and_target_probs_equal
)
draft_and_target_probs_equal
)
sample_sizes
=
[
10
,
100
,
1_000
,
10_000
,
100_000
]
sample_sizes
=
[
10
,
100
,
1_000
,
10_000
,
100_000
]
distance_wrt_reference
=
[]
distance_wrt_reference
:
List
[
float
]
=
[]
distance_wrt_target
=
[]
distance_wrt_target
:
List
[
float
]
=
[]
for
num_samples
in
sample_sizes
:
for
num_samples
in
sample_sizes
:
(
reference_vs_rejsample_dist
,
(
reference_vs_rejsample_dist
,
...
...
tests/samplers/test_sampler.py
View file @
0e9164b4
import
itertools
import
itertools
import
random
import
random
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
pytest
import
pytest
...
@@ -49,8 +49,8 @@ def _do_sample(
...
@@ -49,8 +49,8 @@ def _do_sample(
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
device
:
str
,
device
:
str
,
):
):
seq_group_metadata_list
=
[]
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
seq_lens
=
[]
seq_lens
:
List
[
int
]
=
[]
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
SequenceGroupMetadata
(
...
@@ -212,7 +212,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
...
@@ -212,7 +212,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
batch_size
=
random
.
randint
(
1
,
128
)
batch_size
=
random
.
randint
(
1
,
128
)
expected_penalization
=
[]
expected_penalization
=
[]
sequence_metadata_list
=
[]
sequence_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
# 20% chance to generate seq group metadata list with all prompts
# 20% chance to generate seq group metadata list with all prompts
is_prompt
=
random
.
random
()
<
0.2
is_prompt
=
random
.
random
()
<
0.2
while
batch_size
>
0
:
while
batch_size
>
0
:
...
@@ -232,8 +232,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
...
@@ -232,8 +232,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
eos_token_id
=
eos_token_id
,
eos_token_id
=
eos_token_id
,
stop_token_ids
=
stop_token_ids
)
stop_token_ids
=
stop_token_ids
)
seq_data
=
{}
seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
seq_group_penalization
=
[]
seq_group_penalization
:
List
[
bool
]
=
[]
for
_
in
range
(
num_seqs
):
for
_
in
range
(
num_seqs
):
num_input
=
random
.
randint
(
1
,
100
)
num_input
=
random
.
randint
(
1
,
100
)
num_generated
=
0
if
is_prompt
else
random
.
randint
(
1
,
100
)
num_generated
=
0
if
is_prompt
else
random
.
randint
(
1
,
100
)
...
@@ -392,17 +392,16 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
...
@@ -392,17 +392,16 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
else
:
else
:
test_cases
=
[
generate_test_case
()]
test_cases
=
[
generate_test_case
()]
def
run_test_case
(
*
,
def
run_test_case
(
*
,
expected_penalization
:
List
[
bool
],
expected_penalization
=
None
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]):
seq_group_metadata_list
=
None
):
assert
expected_penalization
,
\
assert
expected_penalization
,
\
"Invalid test case, need expected_penalization"
"Invalid test case, need expected_penalization"
assert
seq_group_metadata_list
,
\
assert
seq_group_metadata_list
,
\
"Invalid test case, need seq_group_metadata_list"
"Invalid test case, need seq_group_metadata_list"
batch_size
=
0
batch_size
=
0
seq_lens
=
[]
seq_lens
:
List
[
int
]
=
[]
sampling_params_per_row
=
[]
sampling_params_per_row
:
List
[
SamplingParams
]
=
[]
for
sgm
in
seq_group_metadata_list
:
for
sgm
in
seq_group_metadata_list
:
sampling_params
=
sgm
.
sampling_params
sampling_params
=
sgm
.
sampling_params
...
@@ -472,15 +471,15 @@ def test_sampler_mixed(seed: int, device: str):
...
@@ -472,15 +471,15 @@ def test_sampler_mixed(seed: int, device: str):
batch_size
=
random
.
randint
(
1
,
256
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
fake_logits
,
sampler
=
_prepare_test
(
batch_size
)
input_tensor
,
fake_logits
,
sampler
=
_prepare_test
(
batch_size
)
seq_group_metadata_list
=
[]
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
expected_tokens
:
List
[
Optional
[
List
[
int
]]]
=
[]
expected_tokens
:
List
[
Optional
[
List
[
int
]]]
=
[]
seq_lens
=
[]
seq_lens
:
List
[
int
]
=
[]
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
expected
:
Optional
[
List
[
int
]]
=
None
expected
:
Optional
[
List
[
int
]]
=
None
sampling_type
=
random
.
randint
(
0
,
3
)
sampling_type
=
random
.
randint
(
0
,
3
)
if
sampling_type
==
0
:
if
sampling_type
==
0
:
sampling_params
=
SamplingParams
(
temperature
=
0
)
sampling_params
=
SamplingParams
(
temperature
=
0
)
expected
=
[
torch
.
argmax
(
fake_logits
[
i
],
dim
=-
1
).
item
()]
expected
=
[
int
(
torch
.
argmax
(
fake_logits
[
i
],
dim
=-
1
).
item
()
)
]
elif
sampling_type
in
(
1
,
2
):
elif
sampling_type
in
(
1
,
2
):
n
=
random
.
randint
(
1
,
10
)
n
=
random
.
randint
(
1
,
10
)
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
...
@@ -536,15 +535,18 @@ def test_sampler_mixed(seed: int, device: str):
...
@@ -536,15 +535,18 @@ def test_sampler_mixed(seed: int, device: str):
]
]
continue
continue
expected_tokens_item
=
expected_tokens
[
i
]
assert
expected_tokens_item
is
not
None
for
n
,
nth_output
in
enumerate
(
sequence_output
.
samples
):
for
n
,
nth_output
in
enumerate
(
sequence_output
.
samples
):
if
(
metadata
.
sampling_params
.
temperature
==
0
if
(
metadata
.
sampling_params
.
temperature
==
0
or
metadata
.
sampling_params
.
seed
is
not
None
):
or
metadata
.
sampling_params
.
seed
is
not
None
):
# Ensure exact matches for greedy or random with seed
# Ensure exact matches for greedy or random with seed
assert
nth_output
.
output_token
==
expected_tokens
[
i
]
[
n
]
assert
nth_output
.
output_token
==
expected_tokens
_item
[
n
]
else
:
else
:
# For non-seeded random check that one of the high-logit
# For non-seeded random check that one of the high-logit
# tokens were chosen
# tokens were chosen
assert
nth_output
.
output_token
in
expected_tokens
[
i
]
assert
nth_output
.
output_token
in
expected_tokens
_item
# Test batch
# Test batch
test_sampling
()
test_sampling
()
...
@@ -588,8 +590,8 @@ def test_sampler_top_k_top_p(seed: int, device: str):
...
@@ -588,8 +590,8 @@ def test_sampler_top_k_top_p(seed: int, device: str):
warpers
=
generation_model
.
_get_logits_warper
(
generation_config
)
warpers
=
generation_model
.
_get_logits_warper
(
generation_config
)
assert
len
(
warpers
)
==
2
# top_p and top_k
assert
len
(
warpers
)
==
2
# top_p and top_k
seq_group_metadata_list
=
[]
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
seq_lens
=
[]
seq_lens
:
List
[
int
]
=
[]
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
SequenceGroupMetadata
(
...
@@ -622,6 +624,9 @@ def test_sampler_top_k_top_p(seed: int, device: str):
...
@@ -622,6 +624,9 @@ def test_sampler_top_k_top_p(seed: int, device: str):
with
patch
(
"vllm.model_executor.layers.sampler._sample"
,
mock_sample
):
with
patch
(
"vllm.model_executor.layers.sampler._sample"
,
mock_sample
):
sampler
(
logits
=
fake_logits
,
sampling_metadata
=
sampling_metadata
)
sampler
(
logits
=
fake_logits
,
sampling_metadata
=
sampling_metadata
)
assert
sample_probs
is
not
None
hf_probs
=
warpers
(
torch
.
zeros_like
(
fake_logits
),
fake_logits
.
clone
())
hf_probs
=
warpers
(
torch
.
zeros_like
(
fake_logits
),
fake_logits
.
clone
())
hf_probs
=
torch
.
softmax
(
hf_probs
,
dim
=-
1
,
dtype
=
torch
.
float
)
hf_probs
=
torch
.
softmax
(
hf_probs
,
dim
=-
1
,
dtype
=
torch
.
float
)
assert
torch
.
allclose
(
hf_probs
,
sample_probs
,
atol
=
1e-5
)
assert
torch
.
allclose
(
hf_probs
,
sample_probs
,
atol
=
1e-5
)
...
...
tests/spec_decode/e2e/conftest.py
View file @
0e9164b4
...
@@ -118,16 +118,17 @@ class AsyncLLM:
...
@@ -118,16 +118,17 @@ class AsyncLLM:
raise
ValueError
(
"The lengths of prompts and "
raise
ValueError
(
"The lengths of prompts and "
"sampling_params must be the same."
)
"sampling_params must be the same."
)
async
def
get_output
(
prompt
,
sampling_param
)
->
str
:
async
def
get_output
(
prompt
,
sampling_param
)
->
RequestOutput
:
request_id
=
random_uuid
()
request_id
=
random_uuid
()
results_generator
=
self
.
llm_engine
.
generate
(
results_generator
=
self
.
llm_engine
.
generate
(
prompt
,
sampling_param
,
request_id
)
prompt
,
sampling_param
,
request_id
)
final_output
=
None
final_output
=
None
async
for
request_output
in
results_generator
:
async
for
request_output
in
results_generator
:
final_output
=
request_output
final_output
=
request_output
assert
final_output
is
not
None
return
final_output
return
final_output
outputs
=
[]
outputs
:
List
[
RequestOutput
]
=
[]
try
:
try
:
for
i
in
range
(
num_requests
):
for
i
in
range
(
num_requests
):
prompt
=
prompts
[
i
]
if
prompts
is
not
None
else
None
prompt
=
prompts
[
i
]
if
prompts
is
not
None
else
None
...
@@ -208,8 +209,8 @@ def maybe_assert_ngram_worker(llm):
...
@@ -208,8 +209,8 @@ def maybe_assert_ngram_worker(llm):
def
get_output_from_llm_generator
(
def
get_output_from_llm_generator
(
llm_generator
,
prompts
,
llm_generator
,
prompts
,
sampling_params
)
->
Tuple
[
List
[
str
],
List
[
List
[
int
]]]:
sampling_params
)
->
Tuple
[
List
[
str
],
List
[
List
[
int
]]]:
tokens
=
[]
tokens
:
List
[
str
]
=
[]
token_ids
=
[]
token_ids
:
List
[
List
[
int
]]
=
[]
for
llm
in
llm_generator
():
for
llm
in
llm_generator
():
maybe_assert_ngram_worker
(
llm
)
maybe_assert_ngram_worker
(
llm
)
...
@@ -300,8 +301,8 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
...
@@ -300,8 +301,8 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
nvmlInit
()
nvmlInit
()
start_time
=
time
.
time
()
start_time
=
time
.
time
()
while
True
:
while
True
:
output
=
{}
output
:
Dict
[
int
,
str
]
=
{}
output_raw
=
{}
output_raw
:
Dict
[
int
,
float
]
=
{}
for
device
in
devices
:
for
device
in
devices
:
dev_handle
=
nvmlDeviceGetHandleByIndex
(
device
)
dev_handle
=
nvmlDeviceGetHandleByIndex
(
device
)
mem_info
=
nvmlDeviceGetMemoryInfo
(
dev_handle
)
mem_info
=
nvmlDeviceGetMemoryInfo
(
dev_handle
)
...
...
tests/spec_decode/test_batch_expansion.py
View file @
0e9164b4
from
typing
import
List
import
pytest
import
pytest
import
torch
import
torch
...
@@ -38,14 +40,14 @@ def test_get_token_ids_to_score(k: int):
...
@@ -38,14 +40,14 @@ def test_get_token_ids_to_score(k: int):
device
=
'cuda'
,
device
=
'cuda'
,
)
)
expected_output
=
[
expected_output
:
List
[
List
[
int
]]
=
[
[],
[],
]
]
for
i
in
range
(
proposal_token_ids
.
shape
[
0
]):
for
i
in
range
(
proposal_token_ids
.
shape
[
0
]):
expected_output
.
append
(
proposal_token_ids
[:
i
+
1
].
tolist
())
expected_output
.
append
(
proposal_token_ids
[:
i
+
1
].
tolist
())
scorer
=
BatchExpansionTop1Scorer
(
mock_worker
(),
'cuda:0'
,
32_000
)
scorer
=
BatchExpansionTop1Scorer
(
mock_worker
(),
'cuda:0'
,
32_000
)
actual_output
=
scorer
.
_get_token_ids_to_score
(
proposal_token_ids
)
# pylint: disable=protected-access
actual_output
=
scorer
.
_get_token_ids_to_score
(
proposal_token_ids
.
tolist
()
)
# pylint: disable=protected-access
actual_output
=
[
actual_output
=
[
x
.
tolist
()
if
isinstance
(
x
,
torch
.
Tensor
)
else
x
for
x
in
actual_output
x
.
tolist
()
if
isinstance
(
x
,
torch
.
Tensor
)
else
x
for
x
in
actual_output
...
...
tests/spec_decode/test_multi_step_worker.py
View file @
0e9164b4
import
random
import
random
from
typing
import
Dict
,
List
from
unittest.mock
import
MagicMock
from
unittest.mock
import
MagicMock
import
pytest
import
pytest
import
torch
import
torch
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
Logprob
,
SamplerOutput
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.worker.worker
import
Worker
from
vllm.worker.worker
import
Worker
...
@@ -210,7 +211,7 @@ def test_same_output_for_multi_step():
...
@@ -210,7 +211,7 @@ def test_same_output_for_multi_step():
# Run single-step repeatedly.
# Run single-step repeatedly.
zero_kv_cache
(
worker
.
cache_engine
)
zero_kv_cache
(
worker
.
cache_engine
)
single_step_output
=
[]
single_step_output
:
List
[
SamplerOutput
]
=
[]
continuations
=
[[
1
]
for
_
in
prompts
]
continuations
=
[[
1
]
for
_
in
prompts
]
set_random_seed
(
seed
)
set_random_seed
(
seed
)
...
@@ -232,11 +233,15 @@ def test_same_output_for_multi_step():
...
@@ -232,11 +233,15 @@ def test_same_output_for_multi_step():
continuations
[
i
].
append
(
seq_group_output
.
samples
[
0
].
output_token
)
continuations
[
i
].
append
(
seq_group_output
.
samples
[
0
].
output_token
)
# Get token ids and logprobs for comparison.
# Get token ids and logprobs for comparison.
multi_step_output_logprobs
=
[[]
for
_
in
prompts
]
multi_step_output_logprobs
:
List
[
List
[
Dict
[
int
,
single_step_output_logprobs
=
[[]
for
_
in
prompts
]
Logprob
]]]
=
[[]
for
_
in
prompts
]
multi_step_output_token_ids
=
[[]
for
_
in
prompts
]
single_step_output_logprobs
:
List
[
List
[
Dict
[
int
,
single_step_output_token_ids
=
[[]
for
_
in
prompts
]
Logprob
]]]
=
[[]
for
_
in
prompts
]
multi_step_output_token_ids
:
List
[
List
[
int
]]
=
[[]
for
_
in
prompts
]
single_step_output_token_ids
:
List
[
List
[
int
]]
=
[[]
for
_
in
prompts
]
for
i
,
_
in
enumerate
(
prompts
):
for
i
,
_
in
enumerate
(
prompts
):
for
multi_step
,
single_step
in
zip
(
multi_step_output
,
for
multi_step
,
single_step
in
zip
(
multi_step_output
,
single_step_output
):
single_step_output
):
...
...
tests/spec_decode/test_spec_decode_worker.py
View file @
0e9164b4
import
random
import
random
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
from
typing
import
Dict
,
List
from
unittest.mock
import
MagicMock
from
unittest.mock
import
MagicMock
import
pytest
import
pytest
...
@@ -7,7 +8,7 @@ import torch
...
@@ -7,7 +8,7 @@ import torch
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
,
SequenceOutput
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.metrics
import
(
AsyncMetricsCollector
,
from
vllm.spec_decode.metrics
import
(
AsyncMetricsCollector
,
SpecDecodeWorkerMetrics
)
SpecDecodeWorkerMetrics
)
...
@@ -103,7 +104,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
...
@@ -103,7 +104,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
))
num_lookahead_slots
=
k
))
seen_contexts
=
[]
seen_contexts
:
List
[
List
[
int
]]
=
[]
call_args_list
=
target_worker
.
execute_model
.
call_args_list
call_args_list
=
target_worker
.
execute_model
.
call_args_list
assert
len
(
call_args_list
)
==
1
assert
len
(
call_args_list
)
==
1
...
@@ -116,7 +117,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
...
@@ -116,7 +117,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
for
seq_data
in
seq_group_metadata
.
seq_data
.
values
():
for
seq_data
in
seq_group_metadata
.
seq_data
.
values
():
seen_contexts
.
append
(
seq_data
.
get_token_ids
())
seen_contexts
.
append
(
seq_data
.
get_token_ids
())
expected_seen_contexts
=
[]
expected_seen_contexts
:
List
[
List
[
int
]]
=
[]
for
prompt
,
prev_generated
,
draft_tokens
in
zip
(
for
prompt
,
prev_generated
,
draft_tokens
in
zip
(
prompts
,
prev_output_tokens
,
proposal_token_ids
.
tolist
()):
prompts
,
prev_output_tokens
,
proposal_token_ids
.
tolist
()):
...
@@ -310,8 +311,14 @@ def test_correctly_formats_output(k: int, batch_size: int):
...
@@ -310,8 +311,14 @@ def test_correctly_formats_output(k: int, batch_size: int):
next
(
iter
(
seq_group_metadata
.
seq_data
.
keys
()))
next
(
iter
(
seq_group_metadata
.
seq_data
.
keys
()))
for
seq_group_metadata
in
seq_group_metadata_list
for
seq_group_metadata
in
seq_group_metadata_list
]
]
actual_output_by_seq
=
{
seq_id
:
[]
for
seq_id
in
seq_ids
}
actual_output_by_seq
:
Dict
[
int
,
List
[
SequenceOutput
]]
=
{
expected_output_by_seq
=
{
seq_id
:
[]
for
seq_id
in
seq_ids
}
seq_id
:
[]
for
seq_id
in
seq_ids
}
expected_output_by_seq
:
Dict
[
int
,
List
[
SequenceOutput
]]
=
{
seq_id
:
[]
for
seq_id
in
seq_ids
}
for
step
in
output
:
for
step
in
output
:
for
seq_group
in
step
:
for
seq_group
in
step
:
...
...
tests/spec_decode/utils.py
View file @
0e9164b4
from
itertools
import
count
from
itertools
import
count
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Union
from
typing
import
Callable
,
Dict
,
List
,
Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
TypeVar
,
Union
from
unittest.mock
import
MagicMock
from
unittest.mock
import
MagicMock
import
torch
import
torch
...
@@ -14,6 +16,8 @@ from vllm.utils import get_distributed_init_method, get_ip, get_open_port
...
@@ -14,6 +16,8 @@ from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.worker
import
Worker
from
vllm.worker.worker
import
Worker
T
=
TypeVar
(
"T"
,
bound
=
Worker
)
def
round_up_to_next_block
(
seq_len
:
int
,
block_size
:
int
)
->
int
:
def
round_up_to_next_block
(
seq_len
:
int
,
block_size
:
int
)
->
int
:
return
(
seq_len
+
block_size
-
1
)
//
block_size
return
(
seq_len
+
block_size
-
1
)
//
block_size
...
@@ -56,13 +60,13 @@ def zero_kv_cache(cache_engine: CacheEngine):
...
@@ -56,13 +60,13 @@ def zero_kv_cache(cache_engine: CacheEngine):
value_blocks
.
zero_
()
value_blocks
.
zero_
()
def
create_worker
(
cls
:
type
,
def
create_worker
(
cls
:
Callable
[...,
T
]
,
model_name
:
str
,
model_name
:
str
,
block_size
:
int
,
block_size
:
int
,
num_gpu_blocks
:
int
,
num_gpu_blocks
:
int
,
seed
:
int
,
seed
:
int
,
is_driver_worker
:
bool
=
True
,
is_driver_worker
:
bool
=
True
,
enforce_eager
:
bool
=
True
):
enforce_eager
:
bool
=
True
)
->
T
:
engine_args
=
EngineArgs
(
engine_args
=
EngineArgs
(
model
=
model_name
,
model
=
model_name
,
seed
=
seed
,
seed
=
seed
,
...
@@ -159,8 +163,8 @@ def assert_logprobs_dict_allclose(
...
@@ -159,8 +163,8 @@ def assert_logprobs_dict_allclose(
def
create_sampler_output_list
(
def
create_sampler_output_list
(
token_ids
:
torch
.
Tensor
,
token_ids
:
torch
.
Tensor
,
probs
:
Iterabl
e
[
Optional
[
torch
.
Tensor
]],
probs
:
GenericSequenc
e
[
Optional
[
torch
.
Tensor
]],
logprobs
:
Iterabl
e
[
Optional
[
torch
.
Tensor
]],
logprobs
:
GenericSequenc
e
[
Optional
[
torch
.
Tensor
]],
seq_ids
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
SamplerOutput
]:
seq_ids
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
SamplerOutput
]:
num_steps
,
batch_size
=
token_ids
.
shape
num_steps
,
batch_size
=
token_ids
.
shape
token_ids_by_step
=
token_ids
.
tolist
()
token_ids_by_step
=
token_ids
.
tolist
()
...
...
tests/test_cache_block_hashing.py
View file @
0e9164b4
...
@@ -51,7 +51,7 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
...
@@ -51,7 +51,7 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
max_input_length
=
None
,
max_input_length
=
None
,
)
)
hashes
=
[]
hashes
:
List
[
List
[
List
[
int
]]]
=
[]
for
prefix
in
prefixes
:
for
prefix
in
prefixes
:
for
lora_int_id
in
concurrent_lora_int_ids
:
for
lora_int_id
in
concurrent_lora_int_ids
:
...
...
tests/test_logger.py
View file @
0e9164b4
...
@@ -47,6 +47,7 @@ def test_default_vllm_root_logger_configuration():
...
@@ -47,6 +47,7 @@ def test_default_vllm_root_logger_configuration():
assert
not
logger
.
propagate
assert
not
logger
.
propagate
handler
=
logger
.
handlers
[
0
]
handler
=
logger
.
handlers
[
0
]
assert
isinstance
(
handler
,
logging
.
StreamHandler
)
assert
handler
.
stream
==
sys
.
stdout
assert
handler
.
stream
==
sys
.
stdout
assert
handler
.
level
==
logging
.
INFO
assert
handler
.
level
==
logging
.
INFO
...
...
tests/tokenization/test_detokenize.py
View file @
0e9164b4
...
@@ -153,8 +153,8 @@ def test_decode_sequence_logprobs(complete_sequence: str,
...
@@ -153,8 +153,8 @@ def test_decode_sequence_logprobs(complete_sequence: str,
# Run sequentially.
# Run sequentially.
seq
=
create_sequence
()
seq
=
create_sequence
()
dummy_logprobs
=
create_dummy_logprobs
(
complete_sequence_token_ids
)
dummy_logprobs
=
create_dummy_logprobs
(
complete_sequence_token_ids
)
sequential_logprobs_text_chosen_token
=
[]
sequential_logprobs_text_chosen_token
:
List
[
str
]
=
[]
sequential_logprobs_text_other_token
=
[]
sequential_logprobs_text_other_token
:
List
[
str
]
=
[]
for
new_token
,
logprobs
in
zip
(
complete_sequence_token_ids
,
for
new_token
,
logprobs
in
zip
(
complete_sequence_token_ids
,
dummy_logprobs
):
dummy_logprobs
):
seq
.
append_token_id
(
new_token
,
logprobs
)
seq
.
append_token_id
(
new_token
,
logprobs
)
...
...
tests/utils.py
View file @
0e9164b4
...
@@ -79,7 +79,7 @@ class RemoteOpenAIServer:
...
@@ -79,7 +79,7 @@ class RemoteOpenAIServer:
self
.
host
=
str
(
args
.
host
or
'localhost'
)
self
.
host
=
str
(
args
.
host
or
'localhost'
)
self
.
port
=
int
(
args
.
port
)
self
.
port
=
int
(
args
.
port
)
self
.
_runner
=
self
.
_RemoteRunner
.
remote
(
self
.
_runner
=
self
.
_RemoteRunner
.
remote
(
# type: ignore
cli_args
,
cli_args
,
wait_url
=
self
.
url_for
(
"health"
),
wait_url
=
self
.
url_for
(
"health"
),
wait_timeout
=
self
.
MAX_SERVER_START_WAIT_S
)
wait_timeout
=
self
.
MAX_SERVER_START_WAIT_S
)
...
...
tests/worker/test_model_runner.py
View file @
0e9164b4
from
typing
import
List
import
pytest
import
pytest
import
torch
import
torch
...
@@ -35,8 +37,8 @@ def test_prepare_prompt(batch_size):
...
@@ -35,8 +37,8 @@ def test_prepare_prompt(batch_size):
enable_chunked_prefill
=
False
,
enable_chunked_prefill
=
False
,
)
)
seq_lens
=
[]
seq_lens
:
List
[
int
]
=
[]
seq_group_metadata_list
=
[]
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
block_tables
=
{
0
:
[
1
]}
block_tables
=
{
0
:
[
1
]}
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
# make sure all tokens fit into one block
# make sure all tokens fit into one block
...
@@ -151,15 +153,14 @@ def test_prepare_decode_cuda_graph(batch_size):
...
@@ -151,15 +153,14 @@ def test_prepare_decode_cuda_graph(batch_size):
enable_chunked_prefill
=
False
,
enable_chunked_prefill
=
False
,
)
)
context_lens
=
[]
context_lens
:
List
[
int
]
=
[]
seq_group_metadata_list
=
[]
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
# Assume each seq group finishes prefill.
# Assume each seq group finishes prefill.
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
# make sure all tokens fit into one block
# make sure all tokens fit into one block
context_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
context_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
context_lens
.
append
(
context_len
)
context_lens
.
append
(
context_len
)
seq_data
=
list
(
range
(
context_len
))
seq_data
=
SequenceData
(
list
(
range
(
context_len
)))
seq_data
=
SequenceData
(
seq_data
)
seq_data
.
update_num_computed_tokens
(
context_len
)
seq_data
.
update_num_computed_tokens
(
context_len
)
# Append one token ID since prefill is finished.
# Append one token ID since prefill is finished.
seq_data
.
append_token_id
(
1
,
0
)
seq_data
.
append_token_id
(
1
,
0
)
...
@@ -257,7 +258,7 @@ def test_empty_seq_group():
...
@@ -257,7 +258,7 @@ def test_empty_seq_group():
dtype
=
"float16"
,
dtype
=
"float16"
,
enforce_eager
=
False
,
enforce_eager
=
False
,
)
)
seq_group_metadata_list
=
[]
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
model_input
=
model_runner
.
_prepare_model_input
(
seq_group_metadata_list
)
model_input
=
model_runner
.
_prepare_model_input
(
seq_group_metadata_list
)
input_tokens
,
input_positions
,
attn_metadata
,
slot_mapping
=
(
input_tokens
,
input_positions
,
attn_metadata
,
slot_mapping
=
(
model_input
.
input_tokens
,
model_input
.
input_tokens
,
...
@@ -310,10 +311,10 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
...
@@ -310,10 +311,10 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
)
)
# Add prefill requests.
# Add prefill requests.
seq_lens
=
[]
seq_lens
:
List
[
int
]
=
[]
seq_group_metadata_list
=
[]
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
prefill_metadata_list
=
[]
prefill_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
decode_metadata_list
=
[]
decode_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
block_tables
=
{
0
:
[
1
]}
block_tables
=
{
0
:
[
1
]}
prefill_batch_size
=
batch_size
//
2
prefill_batch_size
=
batch_size
//
2
decode_batch_size
=
batch_size
-
prefill_batch_size
decode_batch_size
=
batch_size
-
prefill_batch_size
...
...
vllm/attention/backends/torch_sdpa.py
View file @
0e9164b4
...
@@ -245,7 +245,7 @@ def _make_alibi_bias(
...
@@ -245,7 +245,7 @@ def _make_alibi_bias(
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
seq_lens
:
List
[
int
],
seq_lens
:
List
[
int
],
)
->
List
[
torch
.
Tensor
]:
)
->
List
[
torch
.
Tensor
]:
attn_biases
=
[]
attn_biases
:
List
[
torch
.
Tensor
]
=
[]
for
seq_len
in
seq_lens
:
for
seq_len
in
seq_lens
:
bias
=
torch
.
arange
(
seq_len
,
dtype
=
dtype
)
bias
=
torch
.
arange
(
seq_len
,
dtype
=
dtype
)
# NOTE(zhuohan): HF uses
# NOTE(zhuohan): HF uses
...
@@ -271,7 +271,7 @@ def _make_sliding_window_bias(
...
@@ -271,7 +271,7 @@ def _make_sliding_window_bias(
window_size
:
Optional
[
int
],
window_size
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
)
->
List
[
torch
.
Tensor
]:
)
->
List
[
torch
.
Tensor
]:
attn_biases
=
[]
attn_biases
:
List
[
torch
.
Tensor
]
=
[]
for
seq_len
in
seq_lens
:
for
seq_len
in
seq_lens
:
tensor
=
torch
.
full
(
tensor
=
torch
.
full
(
(
1
,
seq_len
,
seq_len
),
(
1
,
seq_len
,
seq_len
),
...
...
vllm/attention/backends/xformers.py
View file @
0e9164b4
...
@@ -431,8 +431,8 @@ def _make_alibi_bias(
...
@@ -431,8 +431,8 @@ def _make_alibi_bias(
num_kv_heads
:
int
,
num_kv_heads
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
seq_lens
:
List
[
int
],
seq_lens
:
List
[
int
],
)
->
L
owerTriangularMaskWithTensor
Bias
:
)
->
L
ist
[
Attention
Bias
]
:
attn_biases
=
[]
attn_biases
:
List
[
AttentionBias
]
=
[]
for
seq_len
in
seq_lens
:
for
seq_len
in
seq_lens
:
bias
=
torch
.
arange
(
seq_len
,
dtype
=
dtype
)
bias
=
torch
.
arange
(
seq_len
,
dtype
=
dtype
)
# NOTE(zhuohan): HF uses
# NOTE(zhuohan): HF uses
...
...
vllm/core/block/block_table.py
View file @
0e9164b4
...
@@ -252,7 +252,7 @@ class BlockTable:
...
@@ -252,7 +252,7 @@ class BlockTable:
def
_allocate_blocks_for_token_ids
(
self
,
prev_block
:
Optional
[
Block
],
def
_allocate_blocks_for_token_ids
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
],
token_ids
:
List
[
int
],
device
:
Device
)
->
List
[
Block
]:
device
:
Device
)
->
List
[
Block
]:
blocks
=
[]
blocks
:
List
[
Block
]
=
[]
for
block_token_ids
in
chunk_list
(
token_ids
,
self
.
_block_size
):
for
block_token_ids
in
chunk_list
(
token_ids
,
self
.
_block_size
):
if
len
(
block_token_ids
)
==
self
.
_block_size
:
if
len
(
block_token_ids
)
==
self
.
_block_size
:
# If the block is full, create an immutable block.
# If the block is full, create an immutable block.
...
...
vllm/core/block/naive_block.py
View file @
0e9164b4
...
@@ -111,7 +111,7 @@ class NaiveBlockAllocator(BlockAllocator):
...
@@ -111,7 +111,7 @@ class NaiveBlockAllocator(BlockAllocator):
"""
"""
source_blocks
=
get_all_blocks_recursively
(
last_block
)
source_blocks
=
get_all_blocks_recursively
(
last_block
)
forked_blocks
=
[]
forked_blocks
:
List
[
Block
]
=
[]
prev_block
=
None
prev_block
=
None
for
block
in
source_blocks
:
for
block
in
source_blocks
:
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment