Unverified Commit cf069aa8 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated Python 3.8 typing (#13971)

parent bf33700e
...@@ -5,7 +5,6 @@ Run `pytest tests/quantization/test_configs.py --forked`. ...@@ -5,7 +5,6 @@ Run `pytest tests/quantization/test_configs.py --forked`.
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Tuple
import pytest import pytest
...@@ -53,7 +52,7 @@ MODEL_ARG_EXPTYPES = [ ...@@ -53,7 +52,7 @@ MODEL_ARG_EXPTYPES = [
@pytest.mark.parametrize("model_arg_exptype", MODEL_ARG_EXPTYPES) @pytest.mark.parametrize("model_arg_exptype", MODEL_ARG_EXPTYPES)
def test_auto_gptq(model_arg_exptype: Tuple[str, None, str]) -> None: def test_auto_gptq(model_arg_exptype: tuple[str, None, str]) -> None:
model_path, quantization_arg, expected_type = model_arg_exptype model_path, quantization_arg, expected_type = model_arg_exptype
try: try:
......
...@@ -5,7 +5,7 @@ See https://github.com/vllm-project/vllm/issues/11926 for more details. ...@@ -5,7 +5,7 @@ See https://github.com/vllm-project/vllm/issues/11926 for more details.
Run `pytest tests/quantization/test_register_quantization_config.py`. Run `pytest tests/quantization/test_register_quantization_config.py`.
""" """
from typing import Any, Dict, List, Optional from typing import Any, Optional
import pytest import pytest
import torch import torch
...@@ -58,7 +58,7 @@ class CustomQuantConfig(QuantizationConfig): ...@@ -58,7 +58,7 @@ class CustomQuantConfig(QuantizationConfig):
"""Name of the quantization method.""" """Name of the quantization method."""
return "custom_quant" return "custom_quant"
def get_supported_act_dtypes(self) -> List["torch.dtype"]: def get_supported_act_dtypes(self) -> list["torch.dtype"]:
"""List of supported activation dtypes.""" """List of supported activation dtypes."""
return [torch.float16, torch.bfloat16] return [torch.float16, torch.bfloat16]
...@@ -68,12 +68,12 @@ class CustomQuantConfig(QuantizationConfig): ...@@ -68,12 +68,12 @@ class CustomQuantConfig(QuantizationConfig):
return -1 return -1
@staticmethod @staticmethod
def get_config_filenames() -> List[str]: def get_config_filenames() -> list[str]:
"""List of filenames to search for in the model directory.""" """List of filenames to search for in the model directory."""
return [] return []
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "CustomQuantConfig": def from_config(cls, config: dict[str, Any]) -> "CustomQuantConfig":
"""Create a config class from the model's quantization config.""" """Create a config class from the model's quantization config."""
return CustomQuantConfig(num_bits=config.get("num_bits", 8)) return CustomQuantConfig(num_bits=config.get("num_bits", 8))
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List
import pytest import pytest
import torch import torch
...@@ -70,7 +68,7 @@ def test_get_prompt_logprobs( ...@@ -70,7 +68,7 @@ def test_get_prompt_logprobs(
assert (len(logprobs) == num_top_logprobs assert (len(logprobs) == num_top_logprobs
or len(logprobs) == num_top_logprobs + 1) or len(logprobs) == num_top_logprobs + 1)
output_text = result.outputs[0].text output_text = result.outputs[0].text
output_string_from_most_likely_tokens_lst: List[str] = [] output_string_from_most_likely_tokens_lst: list[str] = []
for top_logprobs in result.outputs[0].logprobs: for top_logprobs in result.outputs[0].logprobs:
top_logprob = next(iter(top_logprobs.values())) top_logprob = next(iter(top_logprobs.values()))
output_string_from_most_likely_tokens_lst.append( output_string_from_most_likely_tokens_lst.append(
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
Run `pytest tests/samplers/test_no_bad_words.py`. Run `pytest tests/samplers/test_no_bad_words.py`.
""" """
from typing import List, Optional from typing import Optional
from transformers import AutoTokenizer from transformers import AutoTokenizer
...@@ -16,8 +16,8 @@ def _generate( ...@@ -16,8 +16,8 @@ def _generate(
prompt: str, prompt: str,
num_prompt_tokens: int, num_prompt_tokens: int,
temperature: float = 0, temperature: float = 0,
bad_words: Optional[List[str]] = None, bad_words: Optional[list[str]] = None,
) -> List[int]: ) -> list[int]:
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=temperature, temperature=temperature,
bad_words=bad_words, bad_words=bad_words,
...@@ -59,7 +59,7 @@ class TestOneTokenBadWord: ...@@ -59,7 +59,7 @@ class TestOneTokenBadWord:
def _generate(self, def _generate(self,
model: LLM, model: LLM,
bad_words: Optional[List[str]] = None) -> List[int]: bad_words: Optional[list[str]] = None) -> list[int]:
return _generate( return _generate(
model=model, model=model,
prompt=self.PROMPT, prompt=self.PROMPT,
...@@ -69,7 +69,7 @@ class TestOneTokenBadWord: ...@@ -69,7 +69,7 @@ class TestOneTokenBadWord:
def _encode(self, def _encode(self,
prompt: str, prompt: str,
add_special_tokens: bool = True) -> List[int]: add_special_tokens: bool = True) -> list[int]:
return self.tokenizer(prompt, return self.tokenizer(prompt,
add_special_tokens=add_special_tokens).input_ids add_special_tokens=add_special_tokens).input_ids
...@@ -149,7 +149,7 @@ class TestTwoTokenBadWord: ...@@ -149,7 +149,7 @@ class TestTwoTokenBadWord:
def _generate(self, def _generate(self,
model: LLM, model: LLM,
bad_words: Optional[List[str]] = None) -> List[int]: bad_words: Optional[list[str]] = None) -> list[int]:
return _generate( return _generate(
model=model, model=model,
prompt=self.PROMPT, prompt=self.PROMPT,
...@@ -158,7 +158,7 @@ class TestTwoTokenBadWord: ...@@ -158,7 +158,7 @@ class TestTwoTokenBadWord:
) )
@staticmethod @staticmethod
def _contains(sequence: List[int], subsequence: List[int]) -> bool: def _contains(sequence: list[int], subsequence: list[int]) -> bool:
searched = False searched = False
for start in range(len(sequence)): for start in range(len(sequence)):
...@@ -181,6 +181,6 @@ class TestTwoTokenBadWord: ...@@ -181,6 +181,6 @@ class TestTwoTokenBadWord:
def _encode(self, def _encode(self,
prompt: str, prompt: str,
add_special_tokens: bool = True) -> List[int]: add_special_tokens: bool = True) -> list[int]:
return self.tokenizer(prompt, return self.tokenizer(prompt,
add_special_tokens=add_special_tokens).input_ids add_special_tokens=add_special_tokens).input_ids
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Tests for rejection sampling.""" """Tests for rejection sampling."""
from typing import List, Tuple
import pytest import pytest
import torch import torch
...@@ -416,8 +415,8 @@ def test_rejection_sampling_approximates_target_distribution( ...@@ -416,8 +415,8 @@ def test_rejection_sampling_approximates_target_distribution(
draft_and_target_probs_equal) draft_and_target_probs_equal)
sample_sizes = [10, 100, 1_000, 10_000, 100_000] sample_sizes = [10, 100, 1_000, 10_000, 100_000]
distance_wrt_reference: List[float] = [] distance_wrt_reference: list[float] = []
distance_wrt_target: List[float] = [] distance_wrt_target: list[float] = []
for num_samples in sample_sizes: for num_samples in sample_sizes:
(reference_vs_rejsample_dist, (reference_vs_rejsample_dist,
...@@ -452,7 +451,7 @@ def test_rejection_sampling_approximates_target_distribution( ...@@ -452,7 +451,7 @@ def test_rejection_sampling_approximates_target_distribution(
expected_improvement_multiplier) expected_improvement_multiplier)
def get_ratio_first_to_last(elements: List[float]) -> float: def get_ratio_first_to_last(elements: list[float]) -> float:
return elements[0] / elements[-1] return elements[0] / elements[-1]
...@@ -477,7 +476,7 @@ class _CorrectnessTestHelper: ...@@ -477,7 +476,7 @@ class _CorrectnessTestHelper:
def generate_probs_for_test( def generate_probs_for_test(
self, draft_and_target_probs_equal: bool self, draft_and_target_probs_equal: bool
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
draft_probs, target_probs = (F.softmax( draft_probs, target_probs = (F.softmax(
torch.rand(self.vocab_size, dtype=torch.float32), torch.rand(self.vocab_size, dtype=torch.float32),
dim=-1, dim=-1,
...@@ -499,7 +498,7 @@ class _CorrectnessTestHelper: ...@@ -499,7 +498,7 @@ class _CorrectnessTestHelper:
def run_and_compare_distributions(self, draft_probs: torch.Tensor, def run_and_compare_distributions(self, draft_probs: torch.Tensor,
target_probs: torch.Tensor, target_probs: torch.Tensor,
reference_probs: torch.Tensor, reference_probs: torch.Tensor,
num_samples: int) -> Tuple[float, float]: num_samples: int) -> tuple[float, float]:
# Sample using rejection sampling. # Sample using rejection sampling.
rej_sample_probs = self._estimate_rejection_sampling_pdf( rej_sample_probs = self._estimate_rejection_sampling_pdf(
draft_probs, target_probs, num_samples) draft_probs, target_probs, num_samples)
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import itertools import itertools
import random import random
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import Optional
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest import pytest
...@@ -30,7 +30,7 @@ class MockLogitsSampler(Sampler): ...@@ -30,7 +30,7 @@ class MockLogitsSampler(Sampler):
def _prepare_test( def _prepare_test(
batch_size: int batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler]: ) -> tuple[torch.Tensor, torch.Tensor, MockLogitsSampler]:
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
fake_logits = torch.full((batch_size, VOCAB_SIZE), fake_logits = torch.full((batch_size, VOCAB_SIZE),
1e-2, 1e-2,
...@@ -53,8 +53,8 @@ def _do_sample( ...@@ -53,8 +53,8 @@ def _do_sample(
sampling_params: SamplingParams, sampling_params: SamplingParams,
device: str, device: str,
): ):
seq_group_metadata_list: List[SequenceGroupMetadata] = [] seq_group_metadata_list: list[SequenceGroupMetadata] = []
seq_lens: List[int] = [] seq_lens: list[int] = []
for i in range(batch_size): for i in range(batch_size):
seq_group_metadata_list.append( seq_group_metadata_list.append(
SequenceGroupMetadata( SequenceGroupMetadata(
...@@ -171,7 +171,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): ...@@ -171,7 +171,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
def create_sampling_params(min_tokens, def create_sampling_params(min_tokens,
eos_token_id=0, eos_token_id=0,
*, *,
stop_token_ids: Optional[List[int]] = None, stop_token_ids: Optional[list[int]] = None,
prompt_logprobs: Optional[int] = None): prompt_logprobs: Optional[int] = None):
sampling_params = SamplingParams( sampling_params = SamplingParams(
min_tokens=min_tokens, min_tokens=min_tokens,
...@@ -196,7 +196,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): ...@@ -196,7 +196,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
batch_size = random.randint(1, 128) batch_size = random.randint(1, 128)
expected_penalization = [] expected_penalization = []
sequence_metadata_list: List[SequenceGroupMetadata] = [] sequence_metadata_list: list[SequenceGroupMetadata] = []
# 20% chance to generate seq group metadata list with all prompts # 20% chance to generate seq group metadata list with all prompts
is_prompt = random.random() < 0.2 is_prompt = random.random() < 0.2
while batch_size > 0: while batch_size > 0:
...@@ -216,8 +216,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): ...@@ -216,8 +216,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
stop_token_ids=stop_token_ids) stop_token_ids=stop_token_ids)
seq_data: Dict[int, SequenceData] = {} seq_data: dict[int, SequenceData] = {}
seq_group_penalization: List[bool] = [] seq_group_penalization: list[bool] = []
for _ in range(num_seqs): for _ in range(num_seqs):
num_input = random.randint(1, 100) num_input = random.randint(1, 100)
num_generated = 0 if is_prompt else random.randint(1, 100) num_generated = 0 if is_prompt else random.randint(1, 100)
...@@ -376,16 +376,16 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): ...@@ -376,16 +376,16 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
else: else:
test_cases = [generate_test_case()] test_cases = [generate_test_case()]
def run_test_case(*, expected_penalization: List[bool], def run_test_case(*, expected_penalization: list[bool],
seq_group_metadata_list: List[SequenceGroupMetadata]): seq_group_metadata_list: list[SequenceGroupMetadata]):
assert expected_penalization, \ assert expected_penalization, \
"Invalid test case, need expected_penalization" "Invalid test case, need expected_penalization"
assert seq_group_metadata_list, \ assert seq_group_metadata_list, \
"Invalid test case, need seq_group_metadata_list" "Invalid test case, need seq_group_metadata_list"
batch_size = 0 batch_size = 0
seq_lens: List[int] = [] seq_lens: list[int] = []
sampling_params_per_row: List[SamplingParams] = [] sampling_params_per_row: list[SamplingParams] = []
for sgm in seq_group_metadata_list: for sgm in seq_group_metadata_list:
sampling_params = sgm.sampling_params sampling_params = sgm.sampling_params
...@@ -456,11 +456,11 @@ def test_sampler_mixed(seed: int, device: str): ...@@ -456,11 +456,11 @@ def test_sampler_mixed(seed: int, device: str):
batch_size = random.randint(1, 256) batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler = _prepare_test(batch_size) input_tensor, fake_logits, sampler = _prepare_test(batch_size)
seq_group_metadata_list: List[SequenceGroupMetadata] = [] seq_group_metadata_list: list[SequenceGroupMetadata] = []
expected_tokens: List[Optional[List[int]]] = [] expected_tokens: list[Optional[list[int]]] = []
seq_lens: List[int] = [] seq_lens: list[int] = []
for i in range(batch_size): for i in range(batch_size):
expected: Optional[List[int]] = None expected: Optional[list[int]] = None
sampling_type = random.randint(0, 2) sampling_type = random.randint(0, 2)
if sampling_type == 0: if sampling_type == 0:
sampling_params = SamplingParams(temperature=0) sampling_params = SamplingParams(temperature=0)
...@@ -492,7 +492,7 @@ def test_sampler_mixed(seed: int, device: str): ...@@ -492,7 +492,7 @@ def test_sampler_mixed(seed: int, device: str):
)) ))
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
generators: Dict[str, torch.Generator] = {} generators: dict[str, torch.Generator] = {}
def test_sampling(): def test_sampling():
sampling_metadata = SamplingMetadata.prepare( sampling_metadata = SamplingMetadata.prepare(
...@@ -587,8 +587,8 @@ def test_sampler_top_k_top_p(seed: int, device: str): ...@@ -587,8 +587,8 @@ def test_sampler_top_k_top_p(seed: int, device: str):
device=device) device=device)
assert len(processors) == 2 # top_p and top_k assert len(processors) == 2 # top_p and top_k
seq_group_metadata_list: List[SequenceGroupMetadata] = [] seq_group_metadata_list: list[SequenceGroupMetadata] = []
seq_lens: List[int] = [] seq_lens: list[int] = []
for i in range(batch_size): for i in range(batch_size):
seq_group_metadata_list.append( seq_group_metadata_list.append(
SequenceGroupMetadata( SequenceGroupMetadata(
...@@ -669,10 +669,10 @@ def test_sampler_repetition_penalty_mixed(device: str): ...@@ -669,10 +669,10 @@ def test_sampler_repetition_penalty_mixed(device: str):
vocab_size = 8 vocab_size = 8
def test_sampling_params(sampling_params: List[SamplingParams]): def test_sampling_params(sampling_params: list[SamplingParams]):
seq_group_metadata_list: List[SequenceGroupMetadata] = [] seq_group_metadata_list: list[SequenceGroupMetadata] = []
seq_lens: List[int] = [] seq_lens: list[int] = []
for i in range(2): for i in range(2):
seq_group_metadata_list.append( seq_group_metadata_list.append(
SequenceGroupMetadata( SequenceGroupMetadata(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from collections.abc import Sequence
from itertools import cycle from itertools import cycle
from typing import List, Optional, Sequence, Tuple, Union from typing import Optional, Union
import pytest import pytest
import torch import torch
...@@ -64,9 +65,9 @@ def maybe_assert_ngram_worker(llm): ...@@ -64,9 +65,9 @@ def maybe_assert_ngram_worker(llm):
def get_output_from_llm_generator( def get_output_from_llm_generator(
llm_generator, prompts, llm_generator, prompts,
sampling_params) -> Tuple[List[str], List[List[int]], float]: sampling_params) -> tuple[list[str], list[list[int]], float]:
tokens: List[str] = [] tokens: list[str] = []
token_ids: List[List[int]] = [] token_ids: list[list[int]] = []
acceptance_rate: float = -1.0 acceptance_rate: float = -1.0
for llm in llm_generator(): for llm in llm_generator():
maybe_assert_ngram_worker(llm) maybe_assert_ngram_worker(llm)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List
import pytest import pytest
import torch import torch
...@@ -42,7 +40,7 @@ def test_get_token_ids_to_score(k: int): ...@@ -42,7 +40,7 @@ def test_get_token_ids_to_score(k: int):
device='cuda', device='cuda',
) )
expected_output: List[List[int]] = [ expected_output: list[list[int]] = [
[], [],
] ]
for i in range(proposal_token_ids.shape[0]): for i in range(proposal_token_ids.shape[0]):
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import random import random
from typing import Dict, List
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
...@@ -221,7 +220,7 @@ def test_same_output_for_multi_step(): ...@@ -221,7 +220,7 @@ def test_same_output_for_multi_step():
# Run single-step repeatedly. # Run single-step repeatedly.
zero_kv_cache(worker.cache_engine) zero_kv_cache(worker.cache_engine)
single_step_output: List[SamplerOutput] = [] single_step_output: list[SamplerOutput] = []
continuations = [[1] for _ in prompts] continuations = [[1] for _ in prompts]
set_random_seed(seed) set_random_seed(seed)
...@@ -243,15 +242,15 @@ def test_same_output_for_multi_step(): ...@@ -243,15 +242,15 @@ def test_same_output_for_multi_step():
continuations[i].append(seq_group_output.samples[0].output_token) continuations[i].append(seq_group_output.samples[0].output_token)
# Get token ids and logprobs for comparison. # Get token ids and logprobs for comparison.
multi_step_output_logprobs: List[List[Dict[int, multi_step_output_logprobs: list[list[dict[int,
Logprob]]] = [[] Logprob]]] = [[]
for _ in prompts] for _ in prompts]
single_step_output_logprobs: List[List[Dict[int, single_step_output_logprobs: list[list[dict[int,
Logprob]]] = [[] Logprob]]] = [[]
for _ in prompts] for _ in prompts]
multi_step_output_token_ids: List[List[int]] = [[] for _ in prompts] multi_step_output_token_ids: list[list[int]] = [[] for _ in prompts]
single_step_output_token_ids: List[List[int]] = [[] for _ in prompts] single_step_output_token_ids: list[list[int]] = [[] for _ in prompts]
for i, _ in enumerate(prompts): for i, _ in enumerate(prompts):
for multi_step, single_step in zip(multi_step_output, for multi_step, single_step in zip(multi_step_output,
single_step_output): single_step_output):
...@@ -336,7 +335,7 @@ def test_multi_step_with_batch_expansion_correct_output(): ...@@ -336,7 +335,7 @@ def test_multi_step_with_batch_expansion_correct_output():
# will simulate the bonus token case with the second token # will simulate the bonus token case with the second token
# being the bonus token. # being the bonus token.
zero_kv_cache(worker.cache_engine) zero_kv_cache(worker.cache_engine)
single_step_output: List[SamplerOutput] = [] single_step_output: list[SamplerOutput] = []
set_random_seed(seed) set_random_seed(seed)
for _ in range(num_steps): for _ in range(num_steps):
seq_group_metadata_list = create_seq_group_metadata_from_prompts( seq_group_metadata_list = create_seq_group_metadata_from_prompts(
...@@ -430,7 +429,7 @@ def test_multi_step_with_batch_expansion_incorrect_output(): ...@@ -430,7 +429,7 @@ def test_multi_step_with_batch_expansion_incorrect_output():
# will simulate the bonus token case with the second token # will simulate the bonus token case with the second token
# being the bonus token. # being the bonus token.
zero_kv_cache(worker.cache_engine) zero_kv_cache(worker.cache_engine)
single_step_output: List[SamplerOutput] = [] single_step_output: list[SamplerOutput] = []
set_random_seed(seed) set_random_seed(seed)
for _ in range(num_steps): for _ in range(num_steps):
seq_group_metadata_list = create_seq_group_metadata_from_prompts( seq_group_metadata_list = create_seq_group_metadata_from_prompts(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import random import random
from typing import List
import pytest import pytest
import torch import torch
...@@ -15,7 +14,7 @@ from vllm.worker.worker import Worker ...@@ -15,7 +14,7 @@ from vllm.worker.worker import Worker
from .utils import create_batch, create_worker from .utils import create_batch, create_worker
def create_proposal(propose_lens: List[int], vocab_size: int, def create_proposal(propose_lens: list[int], vocab_size: int,
device: str) -> SpeculativeProposals: device: str) -> SpeculativeProposals:
batch_size = len(propose_lens) batch_size = len(propose_lens)
max_propose_len = max(propose_lens) max_propose_len = max(propose_lens)
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
import random import random
from collections import defaultdict from collections import defaultdict
from types import SimpleNamespace from types import SimpleNamespace
from typing import Dict, List, Set
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
...@@ -123,7 +122,7 @@ def test_batch_expansion_correctly_calls_target_model( ...@@ -123,7 +122,7 @@ def test_batch_expansion_correctly_calls_target_model(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k)) num_lookahead_slots=k))
seen_contexts: List[List[int]] = [] seen_contexts: list[list[int]] = []
call_args_list = target_worker.execute_model.call_args_list call_args_list = target_worker.execute_model.call_args_list
assert len(call_args_list) == 1 assert len(call_args_list) == 1
...@@ -136,7 +135,7 @@ def test_batch_expansion_correctly_calls_target_model( ...@@ -136,7 +135,7 @@ def test_batch_expansion_correctly_calls_target_model(
for seq_data in seq_group_metadata.seq_data.values(): for seq_data in seq_group_metadata.seq_data.values():
seen_contexts.append(seq_data.get_token_ids()) seen_contexts.append(seq_data.get_token_ids())
expected_seen_contexts: List[List[int]] = [] expected_seen_contexts: list[list[int]] = []
for prompt, prev_generated, draft_tokens in zip( for prompt, prev_generated, draft_tokens in zip(
prompts, prev_output_tokens, proposal_token_ids.tolist()): prompts, prev_output_tokens, proposal_token_ids.tolist()):
...@@ -338,11 +337,11 @@ def test_correctly_formats_output(k: int, batch_size: int, ...@@ -338,11 +337,11 @@ def test_correctly_formats_output(k: int, batch_size: int,
next(iter(seq_group_metadata.seq_data.keys())) next(iter(seq_group_metadata.seq_data.keys()))
for seq_group_metadata in seq_group_metadata_list for seq_group_metadata in seq_group_metadata_list
] ]
actual_output_by_seq: Dict[int, List[SequenceOutput]] = { actual_output_by_seq: dict[int, list[SequenceOutput]] = {
seq_id: [] seq_id: []
for seq_id in seq_ids for seq_id in seq_ids
} }
expected_output_by_seq: Dict[int, List[SequenceOutput]] = { expected_output_by_seq: dict[int, list[SequenceOutput]] = {
seq_id: [] seq_id: []
for seq_id in seq_ids for seq_id in seq_ids
} }
...@@ -728,7 +727,7 @@ def test_populate_seq_ids_with_bonus_tokens(): ...@@ -728,7 +727,7 @@ def test_populate_seq_ids_with_bonus_tokens():
size=(batch_size, (k + 1)), size=(batch_size, (k + 1)),
dtype=torch.int64, dtype=torch.int64,
device='cuda') device='cuda')
expected_request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set) expected_request_id_seq_ids_mapping: dict[str, set[int]] = defaultdict(set)
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
for seq_id in seq_group_metadata.seq_data: for seq_id in seq_group_metadata.seq_data:
expected_request_id_seq_ids_mapping[ expected_request_id_seq_ids_mapping[
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from collections.abc import Sequence as GenericSequence
from itertools import count from itertools import count
from typing import Callable, Dict, List, Optional from typing import Callable, Optional, TypeVar, Union
from typing import Sequence as GenericSequence
from typing import TypeVar, Union
from unittest.mock import MagicMock from unittest.mock import MagicMock
import torch import torch
...@@ -44,7 +43,7 @@ def mock_worker(cls=None, ...@@ -44,7 +43,7 @@ def mock_worker(cls=None,
return worker return worker
def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]): def patch_execute_model_with_seeds(worker: Worker, rand_seeds: list[int]):
seed_iter = iter(rand_seeds) seed_iter = iter(rand_seeds)
original_execute_model = worker.execute_model original_execute_model = worker.execute_model
...@@ -56,7 +55,7 @@ def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]): ...@@ -56,7 +55,7 @@ def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]):
return new_execute_model return new_execute_model
def zero_kv_cache(cache_engine: List[CacheEngine]): def zero_kv_cache(cache_engine: list[CacheEngine]):
assert cache_engine[0].gpu_cache assert cache_engine[0].gpu_cache
for key_blocks, value_blocks in cache_engine[0].gpu_cache: for key_blocks, value_blocks in cache_engine[0].gpu_cache:
key_blocks.zero_() key_blocks.zero_()
...@@ -106,13 +105,13 @@ def create_worker(cls: Callable[..., T], ...@@ -106,13 +105,13 @@ def create_worker(cls: Callable[..., T],
def create_seq_group_metadata_from_prompts( def create_seq_group_metadata_from_prompts(
prompts: List[List[int]], prompts: list[list[int]],
num_gpu_blocks: int, num_gpu_blocks: int,
block_size: int, block_size: int,
final_prompt_lens: List[int], final_prompt_lens: list[int],
continuations: Optional[List[List[int]]] = None, continuations: Optional[list[list[int]]] = None,
seq_ids: Optional[List[int]] = None, seq_ids: Optional[list[int]] = None,
) -> List[SequenceGroupMetadata]: ) -> list[SequenceGroupMetadata]:
if continuations is None: if continuations is None:
continuations = [[] for _ in prompts] continuations = [[] for _ in prompts]
...@@ -149,11 +148,11 @@ def create_seq_group_metadata_from_prompts( ...@@ -149,11 +148,11 @@ def create_seq_group_metadata_from_prompts(
def create_chunked_seq_group_metadata_from_prompt( def create_chunked_seq_group_metadata_from_prompt(
prompt: List[int], prompt: list[int],
num_gpu_blocks: int, num_gpu_blocks: int,
chunk_size: int, chunk_size: int,
block_size: int, block_size: int,
seq_id: Optional[int] = None) -> List[SequenceGroupMetadata]: seq_id: Optional[int] = None) -> list[SequenceGroupMetadata]:
if seq_id is None: if seq_id is None:
seq_id = 0 seq_id = 0
...@@ -184,8 +183,8 @@ def create_chunked_seq_group_metadata_from_prompt( ...@@ -184,8 +183,8 @@ def create_chunked_seq_group_metadata_from_prompt(
def assert_logprobs_dict_allclose( def assert_logprobs_dict_allclose(
actual_logprobs: List[Dict[int, Logprob]], actual_logprobs: list[dict[int, Logprob]],
expected_logprobs: List[Dict[int, Logprob]]) -> None: expected_logprobs: list[dict[int, Logprob]]) -> None:
for single_step_actual_logprobs, single_step_expected_logprobs in zip( for single_step_actual_logprobs, single_step_expected_logprobs in zip(
actual_logprobs, expected_logprobs): actual_logprobs, expected_logprobs):
assert set(single_step_actual_logprobs.keys()) == set( assert set(single_step_actual_logprobs.keys()) == set(
...@@ -202,7 +201,7 @@ def create_sampler_output_list( ...@@ -202,7 +201,7 @@ def create_sampler_output_list(
token_ids: torch.Tensor, token_ids: torch.Tensor,
probs: GenericSequence[Optional[torch.Tensor]], probs: GenericSequence[Optional[torch.Tensor]],
logprobs: GenericSequence[Optional[torch.Tensor]], logprobs: GenericSequence[Optional[torch.Tensor]],
seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]: seq_ids: Optional[list[int]] = None) -> list[SamplerOutput]:
num_steps, batch_size = token_ids.shape num_steps, batch_size = token_ids.shape
token_ids_by_step = token_ids.tolist() token_ids_by_step = token_ids.tolist()
...@@ -231,9 +230,9 @@ def create_sampler_output_list( ...@@ -231,9 +230,9 @@ def create_sampler_output_list(
def create_batch(batch_size, def create_batch(batch_size,
k, k,
prompt_len: Union[int, List[int]] = 10, prompt_len: Union[int, list[int]] = 10,
prev_output_token_len: int = 10, prev_output_token_len: int = 10,
seq_ids: Optional[List[int]] = None, seq_ids: Optional[list[int]] = None,
num_gpu_blocks: Optional[int] = None, num_gpu_blocks: Optional[int] = None,
block_size: Optional[int] = None, block_size: Optional[int] = None,
prefill_chunk_size: Optional[int] = None): prefill_chunk_size: Optional[int] = None):
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
Run `pytest tests/test_cache_block_hashing.py`. Run `pytest tests/test_cache_block_hashing.py`.
""" """
from typing import List, Optional from typing import Optional
import pytest import pytest
...@@ -44,7 +44,7 @@ def flatten_2d(li): ...@@ -44,7 +44,7 @@ def flatten_2d(li):
@pytest.mark.parametrize("concurrent_lora_int_ids", @pytest.mark.parametrize("concurrent_lora_int_ids",
[[None], [1], [None, 1], [None, 1, 2], [1, 2]]) [[None], [1], [None, 1], [None, 1, 2], [1, 2]])
def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int, def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
concurrent_lora_int_ids: List[Optional[int]]): concurrent_lora_int_ids: list[Optional[int]]):
tokenizer = TokenizerGroup( tokenizer = TokenizerGroup(
tokenizer_id="facebook/opt-125m", tokenizer_id="facebook/opt-125m",
...@@ -53,7 +53,7 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int, ...@@ -53,7 +53,7 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
max_input_length=None, max_input_length=None,
) )
hashes: List[List[List[int]]] = [] hashes: list[list[list[int]]] = []
for prefix in prefixes: for prefix in prefixes:
for lora_int_id in concurrent_lora_int_ids: for lora_int_id in concurrent_lora_int_ids:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List
import pytest import pytest
from vllm.inputs import zip_enc_dec_prompts from vllm.inputs import zip_enc_dec_prompts
...@@ -45,7 +43,7 @@ def test_parse_single_batch_string_consistent(string_input: str): ...@@ -45,7 +43,7 @@ def test_parse_single_batch_string_consistent(string_input: str):
@pytest.mark.parametrize('token_input', TOKEN_INPUTS) @pytest.mark.parametrize('token_input', TOKEN_INPUTS)
def test_parse_single_batch_token_consistent(token_input: List[int]): def test_parse_single_batch_token_consistent(token_input: list[int]):
assert parse_and_batch_prompt(token_input) \ assert parse_and_batch_prompt(token_input) \
== parse_and_batch_prompt([token_input]) == parse_and_batch_prompt([token_input])
......
...@@ -155,7 +155,7 @@ def test_an_error_is_raised_when_custom_logging_config_is_unexpected_json( ...@@ -155,7 +155,7 @@ def test_an_error_is_raised_when_custom_logging_config_is_unexpected_json(
with pytest.raises(ValueError) as ex_info: with pytest.raises(ValueError) as ex_info:
_configure_vllm_root_logger() _configure_vllm_root_logger()
assert ex_info.type == ValueError # noqa: E721 assert ex_info.type == ValueError # noqa: E721
assert "Invalid logging config. Expected Dict, got" in str(ex_info) assert "Invalid logging config. Expected dict, got" in str(ex_info)
@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 1) @patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 1)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import random import random
from typing import Tuple
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
...@@ -33,7 +32,7 @@ class MockLogitsProcessor(LogitsProcessor): ...@@ -33,7 +32,7 @@ class MockLogitsProcessor(LogitsProcessor):
def _prepare_test( def _prepare_test(
batch_size: int batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]: ) -> tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]:
vocab_size = 32000 vocab_size = 32000
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
fake_logits = torch.full((batch_size, vocab_size), fake_logits = torch.full((batch_size, vocab_size),
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import asyncio import asyncio
import os import os
import socket import socket
from typing import AsyncIterator, Tuple from collections.abc import AsyncIterator
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
...@@ -33,7 +33,7 @@ async def test_merge_async_iterators(): ...@@ -33,7 +33,7 @@ async def test_merge_async_iterators():
iterators = [mock_async_iterator(i) for i in range(3)] iterators = [mock_async_iterator(i) for i in range(3)]
merged_iterator = merge_async_iterators(*iterators) merged_iterator = merge_async_iterators(*iterators)
async def stream_output(generator: AsyncIterator[Tuple[int, str]]): async def stream_output(generator: AsyncIterator[tuple[int, str]]):
async for idx, output in generator: async for idx, output in generator:
print(f"idx: {idx}, output: {output}") print(f"idx: {idx}, output: {output}")
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, Generator, List, Optional from collections.abc import Generator
from typing import Any, Optional
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
...@@ -163,7 +164,7 @@ def detokenizer(tokenizer_name: str) -> Detokenizer: ...@@ -163,7 +164,7 @@ def detokenizer(tokenizer_name: str) -> Detokenizer:
@pytest.fixture(name="complete_sequence_token_ids") @pytest.fixture(name="complete_sequence_token_ids")
def create_complete_sequence_token_ids(complete_sequence: str, def create_complete_sequence_token_ids(complete_sequence: str,
tokenizer) -> List[int]: tokenizer) -> list[int]:
complete_sequence_token_ids = tokenizer(complete_sequence).input_ids complete_sequence_token_ids = tokenizer(complete_sequence).input_ids
return complete_sequence_token_ids return complete_sequence_token_ids
...@@ -178,7 +179,7 @@ def create_sequence(prompt_token_ids=None): ...@@ -178,7 +179,7 @@ def create_sequence(prompt_token_ids=None):
def create_dummy_logprobs( def create_dummy_logprobs(
complete_sequence_token_ids: List[int]) -> List[Dict[int, Logprob]]: complete_sequence_token_ids: list[int]) -> list[dict[int, Logprob]]:
return [{ return [{
token_id: Logprob(logprob=0.0), token_id: Logprob(logprob=0.0),
token_id + 1: Logprob(logprob=0.1) token_id + 1: Logprob(logprob=0.1)
...@@ -186,10 +187,10 @@ def create_dummy_logprobs( ...@@ -186,10 +187,10 @@ def create_dummy_logprobs(
def create_dummy_prompt_logprobs( def create_dummy_prompt_logprobs(
complete_sequence_token_ids: List[int] complete_sequence_token_ids: list[int]
) -> List[Optional[Dict[int, Any]]]: ) -> list[Optional[dict[int, Any]]]:
# logprob for the first prompt token is None. # logprob for the first prompt token is None.
logprobs: List[Optional[Dict[int, Any]]] = [None] logprobs: list[Optional[dict[int, Any]]] = [None]
logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:]) logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:])
return logprobs return logprobs
...@@ -198,7 +199,7 @@ def create_dummy_prompt_logprobs( ...@@ -198,7 +199,7 @@ def create_dummy_prompt_logprobs(
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS) @pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", [True, False], indirect=True) @pytest.mark.parametrize("skip_special_tokens", [True, False], indirect=True)
def test_decode_sequence_logprobs(complete_sequence: str, def test_decode_sequence_logprobs(complete_sequence: str,
complete_sequence_token_ids: List[int], complete_sequence_token_ids: list[int],
detokenizer: Detokenizer, detokenizer: Detokenizer,
skip_special_tokens: bool): skip_special_tokens: bool):
"""Verify Detokenizer decodes logprobs correctly.""" """Verify Detokenizer decodes logprobs correctly."""
...@@ -208,8 +209,8 @@ def test_decode_sequence_logprobs(complete_sequence: str, ...@@ -208,8 +209,8 @@ def test_decode_sequence_logprobs(complete_sequence: str,
# Run sequentially. # Run sequentially.
seq = create_sequence() seq = create_sequence()
dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids) dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
sequential_logprobs_text_chosen_token: List[str] = [] sequential_logprobs_text_chosen_token: list[str] = []
sequential_logprobs_text_other_token: List[str] = [] sequential_logprobs_text_other_token: list[str] = []
for new_token, logprobs in zip(complete_sequence_token_ids, for new_token, logprobs in zip(complete_sequence_token_ids,
dummy_logprobs): dummy_logprobs):
seq.append_token_id(new_token, logprobs) seq.append_token_id(new_token, logprobs)
...@@ -232,7 +233,7 @@ def test_decode_sequence_logprobs(complete_sequence: str, ...@@ -232,7 +233,7 @@ def test_decode_sequence_logprobs(complete_sequence: str,
@pytest.mark.parametrize("complete_sequence", TRUTH) @pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS) @pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int], def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int],
detokenizer: Detokenizer): detokenizer: Detokenizer):
"""Verify Detokenizer decodes prompt logprobs correctly.""" """Verify Detokenizer decodes prompt logprobs correctly."""
sampling_params = SamplingParams(skip_special_tokens=True, sampling_params = SamplingParams(skip_special_tokens=True,
...@@ -249,7 +250,7 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int], ...@@ -249,7 +250,7 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int],
dummy_logprobs, dummy_logprobs,
position_offset=0) position_offset=0)
# First logprob is None. # First logprob is None.
decoded_prompt_logprobs: List[Dict[int, Any]] = dummy_logprobs[ decoded_prompt_logprobs: list[dict[int, Any]] = dummy_logprobs[
1:] # type: ignore 1:] # type: ignore
# decoded_prompt_logprobs doesn't contain the first token. # decoded_prompt_logprobs doesn't contain the first token.
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import asyncio import asyncio
import os import os
import sys import sys
from typing import List, Optional from typing import Optional
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
...@@ -129,7 +129,7 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type): ...@@ -129,7 +129,7 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
def __init__(self, def __init__(self,
*args, *args,
fail_at: Optional[List[int]] = None, fail_at: Optional[list[int]] = None,
**kwargs): **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.i = 0 self.i = 0
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.transformers_utils.tokenizer_base import (TokenizerBase, from vllm.transformers_utils.tokenizer_base import (TokenizerBase,
...@@ -17,15 +17,15 @@ class TestTokenizer(TokenizerBase): ...@@ -17,15 +17,15 @@ class TestTokenizer(TokenizerBase):
return TestTokenizer() return TestTokenizer()
@property @property
def all_special_tokens_extended(self) -> List[str]: def all_special_tokens_extended(self) -> list[str]:
raise NotImplementedError() raise NotImplementedError()
@property @property
def all_special_tokens(self) -> List[str]: def all_special_tokens(self) -> list[str]:
raise NotImplementedError() raise NotImplementedError()
@property @property
def all_special_ids(self) -> List[int]: def all_special_ids(self) -> list[int]:
raise NotImplementedError() raise NotImplementedError()
@property @property
...@@ -58,7 +58,7 @@ class TestTokenizer(TokenizerBase): ...@@ -58,7 +58,7 @@ class TestTokenizer(TokenizerBase):
def __call__( def __call__(
self, self,
text: Union[str, List[str], List[int]], text: Union[str, list[str], list[int]],
text_pair: Optional[str] = None, text_pair: Optional[str] = None,
add_special_tokens: bool = False, add_special_tokens: bool = False,
truncation: bool = False, truncation: bool = False,
...@@ -66,10 +66,10 @@ class TestTokenizer(TokenizerBase): ...@@ -66,10 +66,10 @@ class TestTokenizer(TokenizerBase):
): ):
raise NotImplementedError() raise NotImplementedError()
def get_vocab(self) -> Dict[str, int]: def get_vocab(self) -> dict[str, int]:
raise NotImplementedError() raise NotImplementedError()
def get_added_vocab(self) -> Dict[str, int]: def get_added_vocab(self) -> dict[str, int]:
raise NotImplementedError() raise NotImplementedError()
def encode_one( def encode_one(
...@@ -77,33 +77,33 @@ class TestTokenizer(TokenizerBase): ...@@ -77,33 +77,33 @@ class TestTokenizer(TokenizerBase):
text: str, text: str,
truncation: bool = False, truncation: bool = False,
max_length: Optional[int] = None, max_length: Optional[int] = None,
) -> List[int]: ) -> list[int]:
raise NotImplementedError() raise NotImplementedError()
def encode(self, def encode(self,
text: str, text: str,
add_special_tokens: Optional[bool] = None) -> List[int]: add_special_tokens: Optional[bool] = None) -> list[int]:
raise NotImplementedError() raise NotImplementedError()
def apply_chat_template(self, def apply_chat_template(self,
messages: List["ChatCompletionMessageParam"], messages: list["ChatCompletionMessageParam"],
tools: Optional[List[Dict[str, Any]]] = None, tools: Optional[list[dict[str, Any]]] = None,
**kwargs) -> List[int]: **kwargs) -> list[int]:
raise NotImplementedError() raise NotImplementedError()
def convert_tokens_to_string(self, tokens: List[str]) -> str: def convert_tokens_to_string(self, tokens: list[str]) -> str:
raise NotImplementedError() raise NotImplementedError()
def decode(self, def decode(self,
ids: Union[List[int], int], ids: Union[list[int], int],
skip_special_tokens: bool = True) -> str: skip_special_tokens: bool = True) -> str:
raise NotImplementedError() raise NotImplementedError()
def convert_ids_to_tokens( def convert_ids_to_tokens(
self, self,
ids: List[int], ids: list[int],
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
) -> List[str]: ) -> list[str]:
raise NotImplementedError() raise NotImplementedError()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment