"vscode:/vscode.git/clone" did not exist on "9324e10275cce6e0fd189bf1ebb0c399d858e9e1"
Unverified Commit 0e9164b4 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[mypy] Enable type checking for test directory (#5017)

parent 1b8a0d71
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
Note: these tests will only pass on L4 GPU. Note: these tests will only pass on L4 GPU.
""" """
import os import os
from typing import List
import pytest import pytest
import torch import torch
...@@ -100,7 +101,7 @@ def test_models(example_prompts, model_name, kv_cache_dtype) -> None: ...@@ -100,7 +101,7 @@ def test_models(example_prompts, model_name, kv_cache_dtype) -> None:
] ]
params = SamplingParams(max_tokens=20, temperature=0) params = SamplingParams(max_tokens=20, temperature=0)
generations = [] generations: List[str] = []
# Note: these need to be run 1 at a time due to numerical precision, # Note: these need to be run 1 at a time due to numerical precision,
# since the expected strs were generated this way. # since the expected strs were generated this way.
for prompt in formatted_prompts: for prompt in formatted_prompts:
......
...@@ -2,8 +2,11 @@ ...@@ -2,8 +2,11 @@
Run `pytest tests/prefix_caching/test_prefix_caching.py`. Run `pytest tests/prefix_caching/test_prefix_caching.py`.
""" """
from typing import List
import pytest import pytest
from vllm.block import PhysicalTokenBlock
from vllm.core.block_manager_v1 import CachedBlockAllocator from vllm.core.block_manager_v1 import CachedBlockAllocator
from vllm.utils import Device from vllm.utils import Device
...@@ -43,7 +46,7 @@ def test_block_allocator( ...@@ -43,7 +46,7 @@ def test_block_allocator(
def test_eviction(num_blocks: int, ): def test_eviction(num_blocks: int, ):
block_size = 16 block_size = 16
block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks) block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks)
blocks = [] blocks: List[PhysicalTokenBlock] = []
for i in range(num_blocks): for i in range(num_blocks):
# use i as the block_hash # use i as the block_hash
......
...@@ -4,6 +4,7 @@ Run `pytest tests/quantization/test_configs.py --forked`. ...@@ -4,6 +4,7 @@ Run `pytest tests/quantization/test_configs.py --forked`.
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Tuple
import pytest import pytest
...@@ -51,7 +52,7 @@ MODEL_ARG_EXPTYPES = [ ...@@ -51,7 +52,7 @@ MODEL_ARG_EXPTYPES = [
@pytest.mark.parametrize("model_arg_exptype", MODEL_ARG_EXPTYPES) @pytest.mark.parametrize("model_arg_exptype", MODEL_ARG_EXPTYPES)
def test_auto_gptq(model_arg_exptype: str) -> None: def test_auto_gptq(model_arg_exptype: Tuple[str, None, str]) -> None:
model_path, quantization_arg, expected_type = model_arg_exptype model_path, quantization_arg, expected_type = model_arg_exptype
try: try:
......
from typing import List
import pytest import pytest
import torch import torch
...@@ -62,21 +64,22 @@ def test_get_prompt_logprobs( ...@@ -62,21 +64,22 @@ def test_get_prompt_logprobs(
for logprobs in result.outputs[0].logprobs: for logprobs in result.outputs[0].logprobs:
assert len(logprobs) == num_top_logprobs assert len(logprobs) == num_top_logprobs
output_text = result.outputs[0].text output_text = result.outputs[0].text
output_string_from_most_likely_tokens = [] output_string_from_most_likely_tokens_lst: List[str] = []
for top_logprobs in result.outputs[0].logprobs: for top_logprobs in result.outputs[0].logprobs:
top_logprob = next(iter(top_logprobs.values())) top_logprob = next(iter(top_logprobs.values()))
output_string_from_most_likely_tokens.append( output_string_from_most_likely_tokens_lst.append(
top_logprob.decoded_token) top_logprob.decoded_token)
if detokenize: if detokenize:
output_string_from_most_likely_tokens = "".join( output_string_from_most_likely_tokens = "".join(
output_string_from_most_likely_tokens) output_string_from_most_likely_tokens_lst)
assert output_text == output_string_from_most_likely_tokens, ( assert output_text == output_string_from_most_likely_tokens, (
"The output text from the top logprob for each token position " "The output text from the top logprob for each token position "
"should be the same as the output text in the result.") "should be the same as the output text in the result.")
else: else:
assert output_text == '' assert output_text == ''
assert output_string_from_most_likely_tokens == [None] * max_tokens assert output_string_from_most_likely_tokens_lst == ([None] *
max_tokens)
# The first prompt logprob is always None # The first prompt logprob is always None
assert result.prompt_logprobs[0] is None assert result.prompt_logprobs[0] is None
......
...@@ -246,8 +246,8 @@ def test_rejection_sampling_approximates_target_distribution( ...@@ -246,8 +246,8 @@ def test_rejection_sampling_approximates_target_distribution(
draft_and_target_probs_equal) draft_and_target_probs_equal)
sample_sizes = [10, 100, 1_000, 10_000, 100_000] sample_sizes = [10, 100, 1_000, 10_000, 100_000]
distance_wrt_reference = [] distance_wrt_reference: List[float] = []
distance_wrt_target = [] distance_wrt_target: List[float] = []
for num_samples in sample_sizes: for num_samples in sample_sizes:
(reference_vs_rejsample_dist, (reference_vs_rejsample_dist,
......
import itertools import itertools
import random import random
from typing import List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
...@@ -49,8 +49,8 @@ def _do_sample( ...@@ -49,8 +49,8 @@ def _do_sample(
sampling_params: SamplingParams, sampling_params: SamplingParams,
device: str, device: str,
): ):
seq_group_metadata_list = [] seq_group_metadata_list: List[SequenceGroupMetadata] = []
seq_lens = [] seq_lens: List[int] = []
for i in range(batch_size): for i in range(batch_size):
seq_group_metadata_list.append( seq_group_metadata_list.append(
SequenceGroupMetadata( SequenceGroupMetadata(
...@@ -212,7 +212,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): ...@@ -212,7 +212,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
batch_size = random.randint(1, 128) batch_size = random.randint(1, 128)
expected_penalization = [] expected_penalization = []
sequence_metadata_list = [] sequence_metadata_list: List[SequenceGroupMetadata] = []
# 20% chance to generate seq group metadata list with all prompts # 20% chance to generate seq group metadata list with all prompts
is_prompt = random.random() < 0.2 is_prompt = random.random() < 0.2
while batch_size > 0: while batch_size > 0:
...@@ -232,8 +232,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): ...@@ -232,8 +232,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
stop_token_ids=stop_token_ids) stop_token_ids=stop_token_ids)
seq_data = {} seq_data: Dict[int, SequenceData] = {}
seq_group_penalization = [] seq_group_penalization: List[bool] = []
for _ in range(num_seqs): for _ in range(num_seqs):
num_input = random.randint(1, 100) num_input = random.randint(1, 100)
num_generated = 0 if is_prompt else random.randint(1, 100) num_generated = 0 if is_prompt else random.randint(1, 100)
...@@ -392,17 +392,16 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): ...@@ -392,17 +392,16 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
else: else:
test_cases = [generate_test_case()] test_cases = [generate_test_case()]
def run_test_case(*, def run_test_case(*, expected_penalization: List[bool],
expected_penalization=None, seq_group_metadata_list: List[SequenceGroupMetadata]):
seq_group_metadata_list=None):
assert expected_penalization, \ assert expected_penalization, \
"Invalid test case, need expected_penalization" "Invalid test case, need expected_penalization"
assert seq_group_metadata_list, \ assert seq_group_metadata_list, \
"Invalid test case, need seq_group_metadata_list" "Invalid test case, need seq_group_metadata_list"
batch_size = 0 batch_size = 0
seq_lens = [] seq_lens: List[int] = []
sampling_params_per_row = [] sampling_params_per_row: List[SamplingParams] = []
for sgm in seq_group_metadata_list: for sgm in seq_group_metadata_list:
sampling_params = sgm.sampling_params sampling_params = sgm.sampling_params
...@@ -472,15 +471,15 @@ def test_sampler_mixed(seed: int, device: str): ...@@ -472,15 +471,15 @@ def test_sampler_mixed(seed: int, device: str):
batch_size = random.randint(1, 256) batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler = _prepare_test(batch_size) input_tensor, fake_logits, sampler = _prepare_test(batch_size)
seq_group_metadata_list = [] seq_group_metadata_list: List[SequenceGroupMetadata] = []
expected_tokens: List[Optional[List[int]]] = [] expected_tokens: List[Optional[List[int]]] = []
seq_lens = [] seq_lens: List[int] = []
for i in range(batch_size): for i in range(batch_size):
expected: Optional[List[int]] = None expected: Optional[List[int]] = None
sampling_type = random.randint(0, 3) sampling_type = random.randint(0, 3)
if sampling_type == 0: if sampling_type == 0:
sampling_params = SamplingParams(temperature=0) sampling_params = SamplingParams(temperature=0)
expected = [torch.argmax(fake_logits[i], dim=-1).item()] expected = [int(torch.argmax(fake_logits[i], dim=-1).item())]
elif sampling_type in (1, 2): elif sampling_type in (1, 2):
n = random.randint(1, 10) n = random.randint(1, 10)
sampling_params = SamplingParams( sampling_params = SamplingParams(
...@@ -536,15 +535,18 @@ def test_sampler_mixed(seed: int, device: str): ...@@ -536,15 +535,18 @@ def test_sampler_mixed(seed: int, device: str):
] ]
continue continue
expected_tokens_item = expected_tokens[i]
assert expected_tokens_item is not None
for n, nth_output in enumerate(sequence_output.samples): for n, nth_output in enumerate(sequence_output.samples):
if (metadata.sampling_params.temperature == 0 if (metadata.sampling_params.temperature == 0
or metadata.sampling_params.seed is not None): or metadata.sampling_params.seed is not None):
# Ensure exact matches for greedy or random with seed # Ensure exact matches for greedy or random with seed
assert nth_output.output_token == expected_tokens[i][n] assert nth_output.output_token == expected_tokens_item[n]
else: else:
# For non-seeded random check that one of the high-logit # For non-seeded random check that one of the high-logit
# tokens were chosen # tokens were chosen
assert nth_output.output_token in expected_tokens[i] assert nth_output.output_token in expected_tokens_item
# Test batch # Test batch
test_sampling() test_sampling()
...@@ -588,8 +590,8 @@ def test_sampler_top_k_top_p(seed: int, device: str): ...@@ -588,8 +590,8 @@ def test_sampler_top_k_top_p(seed: int, device: str):
warpers = generation_model._get_logits_warper(generation_config) warpers = generation_model._get_logits_warper(generation_config)
assert len(warpers) == 2 # top_p and top_k assert len(warpers) == 2 # top_p and top_k
seq_group_metadata_list = [] seq_group_metadata_list: List[SequenceGroupMetadata] = []
seq_lens = [] seq_lens: List[int] = []
for i in range(batch_size): for i in range(batch_size):
seq_group_metadata_list.append( seq_group_metadata_list.append(
SequenceGroupMetadata( SequenceGroupMetadata(
...@@ -622,6 +624,9 @@ def test_sampler_top_k_top_p(seed: int, device: str): ...@@ -622,6 +624,9 @@ def test_sampler_top_k_top_p(seed: int, device: str):
with patch("vllm.model_executor.layers.sampler._sample", mock_sample): with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
sampler(logits=fake_logits, sampling_metadata=sampling_metadata) sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
assert sample_probs is not None
hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone()) hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone())
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float) hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
assert torch.allclose(hf_probs, sample_probs, atol=1e-5) assert torch.allclose(hf_probs, sample_probs, atol=1e-5)
......
...@@ -118,16 +118,17 @@ class AsyncLLM: ...@@ -118,16 +118,17 @@ class AsyncLLM:
raise ValueError("The lengths of prompts and " raise ValueError("The lengths of prompts and "
"sampling_params must be the same.") "sampling_params must be the same.")
async def get_output(prompt, sampling_param) -> str: async def get_output(prompt, sampling_param) -> RequestOutput:
request_id = random_uuid() request_id = random_uuid()
results_generator = self.llm_engine.generate( results_generator = self.llm_engine.generate(
prompt, sampling_param, request_id) prompt, sampling_param, request_id)
final_output = None final_output = None
async for request_output in results_generator: async for request_output in results_generator:
final_output = request_output final_output = request_output
assert final_output is not None
return final_output return final_output
outputs = [] outputs: List[RequestOutput] = []
try: try:
for i in range(num_requests): for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None prompt = prompts[i] if prompts is not None else None
...@@ -208,8 +209,8 @@ def maybe_assert_ngram_worker(llm): ...@@ -208,8 +209,8 @@ def maybe_assert_ngram_worker(llm):
def get_output_from_llm_generator( def get_output_from_llm_generator(
llm_generator, prompts, llm_generator, prompts,
sampling_params) -> Tuple[List[str], List[List[int]]]: sampling_params) -> Tuple[List[str], List[List[int]]]:
tokens = [] tokens: List[str] = []
token_ids = [] token_ids: List[List[int]] = []
for llm in llm_generator(): for llm in llm_generator():
maybe_assert_ngram_worker(llm) maybe_assert_ngram_worker(llm)
...@@ -300,8 +301,8 @@ def wait_for_gpu_memory_to_clear(devices: List[int], ...@@ -300,8 +301,8 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
nvmlInit() nvmlInit()
start_time = time.time() start_time = time.time()
while True: while True:
output = {} output: Dict[int, str] = {}
output_raw = {} output_raw: Dict[int, float] = {}
for device in devices: for device in devices:
dev_handle = nvmlDeviceGetHandleByIndex(device) dev_handle = nvmlDeviceGetHandleByIndex(device)
mem_info = nvmlDeviceGetMemoryInfo(dev_handle) mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
......
from typing import List
import pytest import pytest
import torch import torch
...@@ -38,14 +40,14 @@ def test_get_token_ids_to_score(k: int): ...@@ -38,14 +40,14 @@ def test_get_token_ids_to_score(k: int):
device='cuda', device='cuda',
) )
expected_output = [ expected_output: List[List[int]] = [
[], [],
] ]
for i in range(proposal_token_ids.shape[0]): for i in range(proposal_token_ids.shape[0]):
expected_output.append(proposal_token_ids[:i + 1].tolist()) expected_output.append(proposal_token_ids[:i + 1].tolist())
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000) scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
actual_output = scorer._get_token_ids_to_score(proposal_token_ids) # pylint: disable=protected-access actual_output = scorer._get_token_ids_to_score(proposal_token_ids.tolist()) # pylint: disable=protected-access
actual_output = [ actual_output = [
x.tolist() if isinstance(x, torch.Tensor) else x for x in actual_output x.tolist() if isinstance(x, torch.Tensor) else x for x in actual_output
......
import random import random
from typing import Dict, List
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
import torch import torch
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest, Logprob, SamplerOutput
from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
...@@ -210,7 +211,7 @@ def test_same_output_for_multi_step(): ...@@ -210,7 +211,7 @@ def test_same_output_for_multi_step():
# Run single-step repeatedly. # Run single-step repeatedly.
zero_kv_cache(worker.cache_engine) zero_kv_cache(worker.cache_engine)
single_step_output = [] single_step_output: List[SamplerOutput] = []
continuations = [[1] for _ in prompts] continuations = [[1] for _ in prompts]
set_random_seed(seed) set_random_seed(seed)
...@@ -232,11 +233,15 @@ def test_same_output_for_multi_step(): ...@@ -232,11 +233,15 @@ def test_same_output_for_multi_step():
continuations[i].append(seq_group_output.samples[0].output_token) continuations[i].append(seq_group_output.samples[0].output_token)
# Get token ids and logprobs for comparison. # Get token ids and logprobs for comparison.
multi_step_output_logprobs = [[] for _ in prompts] multi_step_output_logprobs: List[List[Dict[int,
single_step_output_logprobs = [[] for _ in prompts] Logprob]]] = [[]
for _ in prompts]
multi_step_output_token_ids = [[] for _ in prompts] single_step_output_logprobs: List[List[Dict[int,
single_step_output_token_ids = [[] for _ in prompts] Logprob]]] = [[]
for _ in prompts]
multi_step_output_token_ids: List[List[int]] = [[] for _ in prompts]
single_step_output_token_ids: List[List[int]] = [[] for _ in prompts]
for i, _ in enumerate(prompts): for i, _ in enumerate(prompts):
for multi_step, single_step in zip(multi_step_output, for multi_step, single_step in zip(multi_step_output,
single_step_output): single_step_output):
......
import random import random
from types import SimpleNamespace from types import SimpleNamespace
from typing import Dict, List
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
...@@ -7,7 +8,7 @@ import torch ...@@ -7,7 +8,7 @@ import torch
from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest, SamplerOutput, SequenceOutput
from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.metrics import (AsyncMetricsCollector, from vllm.spec_decode.metrics import (AsyncMetricsCollector,
SpecDecodeWorkerMetrics) SpecDecodeWorkerMetrics)
...@@ -103,7 +104,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int): ...@@ -103,7 +104,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k)) num_lookahead_slots=k))
seen_contexts = [] seen_contexts: List[List[int]] = []
call_args_list = target_worker.execute_model.call_args_list call_args_list = target_worker.execute_model.call_args_list
assert len(call_args_list) == 1 assert len(call_args_list) == 1
...@@ -116,7 +117,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int): ...@@ -116,7 +117,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
for seq_data in seq_group_metadata.seq_data.values(): for seq_data in seq_group_metadata.seq_data.values():
seen_contexts.append(seq_data.get_token_ids()) seen_contexts.append(seq_data.get_token_ids())
expected_seen_contexts = [] expected_seen_contexts: List[List[int]] = []
for prompt, prev_generated, draft_tokens in zip( for prompt, prev_generated, draft_tokens in zip(
prompts, prev_output_tokens, proposal_token_ids.tolist()): prompts, prev_output_tokens, proposal_token_ids.tolist()):
...@@ -310,8 +311,14 @@ def test_correctly_formats_output(k: int, batch_size: int): ...@@ -310,8 +311,14 @@ def test_correctly_formats_output(k: int, batch_size: int):
next(iter(seq_group_metadata.seq_data.keys())) next(iter(seq_group_metadata.seq_data.keys()))
for seq_group_metadata in seq_group_metadata_list for seq_group_metadata in seq_group_metadata_list
] ]
actual_output_by_seq = {seq_id: [] for seq_id in seq_ids} actual_output_by_seq: Dict[int, List[SequenceOutput]] = {
expected_output_by_seq = {seq_id: [] for seq_id in seq_ids} seq_id: []
for seq_id in seq_ids
}
expected_output_by_seq: Dict[int, List[SequenceOutput]] = {
seq_id: []
for seq_id in seq_ids
}
for step in output: for step in output:
for seq_group in step: for seq_group in step:
......
from itertools import count from itertools import count
from typing import Dict, Iterable, List, Optional, Union from typing import Callable, Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import TypeVar, Union
from unittest.mock import MagicMock from unittest.mock import MagicMock
import torch import torch
...@@ -14,6 +16,8 @@ from vllm.utils import get_distributed_init_method, get_ip, get_open_port ...@@ -14,6 +16,8 @@ from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
T = TypeVar("T", bound=Worker)
def round_up_to_next_block(seq_len: int, block_size: int) -> int: def round_up_to_next_block(seq_len: int, block_size: int) -> int:
return (seq_len + block_size - 1) // block_size return (seq_len + block_size - 1) // block_size
...@@ -56,13 +60,13 @@ def zero_kv_cache(cache_engine: CacheEngine): ...@@ -56,13 +60,13 @@ def zero_kv_cache(cache_engine: CacheEngine):
value_blocks.zero_() value_blocks.zero_()
def create_worker(cls: type, def create_worker(cls: Callable[..., T],
model_name: str, model_name: str,
block_size: int, block_size: int,
num_gpu_blocks: int, num_gpu_blocks: int,
seed: int, seed: int,
is_driver_worker: bool = True, is_driver_worker: bool = True,
enforce_eager: bool = True): enforce_eager: bool = True) -> T:
engine_args = EngineArgs( engine_args = EngineArgs(
model=model_name, model=model_name,
seed=seed, seed=seed,
...@@ -159,8 +163,8 @@ def assert_logprobs_dict_allclose( ...@@ -159,8 +163,8 @@ def assert_logprobs_dict_allclose(
def create_sampler_output_list( def create_sampler_output_list(
token_ids: torch.Tensor, token_ids: torch.Tensor,
probs: Iterable[Optional[torch.Tensor]], probs: GenericSequence[Optional[torch.Tensor]],
logprobs: Iterable[Optional[torch.Tensor]], logprobs: GenericSequence[Optional[torch.Tensor]],
seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]: seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]:
num_steps, batch_size = token_ids.shape num_steps, batch_size = token_ids.shape
token_ids_by_step = token_ids.tolist() token_ids_by_step = token_ids.tolist()
......
...@@ -51,7 +51,7 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int, ...@@ -51,7 +51,7 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
max_input_length=None, max_input_length=None,
) )
hashes = [] hashes: List[List[List[int]]] = []
for prefix in prefixes: for prefix in prefixes:
for lora_int_id in concurrent_lora_int_ids: for lora_int_id in concurrent_lora_int_ids:
......
...@@ -47,6 +47,7 @@ def test_default_vllm_root_logger_configuration(): ...@@ -47,6 +47,7 @@ def test_default_vllm_root_logger_configuration():
assert not logger.propagate assert not logger.propagate
handler = logger.handlers[0] handler = logger.handlers[0]
assert isinstance(handler, logging.StreamHandler)
assert handler.stream == sys.stdout assert handler.stream == sys.stdout
assert handler.level == logging.INFO assert handler.level == logging.INFO
......
...@@ -153,8 +153,8 @@ def test_decode_sequence_logprobs(complete_sequence: str, ...@@ -153,8 +153,8 @@ def test_decode_sequence_logprobs(complete_sequence: str,
# Run sequentially. # Run sequentially.
seq = create_sequence() seq = create_sequence()
dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids) dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
sequential_logprobs_text_chosen_token = [] sequential_logprobs_text_chosen_token: List[str] = []
sequential_logprobs_text_other_token = [] sequential_logprobs_text_other_token: List[str] = []
for new_token, logprobs in zip(complete_sequence_token_ids, for new_token, logprobs in zip(complete_sequence_token_ids,
dummy_logprobs): dummy_logprobs):
seq.append_token_id(new_token, logprobs) seq.append_token_id(new_token, logprobs)
......
...@@ -79,7 +79,7 @@ class RemoteOpenAIServer: ...@@ -79,7 +79,7 @@ class RemoteOpenAIServer:
self.host = str(args.host or 'localhost') self.host = str(args.host or 'localhost')
self.port = int(args.port) self.port = int(args.port)
self._runner = self._RemoteRunner.remote( self._runner = self._RemoteRunner.remote( # type: ignore
cli_args, cli_args,
wait_url=self.url_for("health"), wait_url=self.url_for("health"),
wait_timeout=self.MAX_SERVER_START_WAIT_S) wait_timeout=self.MAX_SERVER_START_WAIT_S)
......
from typing import List
import pytest import pytest
import torch import torch
...@@ -35,8 +37,8 @@ def test_prepare_prompt(batch_size): ...@@ -35,8 +37,8 @@ def test_prepare_prompt(batch_size):
enable_chunked_prefill=False, enable_chunked_prefill=False,
) )
seq_lens = [] seq_lens: List[int] = []
seq_group_metadata_list = [] seq_group_metadata_list: List[SequenceGroupMetadata] = []
block_tables = {0: [1]} block_tables = {0: [1]}
for i in range(batch_size): for i in range(batch_size):
# make sure all tokens fit into one block # make sure all tokens fit into one block
...@@ -151,15 +153,14 @@ def test_prepare_decode_cuda_graph(batch_size): ...@@ -151,15 +153,14 @@ def test_prepare_decode_cuda_graph(batch_size):
enable_chunked_prefill=False, enable_chunked_prefill=False,
) )
context_lens = [] context_lens: List[int] = []
seq_group_metadata_list = [] seq_group_metadata_list: List[SequenceGroupMetadata] = []
# Assume each seq group finishes prefill. # Assume each seq group finishes prefill.
for i in range(batch_size): for i in range(batch_size):
# make sure all tokens fit into one block # make sure all tokens fit into one block
context_len = i % (model_runner.block_size - 1) + 1 context_len = i % (model_runner.block_size - 1) + 1
context_lens.append(context_len) context_lens.append(context_len)
seq_data = list(range(context_len)) seq_data = SequenceData(list(range(context_len)))
seq_data = SequenceData(seq_data)
seq_data.update_num_computed_tokens(context_len) seq_data.update_num_computed_tokens(context_len)
# Append one token ID since prefill is finished. # Append one token ID since prefill is finished.
seq_data.append_token_id(1, 0) seq_data.append_token_id(1, 0)
...@@ -257,7 +258,7 @@ def test_empty_seq_group(): ...@@ -257,7 +258,7 @@ def test_empty_seq_group():
dtype="float16", dtype="float16",
enforce_eager=False, enforce_eager=False,
) )
seq_group_metadata_list = [] seq_group_metadata_list: List[SequenceGroupMetadata] = []
model_input = model_runner._prepare_model_input(seq_group_metadata_list) model_input = model_runner._prepare_model_input(seq_group_metadata_list)
input_tokens, input_positions, attn_metadata, slot_mapping = ( input_tokens, input_positions, attn_metadata, slot_mapping = (
model_input.input_tokens, model_input.input_tokens,
...@@ -310,10 +311,10 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): ...@@ -310,10 +311,10 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
) )
# Add prefill requests. # Add prefill requests.
seq_lens = [] seq_lens: List[int] = []
seq_group_metadata_list = [] seq_group_metadata_list: List[SequenceGroupMetadata] = []
prefill_metadata_list = [] prefill_metadata_list: List[SequenceGroupMetadata] = []
decode_metadata_list = [] decode_metadata_list: List[SequenceGroupMetadata] = []
block_tables = {0: [1]} block_tables = {0: [1]}
prefill_batch_size = batch_size // 2 prefill_batch_size = batch_size // 2
decode_batch_size = batch_size - prefill_batch_size decode_batch_size = batch_size - prefill_batch_size
......
...@@ -245,7 +245,7 @@ def _make_alibi_bias( ...@@ -245,7 +245,7 @@ def _make_alibi_bias(
dtype: torch.dtype, dtype: torch.dtype,
seq_lens: List[int], seq_lens: List[int],
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
attn_biases = [] attn_biases: List[torch.Tensor] = []
for seq_len in seq_lens: for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype) bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses # NOTE(zhuohan): HF uses
...@@ -271,7 +271,7 @@ def _make_sliding_window_bias( ...@@ -271,7 +271,7 @@ def _make_sliding_window_bias(
window_size: Optional[int], window_size: Optional[int],
dtype: torch.dtype, dtype: torch.dtype,
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
attn_biases = [] attn_biases: List[torch.Tensor] = []
for seq_len in seq_lens: for seq_len in seq_lens:
tensor = torch.full( tensor = torch.full(
(1, seq_len, seq_len), (1, seq_len, seq_len),
......
...@@ -431,8 +431,8 @@ def _make_alibi_bias( ...@@ -431,8 +431,8 @@ def _make_alibi_bias(
num_kv_heads: int, num_kv_heads: int,
dtype: torch.dtype, dtype: torch.dtype,
seq_lens: List[int], seq_lens: List[int],
) -> LowerTriangularMaskWithTensorBias: ) -> List[AttentionBias]:
attn_biases = [] attn_biases: List[AttentionBias] = []
for seq_len in seq_lens: for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype) bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses # NOTE(zhuohan): HF uses
......
...@@ -252,7 +252,7 @@ class BlockTable: ...@@ -252,7 +252,7 @@ class BlockTable:
def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block], def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block],
token_ids: List[int], token_ids: List[int],
device: Device) -> List[Block]: device: Device) -> List[Block]:
blocks = [] blocks: List[Block] = []
for block_token_ids in chunk_list(token_ids, self._block_size): for block_token_ids in chunk_list(token_ids, self._block_size):
if len(block_token_ids) == self._block_size: if len(block_token_ids) == self._block_size:
# If the block is full, create an immutable block. # If the block is full, create an immutable block.
......
...@@ -111,7 +111,7 @@ class NaiveBlockAllocator(BlockAllocator): ...@@ -111,7 +111,7 @@ class NaiveBlockAllocator(BlockAllocator):
""" """
source_blocks = get_all_blocks_recursively(last_block) source_blocks = get_all_blocks_recursively(last_block)
forked_blocks = [] forked_blocks: List[Block] = []
prev_block = None prev_block = None
for block in source_blocks: for block in source_blocks:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment