Unverified Commit 0e9164b4 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[mypy] Enable type checking for test directory (#5017)

parent 1b8a0d71
......@@ -3,6 +3,7 @@
Note: these tests will only pass on L4 GPU.
"""
import os
from typing import List
import pytest
import torch
......@@ -100,7 +101,7 @@ def test_models(example_prompts, model_name, kv_cache_dtype) -> None:
]
params = SamplingParams(max_tokens=20, temperature=0)
generations = []
generations: List[str] = []
# Note: these need to be run 1 at a time due to numerical precision,
# since the expected strs were generated this way.
for prompt in formatted_prompts:
......
......@@ -2,8 +2,11 @@
Run `pytest tests/prefix_caching/test_prefix_caching.py`.
"""
from typing import List
import pytest
from vllm.block import PhysicalTokenBlock
from vllm.core.block_manager_v1 import CachedBlockAllocator
from vllm.utils import Device
......@@ -43,7 +46,7 @@ def test_block_allocator(
def test_eviction(num_blocks: int, ):
block_size = 16
block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks)
blocks = []
blocks: List[PhysicalTokenBlock] = []
for i in range(num_blocks):
# use i as the block_hash
......
......@@ -4,6 +4,7 @@ Run `pytest tests/quantization/test_configs.py --forked`.
"""
from dataclasses import dataclass
from typing import Tuple
import pytest
......@@ -51,7 +52,7 @@ MODEL_ARG_EXPTYPES = [
@pytest.mark.parametrize("model_arg_exptype", MODEL_ARG_EXPTYPES)
def test_auto_gptq(model_arg_exptype: str) -> None:
def test_auto_gptq(model_arg_exptype: Tuple[str, None, str]) -> None:
model_path, quantization_arg, expected_type = model_arg_exptype
try:
......
from typing import List
import pytest
import torch
......@@ -62,21 +64,22 @@ def test_get_prompt_logprobs(
for logprobs in result.outputs[0].logprobs:
assert len(logprobs) == num_top_logprobs
output_text = result.outputs[0].text
output_string_from_most_likely_tokens = []
output_string_from_most_likely_tokens_lst: List[str] = []
for top_logprobs in result.outputs[0].logprobs:
top_logprob = next(iter(top_logprobs.values()))
output_string_from_most_likely_tokens.append(
output_string_from_most_likely_tokens_lst.append(
top_logprob.decoded_token)
if detokenize:
output_string_from_most_likely_tokens = "".join(
output_string_from_most_likely_tokens)
output_string_from_most_likely_tokens_lst)
assert output_text == output_string_from_most_likely_tokens, (
"The output text from the top logprob for each token position "
"should be the same as the output text in the result.")
else:
assert output_text == ''
assert output_string_from_most_likely_tokens == [None] * max_tokens
assert output_string_from_most_likely_tokens_lst == ([None] *
max_tokens)
# The first prompt logprob is always None
assert result.prompt_logprobs[0] is None
......
......@@ -246,8 +246,8 @@ def test_rejection_sampling_approximates_target_distribution(
draft_and_target_probs_equal)
sample_sizes = [10, 100, 1_000, 10_000, 100_000]
distance_wrt_reference = []
distance_wrt_target = []
distance_wrt_reference: List[float] = []
distance_wrt_target: List[float] = []
for num_samples in sample_sizes:
(reference_vs_rejsample_dist,
......
import itertools
import random
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple
from unittest.mock import patch
import pytest
......@@ -49,8 +49,8 @@ def _do_sample(
sampling_params: SamplingParams,
device: str,
):
seq_group_metadata_list = []
seq_lens = []
seq_group_metadata_list: List[SequenceGroupMetadata] = []
seq_lens: List[int] = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
......@@ -212,7 +212,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
batch_size = random.randint(1, 128)
expected_penalization = []
sequence_metadata_list = []
sequence_metadata_list: List[SequenceGroupMetadata] = []
# 20% chance to generate seq group metadata list with all prompts
is_prompt = random.random() < 0.2
while batch_size > 0:
......@@ -232,8 +232,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
eos_token_id=eos_token_id,
stop_token_ids=stop_token_ids)
seq_data = {}
seq_group_penalization = []
seq_data: Dict[int, SequenceData] = {}
seq_group_penalization: List[bool] = []
for _ in range(num_seqs):
num_input = random.randint(1, 100)
num_generated = 0 if is_prompt else random.randint(1, 100)
......@@ -392,17 +392,16 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
else:
test_cases = [generate_test_case()]
def run_test_case(*,
expected_penalization=None,
seq_group_metadata_list=None):
def run_test_case(*, expected_penalization: List[bool],
seq_group_metadata_list: List[SequenceGroupMetadata]):
assert expected_penalization, \
"Invalid test case, need expected_penalization"
assert seq_group_metadata_list, \
"Invalid test case, need seq_group_metadata_list"
batch_size = 0
seq_lens = []
sampling_params_per_row = []
seq_lens: List[int] = []
sampling_params_per_row: List[SamplingParams] = []
for sgm in seq_group_metadata_list:
sampling_params = sgm.sampling_params
......@@ -472,15 +471,15 @@ def test_sampler_mixed(seed: int, device: str):
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler = _prepare_test(batch_size)
seq_group_metadata_list = []
seq_group_metadata_list: List[SequenceGroupMetadata] = []
expected_tokens: List[Optional[List[int]]] = []
seq_lens = []
seq_lens: List[int] = []
for i in range(batch_size):
expected: Optional[List[int]] = None
sampling_type = random.randint(0, 3)
if sampling_type == 0:
sampling_params = SamplingParams(temperature=0)
expected = [torch.argmax(fake_logits[i], dim=-1).item()]
expected = [int(torch.argmax(fake_logits[i], dim=-1).item())]
elif sampling_type in (1, 2):
n = random.randint(1, 10)
sampling_params = SamplingParams(
......@@ -536,15 +535,18 @@ def test_sampler_mixed(seed: int, device: str):
]
continue
expected_tokens_item = expected_tokens[i]
assert expected_tokens_item is not None
for n, nth_output in enumerate(sequence_output.samples):
if (metadata.sampling_params.temperature == 0
or metadata.sampling_params.seed is not None):
# Ensure exact matches for greedy or random with seed
assert nth_output.output_token == expected_tokens[i][n]
assert nth_output.output_token == expected_tokens_item[n]
else:
# For non-seeded random check that one of the high-logit
# tokens were chosen
assert nth_output.output_token in expected_tokens[i]
assert nth_output.output_token in expected_tokens_item
# Test batch
test_sampling()
......@@ -588,8 +590,8 @@ def test_sampler_top_k_top_p(seed: int, device: str):
warpers = generation_model._get_logits_warper(generation_config)
assert len(warpers) == 2 # top_p and top_k
seq_group_metadata_list = []
seq_lens = []
seq_group_metadata_list: List[SequenceGroupMetadata] = []
seq_lens: List[int] = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
......@@ -622,6 +624,9 @@ def test_sampler_top_k_top_p(seed: int, device: str):
with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
assert sample_probs is not None
hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone())
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
assert torch.allclose(hf_probs, sample_probs, atol=1e-5)
......
......@@ -118,16 +118,17 @@ class AsyncLLM:
raise ValueError("The lengths of prompts and "
"sampling_params must be the same.")
async def get_output(prompt, sampling_param) -> str:
async def get_output(prompt, sampling_param) -> RequestOutput:
request_id = random_uuid()
results_generator = self.llm_engine.generate(
prompt, sampling_param, request_id)
final_output = None
async for request_output in results_generator:
final_output = request_output
assert final_output is not None
return final_output
outputs = []
outputs: List[RequestOutput] = []
try:
for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None
......@@ -208,8 +209,8 @@ def maybe_assert_ngram_worker(llm):
def get_output_from_llm_generator(
llm_generator, prompts,
sampling_params) -> Tuple[List[str], List[List[int]]]:
tokens = []
token_ids = []
tokens: List[str] = []
token_ids: List[List[int]] = []
for llm in llm_generator():
maybe_assert_ngram_worker(llm)
......@@ -300,8 +301,8 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
nvmlInit()
start_time = time.time()
while True:
output = {}
output_raw = {}
output: Dict[int, str] = {}
output_raw: Dict[int, float] = {}
for device in devices:
dev_handle = nvmlDeviceGetHandleByIndex(device)
mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
......
from typing import List
import pytest
import torch
......@@ -38,14 +40,14 @@ def test_get_token_ids_to_score(k: int):
device='cuda',
)
expected_output = [
expected_output: List[List[int]] = [
[],
]
for i in range(proposal_token_ids.shape[0]):
expected_output.append(proposal_token_ids[:i + 1].tolist())
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
actual_output = scorer._get_token_ids_to_score(proposal_token_ids) # pylint: disable=protected-access
actual_output = scorer._get_token_ids_to_score(proposal_token_ids.tolist()) # pylint: disable=protected-access
actual_output = [
x.tolist() if isinstance(x, torch.Tensor) else x for x in actual_output
......
import random
from typing import Dict, List
from unittest.mock import MagicMock
import pytest
import torch
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.sequence import ExecuteModelRequest, Logprob, SamplerOutput
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker
......@@ -210,7 +211,7 @@ def test_same_output_for_multi_step():
# Run single-step repeatedly.
zero_kv_cache(worker.cache_engine)
single_step_output = []
single_step_output: List[SamplerOutput] = []
continuations = [[1] for _ in prompts]
set_random_seed(seed)
......@@ -232,11 +233,15 @@ def test_same_output_for_multi_step():
continuations[i].append(seq_group_output.samples[0].output_token)
# Get token ids and logprobs for comparison.
multi_step_output_logprobs = [[] for _ in prompts]
single_step_output_logprobs = [[] for _ in prompts]
multi_step_output_token_ids = [[] for _ in prompts]
single_step_output_token_ids = [[] for _ in prompts]
multi_step_output_logprobs: List[List[Dict[int,
Logprob]]] = [[]
for _ in prompts]
single_step_output_logprobs: List[List[Dict[int,
Logprob]]] = [[]
for _ in prompts]
multi_step_output_token_ids: List[List[int]] = [[] for _ in prompts]
single_step_output_token_ids: List[List[int]] = [[] for _ in prompts]
for i, _ in enumerate(prompts):
for multi_step, single_step in zip(multi_step_output,
single_step_output):
......
import random
from types import SimpleNamespace
from typing import Dict, List
from unittest.mock import MagicMock
import pytest
......@@ -7,7 +8,7 @@ import torch
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.sequence import ExecuteModelRequest, SamplerOutput, SequenceOutput
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
SpecDecodeWorkerMetrics)
......@@ -103,7 +104,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k))
seen_contexts = []
seen_contexts: List[List[int]] = []
call_args_list = target_worker.execute_model.call_args_list
assert len(call_args_list) == 1
......@@ -116,7 +117,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
for seq_data in seq_group_metadata.seq_data.values():
seen_contexts.append(seq_data.get_token_ids())
expected_seen_contexts = []
expected_seen_contexts: List[List[int]] = []
for prompt, prev_generated, draft_tokens in zip(
prompts, prev_output_tokens, proposal_token_ids.tolist()):
......@@ -310,8 +311,14 @@ def test_correctly_formats_output(k: int, batch_size: int):
next(iter(seq_group_metadata.seq_data.keys()))
for seq_group_metadata in seq_group_metadata_list
]
actual_output_by_seq = {seq_id: [] for seq_id in seq_ids}
expected_output_by_seq = {seq_id: [] for seq_id in seq_ids}
actual_output_by_seq: Dict[int, List[SequenceOutput]] = {
seq_id: []
for seq_id in seq_ids
}
expected_output_by_seq: Dict[int, List[SequenceOutput]] = {
seq_id: []
for seq_id in seq_ids
}
for step in output:
for seq_group in step:
......
from itertools import count
from typing import Dict, Iterable, List, Optional, Union
from typing import Callable, Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import TypeVar, Union
from unittest.mock import MagicMock
import torch
......@@ -14,6 +16,8 @@ from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.worker import Worker
T = TypeVar("T", bound=Worker)
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
return (seq_len + block_size - 1) // block_size
......@@ -56,13 +60,13 @@ def zero_kv_cache(cache_engine: CacheEngine):
value_blocks.zero_()
def create_worker(cls: type,
def create_worker(cls: Callable[..., T],
model_name: str,
block_size: int,
num_gpu_blocks: int,
seed: int,
is_driver_worker: bool = True,
enforce_eager: bool = True):
enforce_eager: bool = True) -> T:
engine_args = EngineArgs(
model=model_name,
seed=seed,
......@@ -159,8 +163,8 @@ def assert_logprobs_dict_allclose(
def create_sampler_output_list(
token_ids: torch.Tensor,
probs: Iterable[Optional[torch.Tensor]],
logprobs: Iterable[Optional[torch.Tensor]],
probs: GenericSequence[Optional[torch.Tensor]],
logprobs: GenericSequence[Optional[torch.Tensor]],
seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]:
num_steps, batch_size = token_ids.shape
token_ids_by_step = token_ids.tolist()
......
......@@ -51,7 +51,7 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
max_input_length=None,
)
hashes = []
hashes: List[List[List[int]]] = []
for prefix in prefixes:
for lora_int_id in concurrent_lora_int_ids:
......
......@@ -47,6 +47,7 @@ def test_default_vllm_root_logger_configuration():
assert not logger.propagate
handler = logger.handlers[0]
assert isinstance(handler, logging.StreamHandler)
assert handler.stream == sys.stdout
assert handler.level == logging.INFO
......
......@@ -153,8 +153,8 @@ def test_decode_sequence_logprobs(complete_sequence: str,
# Run sequentially.
seq = create_sequence()
dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
sequential_logprobs_text_chosen_token = []
sequential_logprobs_text_other_token = []
sequential_logprobs_text_chosen_token: List[str] = []
sequential_logprobs_text_other_token: List[str] = []
for new_token, logprobs in zip(complete_sequence_token_ids,
dummy_logprobs):
seq.append_token_id(new_token, logprobs)
......
......@@ -79,7 +79,7 @@ class RemoteOpenAIServer:
self.host = str(args.host or 'localhost')
self.port = int(args.port)
self._runner = self._RemoteRunner.remote(
self._runner = self._RemoteRunner.remote( # type: ignore
cli_args,
wait_url=self.url_for("health"),
wait_timeout=self.MAX_SERVER_START_WAIT_S)
......
from typing import List
import pytest
import torch
......@@ -35,8 +37,8 @@ def test_prepare_prompt(batch_size):
enable_chunked_prefill=False,
)
seq_lens = []
seq_group_metadata_list = []
seq_lens: List[int] = []
seq_group_metadata_list: List[SequenceGroupMetadata] = []
block_tables = {0: [1]}
for i in range(batch_size):
# make sure all tokens fit into one block
......@@ -151,15 +153,14 @@ def test_prepare_decode_cuda_graph(batch_size):
enable_chunked_prefill=False,
)
context_lens = []
seq_group_metadata_list = []
context_lens: List[int] = []
seq_group_metadata_list: List[SequenceGroupMetadata] = []
# Assume each seq group finishes prefill.
for i in range(batch_size):
# make sure all tokens fit into one block
context_len = i % (model_runner.block_size - 1) + 1
context_lens.append(context_len)
seq_data = list(range(context_len))
seq_data = SequenceData(seq_data)
seq_data = SequenceData(list(range(context_len)))
seq_data.update_num_computed_tokens(context_len)
# Append one token ID since prefill is finished.
seq_data.append_token_id(1, 0)
......@@ -257,7 +258,7 @@ def test_empty_seq_group():
dtype="float16",
enforce_eager=False,
)
seq_group_metadata_list = []
seq_group_metadata_list: List[SequenceGroupMetadata] = []
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
input_tokens, input_positions, attn_metadata, slot_mapping = (
model_input.input_tokens,
......@@ -310,10 +311,10 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
)
# Add prefill requests.
seq_lens = []
seq_group_metadata_list = []
prefill_metadata_list = []
decode_metadata_list = []
seq_lens: List[int] = []
seq_group_metadata_list: List[SequenceGroupMetadata] = []
prefill_metadata_list: List[SequenceGroupMetadata] = []
decode_metadata_list: List[SequenceGroupMetadata] = []
block_tables = {0: [1]}
prefill_batch_size = batch_size // 2
decode_batch_size = batch_size - prefill_batch_size
......
......@@ -245,7 +245,7 @@ def _make_alibi_bias(
dtype: torch.dtype,
seq_lens: List[int],
) -> List[torch.Tensor]:
attn_biases = []
attn_biases: List[torch.Tensor] = []
for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
......@@ -271,7 +271,7 @@ def _make_sliding_window_bias(
window_size: Optional[int],
dtype: torch.dtype,
) -> List[torch.Tensor]:
attn_biases = []
attn_biases: List[torch.Tensor] = []
for seq_len in seq_lens:
tensor = torch.full(
(1, seq_len, seq_len),
......
......@@ -431,8 +431,8 @@ def _make_alibi_bias(
num_kv_heads: int,
dtype: torch.dtype,
seq_lens: List[int],
) -> LowerTriangularMaskWithTensorBias:
attn_biases = []
) -> List[AttentionBias]:
attn_biases: List[AttentionBias] = []
for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
......
......@@ -252,7 +252,7 @@ class BlockTable:
def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block],
token_ids: List[int],
device: Device) -> List[Block]:
blocks = []
blocks: List[Block] = []
for block_token_ids in chunk_list(token_ids, self._block_size):
if len(block_token_ids) == self._block_size:
# If the block is full, create an immutable block.
......
......@@ -111,7 +111,7 @@ class NaiveBlockAllocator(BlockAllocator):
"""
source_blocks = get_all_blocks_recursively(last_block)
forked_blocks = []
forked_blocks: List[Block] = []
prev_block = None
for block in source_blocks:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment