Unverified Commit cf069aa8 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated Python 3.8 typing (#13971)

parent bf33700e
......@@ -5,7 +5,6 @@ Run `pytest tests/quantization/test_configs.py --forked`.
"""
from dataclasses import dataclass
from typing import Tuple
import pytest
......@@ -53,7 +52,7 @@ MODEL_ARG_EXPTYPES = [
@pytest.mark.parametrize("model_arg_exptype", MODEL_ARG_EXPTYPES)
def test_auto_gptq(model_arg_exptype: Tuple[str, None, str]) -> None:
def test_auto_gptq(model_arg_exptype: tuple[str, None, str]) -> None:
model_path, quantization_arg, expected_type = model_arg_exptype
try:
......
......@@ -5,7 +5,7 @@ See https://github.com/vllm-project/vllm/issues/11926 for more details.
Run `pytest tests/quantization/test_register_quantization_config.py`.
"""
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import pytest
import torch
......@@ -58,7 +58,7 @@ class CustomQuantConfig(QuantizationConfig):
"""Name of the quantization method."""
return "custom_quant"
def get_supported_act_dtypes(self) -> List["torch.dtype"]:
def get_supported_act_dtypes(self) -> list["torch.dtype"]:
"""List of supported activation dtypes."""
return [torch.float16, torch.bfloat16]
......@@ -68,12 +68,12 @@ class CustomQuantConfig(QuantizationConfig):
return -1
@staticmethod
def get_config_filenames() -> List[str]:
def get_config_filenames() -> list[str]:
"""List of filenames to search for in the model directory."""
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "CustomQuantConfig":
def from_config(cls, config: dict[str, Any]) -> "CustomQuantConfig":
"""Create a config class from the model's quantization config."""
return CustomQuantConfig(num_bits=config.get("num_bits", 8))
......
# SPDX-License-Identifier: Apache-2.0
from typing import List
import pytest
import torch
......@@ -70,7 +68,7 @@ def test_get_prompt_logprobs(
assert (len(logprobs) == num_top_logprobs
or len(logprobs) == num_top_logprobs + 1)
output_text = result.outputs[0].text
output_string_from_most_likely_tokens_lst: List[str] = []
output_string_from_most_likely_tokens_lst: list[str] = []
for top_logprobs in result.outputs[0].logprobs:
top_logprob = next(iter(top_logprobs.values()))
output_string_from_most_likely_tokens_lst.append(
......
......@@ -4,7 +4,7 @@
Run `pytest tests/samplers/test_no_bad_words.py`.
"""
from typing import List, Optional
from typing import Optional
from transformers import AutoTokenizer
......@@ -16,8 +16,8 @@ def _generate(
prompt: str,
num_prompt_tokens: int,
temperature: float = 0,
bad_words: Optional[List[str]] = None,
) -> List[int]:
bad_words: Optional[list[str]] = None,
) -> list[int]:
sampling_params = SamplingParams(
temperature=temperature,
bad_words=bad_words,
......@@ -59,7 +59,7 @@ class TestOneTokenBadWord:
def _generate(self,
model: LLM,
bad_words: Optional[List[str]] = None) -> List[int]:
bad_words: Optional[list[str]] = None) -> list[int]:
return _generate(
model=model,
prompt=self.PROMPT,
......@@ -69,7 +69,7 @@ class TestOneTokenBadWord:
def _encode(self,
prompt: str,
add_special_tokens: bool = True) -> List[int]:
add_special_tokens: bool = True) -> list[int]:
return self.tokenizer(prompt,
add_special_tokens=add_special_tokens).input_ids
......@@ -149,7 +149,7 @@ class TestTwoTokenBadWord:
def _generate(self,
model: LLM,
bad_words: Optional[List[str]] = None) -> List[int]:
bad_words: Optional[list[str]] = None) -> list[int]:
return _generate(
model=model,
prompt=self.PROMPT,
......@@ -158,7 +158,7 @@ class TestTwoTokenBadWord:
)
@staticmethod
def _contains(sequence: List[int], subsequence: List[int]) -> bool:
def _contains(sequence: list[int], subsequence: list[int]) -> bool:
searched = False
for start in range(len(sequence)):
......@@ -181,6 +181,6 @@ class TestTwoTokenBadWord:
def _encode(self,
prompt: str,
add_special_tokens: bool = True) -> List[int]:
add_special_tokens: bool = True) -> list[int]:
return self.tokenizer(prompt,
add_special_tokens=add_special_tokens).input_ids
# SPDX-License-Identifier: Apache-2.0
"""Tests for rejection sampling."""
from typing import List, Tuple
import pytest
import torch
......@@ -416,8 +415,8 @@ def test_rejection_sampling_approximates_target_distribution(
draft_and_target_probs_equal)
sample_sizes = [10, 100, 1_000, 10_000, 100_000]
distance_wrt_reference: List[float] = []
distance_wrt_target: List[float] = []
distance_wrt_reference: list[float] = []
distance_wrt_target: list[float] = []
for num_samples in sample_sizes:
(reference_vs_rejsample_dist,
......@@ -452,7 +451,7 @@ def test_rejection_sampling_approximates_target_distribution(
expected_improvement_multiplier)
def get_ratio_first_to_last(elements: List[float]) -> float:
def get_ratio_first_to_last(elements: list[float]) -> float:
return elements[0] / elements[-1]
......@@ -477,7 +476,7 @@ class _CorrectnessTestHelper:
def generate_probs_for_test(
self, draft_and_target_probs_equal: bool
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
draft_probs, target_probs = (F.softmax(
torch.rand(self.vocab_size, dtype=torch.float32),
dim=-1,
......@@ -499,7 +498,7 @@ class _CorrectnessTestHelper:
def run_and_compare_distributions(self, draft_probs: torch.Tensor,
target_probs: torch.Tensor,
reference_probs: torch.Tensor,
num_samples: int) -> Tuple[float, float]:
num_samples: int) -> tuple[float, float]:
# Sample using rejection sampling.
rej_sample_probs = self._estimate_rejection_sampling_pdf(
draft_probs, target_probs, num_samples)
......
......@@ -3,7 +3,7 @@
import itertools
import random
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from typing import Optional
from unittest.mock import Mock, patch
import pytest
......@@ -30,7 +30,7 @@ class MockLogitsSampler(Sampler):
def _prepare_test(
batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler]:
) -> tuple[torch.Tensor, torch.Tensor, MockLogitsSampler]:
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
fake_logits = torch.full((batch_size, VOCAB_SIZE),
1e-2,
......@@ -53,8 +53,8 @@ def _do_sample(
sampling_params: SamplingParams,
device: str,
):
seq_group_metadata_list: List[SequenceGroupMetadata] = []
seq_lens: List[int] = []
seq_group_metadata_list: list[SequenceGroupMetadata] = []
seq_lens: list[int] = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
......@@ -171,7 +171,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
def create_sampling_params(min_tokens,
eos_token_id=0,
*,
stop_token_ids: Optional[List[int]] = None,
stop_token_ids: Optional[list[int]] = None,
prompt_logprobs: Optional[int] = None):
sampling_params = SamplingParams(
min_tokens=min_tokens,
......@@ -196,7 +196,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
batch_size = random.randint(1, 128)
expected_penalization = []
sequence_metadata_list: List[SequenceGroupMetadata] = []
sequence_metadata_list: list[SequenceGroupMetadata] = []
# 20% chance to generate seq group metadata list with all prompts
is_prompt = random.random() < 0.2
while batch_size > 0:
......@@ -216,8 +216,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
eos_token_id=eos_token_id,
stop_token_ids=stop_token_ids)
seq_data: Dict[int, SequenceData] = {}
seq_group_penalization: List[bool] = []
seq_data: dict[int, SequenceData] = {}
seq_group_penalization: list[bool] = []
for _ in range(num_seqs):
num_input = random.randint(1, 100)
num_generated = 0 if is_prompt else random.randint(1, 100)
......@@ -376,16 +376,16 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
else:
test_cases = [generate_test_case()]
def run_test_case(*, expected_penalization: List[bool],
seq_group_metadata_list: List[SequenceGroupMetadata]):
def run_test_case(*, expected_penalization: list[bool],
seq_group_metadata_list: list[SequenceGroupMetadata]):
assert expected_penalization, \
"Invalid test case, need expected_penalization"
assert seq_group_metadata_list, \
"Invalid test case, need seq_group_metadata_list"
batch_size = 0
seq_lens: List[int] = []
sampling_params_per_row: List[SamplingParams] = []
seq_lens: list[int] = []
sampling_params_per_row: list[SamplingParams] = []
for sgm in seq_group_metadata_list:
sampling_params = sgm.sampling_params
......@@ -456,11 +456,11 @@ def test_sampler_mixed(seed: int, device: str):
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler = _prepare_test(batch_size)
seq_group_metadata_list: List[SequenceGroupMetadata] = []
expected_tokens: List[Optional[List[int]]] = []
seq_lens: List[int] = []
seq_group_metadata_list: list[SequenceGroupMetadata] = []
expected_tokens: list[Optional[list[int]]] = []
seq_lens: list[int] = []
for i in range(batch_size):
expected: Optional[List[int]] = None
expected: Optional[list[int]] = None
sampling_type = random.randint(0, 2)
if sampling_type == 0:
sampling_params = SamplingParams(temperature=0)
......@@ -492,7 +492,7 @@ def test_sampler_mixed(seed: int, device: str):
))
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
generators: Dict[str, torch.Generator] = {}
generators: dict[str, torch.Generator] = {}
def test_sampling():
sampling_metadata = SamplingMetadata.prepare(
......@@ -587,8 +587,8 @@ def test_sampler_top_k_top_p(seed: int, device: str):
device=device)
assert len(processors) == 2 # top_p and top_k
seq_group_metadata_list: List[SequenceGroupMetadata] = []
seq_lens: List[int] = []
seq_group_metadata_list: list[SequenceGroupMetadata] = []
seq_lens: list[int] = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
......@@ -669,10 +669,10 @@ def test_sampler_repetition_penalty_mixed(device: str):
vocab_size = 8
def test_sampling_params(sampling_params: List[SamplingParams]):
def test_sampling_params(sampling_params: list[SamplingParams]):
seq_group_metadata_list: List[SequenceGroupMetadata] = []
seq_lens: List[int] = []
seq_group_metadata_list: list[SequenceGroupMetadata] = []
seq_lens: list[int] = []
for i in range(2):
seq_group_metadata_list.append(
SequenceGroupMetadata(
......
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Sequence
from itertools import cycle
from typing import List, Optional, Sequence, Tuple, Union
from typing import Optional, Union
import pytest
import torch
......@@ -64,9 +65,9 @@ def maybe_assert_ngram_worker(llm):
def get_output_from_llm_generator(
llm_generator, prompts,
sampling_params) -> Tuple[List[str], List[List[int]], float]:
tokens: List[str] = []
token_ids: List[List[int]] = []
sampling_params) -> tuple[list[str], list[list[int]], float]:
tokens: list[str] = []
token_ids: list[list[int]] = []
acceptance_rate: float = -1.0
for llm in llm_generator():
maybe_assert_ngram_worker(llm)
......
# SPDX-License-Identifier: Apache-2.0
from typing import List
import pytest
import torch
......@@ -42,7 +40,7 @@ def test_get_token_ids_to_score(k: int):
device='cuda',
)
expected_output: List[List[int]] = [
expected_output: list[list[int]] = [
[],
]
for i in range(proposal_token_ids.shape[0]):
......
# SPDX-License-Identifier: Apache-2.0
import random
from typing import Dict, List
from unittest.mock import MagicMock
import pytest
......@@ -221,7 +220,7 @@ def test_same_output_for_multi_step():
# Run single-step repeatedly.
zero_kv_cache(worker.cache_engine)
single_step_output: List[SamplerOutput] = []
single_step_output: list[SamplerOutput] = []
continuations = [[1] for _ in prompts]
set_random_seed(seed)
......@@ -243,15 +242,15 @@ def test_same_output_for_multi_step():
continuations[i].append(seq_group_output.samples[0].output_token)
# Get token ids and logprobs for comparison.
multi_step_output_logprobs: List[List[Dict[int,
multi_step_output_logprobs: list[list[dict[int,
Logprob]]] = [[]
for _ in prompts]
single_step_output_logprobs: List[List[Dict[int,
single_step_output_logprobs: list[list[dict[int,
Logprob]]] = [[]
for _ in prompts]
multi_step_output_token_ids: List[List[int]] = [[] for _ in prompts]
single_step_output_token_ids: List[List[int]] = [[] for _ in prompts]
multi_step_output_token_ids: list[list[int]] = [[] for _ in prompts]
single_step_output_token_ids: list[list[int]] = [[] for _ in prompts]
for i, _ in enumerate(prompts):
for multi_step, single_step in zip(multi_step_output,
single_step_output):
......@@ -336,7 +335,7 @@ def test_multi_step_with_batch_expansion_correct_output():
# will simulate the bonus token case with the second token
# being the bonus token.
zero_kv_cache(worker.cache_engine)
single_step_output: List[SamplerOutput] = []
single_step_output: list[SamplerOutput] = []
set_random_seed(seed)
for _ in range(num_steps):
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
......@@ -430,7 +429,7 @@ def test_multi_step_with_batch_expansion_incorrect_output():
# will simulate the bonus token case with the second token
# being the bonus token.
zero_kv_cache(worker.cache_engine)
single_step_output: List[SamplerOutput] = []
single_step_output: list[SamplerOutput] = []
set_random_seed(seed)
for _ in range(num_steps):
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
......
# SPDX-License-Identifier: Apache-2.0
import random
from typing import List
import pytest
import torch
......@@ -15,7 +14,7 @@ from vllm.worker.worker import Worker
from .utils import create_batch, create_worker
def create_proposal(propose_lens: List[int], vocab_size: int,
def create_proposal(propose_lens: list[int], vocab_size: int,
device: str) -> SpeculativeProposals:
batch_size = len(propose_lens)
max_propose_len = max(propose_lens)
......
......@@ -3,7 +3,6 @@
import random
from collections import defaultdict
from types import SimpleNamespace
from typing import Dict, List, Set
from unittest.mock import MagicMock
import pytest
......@@ -123,7 +122,7 @@ def test_batch_expansion_correctly_calls_target_model(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k))
seen_contexts: List[List[int]] = []
seen_contexts: list[list[int]] = []
call_args_list = target_worker.execute_model.call_args_list
assert len(call_args_list) == 1
......@@ -136,7 +135,7 @@ def test_batch_expansion_correctly_calls_target_model(
for seq_data in seq_group_metadata.seq_data.values():
seen_contexts.append(seq_data.get_token_ids())
expected_seen_contexts: List[List[int]] = []
expected_seen_contexts: list[list[int]] = []
for prompt, prev_generated, draft_tokens in zip(
prompts, prev_output_tokens, proposal_token_ids.tolist()):
......@@ -338,11 +337,11 @@ def test_correctly_formats_output(k: int, batch_size: int,
next(iter(seq_group_metadata.seq_data.keys()))
for seq_group_metadata in seq_group_metadata_list
]
actual_output_by_seq: Dict[int, List[SequenceOutput]] = {
actual_output_by_seq: dict[int, list[SequenceOutput]] = {
seq_id: []
for seq_id in seq_ids
}
expected_output_by_seq: Dict[int, List[SequenceOutput]] = {
expected_output_by_seq: dict[int, list[SequenceOutput]] = {
seq_id: []
for seq_id in seq_ids
}
......@@ -728,7 +727,7 @@ def test_populate_seq_ids_with_bonus_tokens():
size=(batch_size, (k + 1)),
dtype=torch.int64,
device='cuda')
expected_request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set)
expected_request_id_seq_ids_mapping: dict[str, set[int]] = defaultdict(set)
for seq_group_metadata in seq_group_metadata_list:
for seq_id in seq_group_metadata.seq_data:
expected_request_id_seq_ids_mapping[
......
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Sequence as GenericSequence
from itertools import count
from typing import Callable, Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import TypeVar, Union
from typing import Callable, Optional, TypeVar, Union
from unittest.mock import MagicMock
import torch
......@@ -44,7 +43,7 @@ def mock_worker(cls=None,
return worker
def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]):
def patch_execute_model_with_seeds(worker: Worker, rand_seeds: list[int]):
seed_iter = iter(rand_seeds)
original_execute_model = worker.execute_model
......@@ -56,7 +55,7 @@ def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]):
return new_execute_model
def zero_kv_cache(cache_engine: List[CacheEngine]):
def zero_kv_cache(cache_engine: list[CacheEngine]):
assert cache_engine[0].gpu_cache
for key_blocks, value_blocks in cache_engine[0].gpu_cache:
key_blocks.zero_()
......@@ -106,13 +105,13 @@ def create_worker(cls: Callable[..., T],
def create_seq_group_metadata_from_prompts(
prompts: List[List[int]],
prompts: list[list[int]],
num_gpu_blocks: int,
block_size: int,
final_prompt_lens: List[int],
continuations: Optional[List[List[int]]] = None,
seq_ids: Optional[List[int]] = None,
) -> List[SequenceGroupMetadata]:
final_prompt_lens: list[int],
continuations: Optional[list[list[int]]] = None,
seq_ids: Optional[list[int]] = None,
) -> list[SequenceGroupMetadata]:
if continuations is None:
continuations = [[] for _ in prompts]
......@@ -149,11 +148,11 @@ def create_seq_group_metadata_from_prompts(
def create_chunked_seq_group_metadata_from_prompt(
prompt: List[int],
prompt: list[int],
num_gpu_blocks: int,
chunk_size: int,
block_size: int,
seq_id: Optional[int] = None) -> List[SequenceGroupMetadata]:
seq_id: Optional[int] = None) -> list[SequenceGroupMetadata]:
if seq_id is None:
seq_id = 0
......@@ -184,8 +183,8 @@ def create_chunked_seq_group_metadata_from_prompt(
def assert_logprobs_dict_allclose(
actual_logprobs: List[Dict[int, Logprob]],
expected_logprobs: List[Dict[int, Logprob]]) -> None:
actual_logprobs: list[dict[int, Logprob]],
expected_logprobs: list[dict[int, Logprob]]) -> None:
for single_step_actual_logprobs, single_step_expected_logprobs in zip(
actual_logprobs, expected_logprobs):
assert set(single_step_actual_logprobs.keys()) == set(
......@@ -202,7 +201,7 @@ def create_sampler_output_list(
token_ids: torch.Tensor,
probs: GenericSequence[Optional[torch.Tensor]],
logprobs: GenericSequence[Optional[torch.Tensor]],
seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]:
seq_ids: Optional[list[int]] = None) -> list[SamplerOutput]:
num_steps, batch_size = token_ids.shape
token_ids_by_step = token_ids.tolist()
......@@ -231,9 +230,9 @@ def create_sampler_output_list(
def create_batch(batch_size,
k,
prompt_len: Union[int, List[int]] = 10,
prompt_len: Union[int, list[int]] = 10,
prev_output_token_len: int = 10,
seq_ids: Optional[List[int]] = None,
seq_ids: Optional[list[int]] = None,
num_gpu_blocks: Optional[int] = None,
block_size: Optional[int] = None,
prefill_chunk_size: Optional[int] = None):
......
......@@ -3,7 +3,7 @@
Run `pytest tests/test_cache_block_hashing.py`.
"""
from typing import List, Optional
from typing import Optional
import pytest
......@@ -44,7 +44,7 @@ def flatten_2d(li):
@pytest.mark.parametrize("concurrent_lora_int_ids",
[[None], [1], [None, 1], [None, 1, 2], [1, 2]])
def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
concurrent_lora_int_ids: List[Optional[int]]):
concurrent_lora_int_ids: list[Optional[int]]):
tokenizer = TokenizerGroup(
tokenizer_id="facebook/opt-125m",
......@@ -53,7 +53,7 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
max_input_length=None,
)
hashes: List[List[List[int]]] = []
hashes: list[list[list[int]]] = []
for prefix in prefixes:
for lora_int_id in concurrent_lora_int_ids:
......
# SPDX-License-Identifier: Apache-2.0
from typing import List
import pytest
from vllm.inputs import zip_enc_dec_prompts
......@@ -45,7 +43,7 @@ def test_parse_single_batch_string_consistent(string_input: str):
@pytest.mark.parametrize('token_input', TOKEN_INPUTS)
def test_parse_single_batch_token_consistent(token_input: List[int]):
def test_parse_single_batch_token_consistent(token_input: list[int]):
assert parse_and_batch_prompt(token_input) \
== parse_and_batch_prompt([token_input])
......
......@@ -155,7 +155,7 @@ def test_an_error_is_raised_when_custom_logging_config_is_unexpected_json(
with pytest.raises(ValueError) as ex_info:
_configure_vllm_root_logger()
assert ex_info.type == ValueError # noqa: E721
assert "Invalid logging config. Expected Dict, got" in str(ex_info)
assert "Invalid logging config. Expected dict, got" in str(ex_info)
@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 1)
......
# SPDX-License-Identifier: Apache-2.0
import random
from typing import Tuple
from unittest.mock import patch
import pytest
......@@ -33,7 +32,7 @@ class MockLogitsProcessor(LogitsProcessor):
def _prepare_test(
batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]:
) -> tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]:
vocab_size = 32000
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
fake_logits = torch.full((batch_size, vocab_size),
......
......@@ -3,7 +3,7 @@
import asyncio
import os
import socket
from typing import AsyncIterator, Tuple
from collections.abc import AsyncIterator
from unittest.mock import patch
import pytest
......@@ -33,7 +33,7 @@ async def test_merge_async_iterators():
iterators = [mock_async_iterator(i) for i in range(3)]
merged_iterator = merge_async_iterators(*iterators)
async def stream_output(generator: AsyncIterator[Tuple[int, str]]):
async def stream_output(generator: AsyncIterator[tuple[int, str]]):
async for idx, output in generator:
print(f"idx: {idx}, output: {output}")
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, Generator, List, Optional
from collections.abc import Generator
from typing import Any, Optional
import pytest
from transformers import AutoTokenizer
......@@ -163,7 +164,7 @@ def detokenizer(tokenizer_name: str) -> Detokenizer:
@pytest.fixture(name="complete_sequence_token_ids")
def create_complete_sequence_token_ids(complete_sequence: str,
tokenizer) -> List[int]:
tokenizer) -> list[int]:
complete_sequence_token_ids = tokenizer(complete_sequence).input_ids
return complete_sequence_token_ids
......@@ -178,7 +179,7 @@ def create_sequence(prompt_token_ids=None):
def create_dummy_logprobs(
complete_sequence_token_ids: List[int]) -> List[Dict[int, Logprob]]:
complete_sequence_token_ids: list[int]) -> list[dict[int, Logprob]]:
return [{
token_id: Logprob(logprob=0.0),
token_id + 1: Logprob(logprob=0.1)
......@@ -186,10 +187,10 @@ def create_dummy_logprobs(
def create_dummy_prompt_logprobs(
complete_sequence_token_ids: List[int]
) -> List[Optional[Dict[int, Any]]]:
complete_sequence_token_ids: list[int]
) -> list[Optional[dict[int, Any]]]:
# logprob for the first prompt token is None.
logprobs: List[Optional[Dict[int, Any]]] = [None]
logprobs: list[Optional[dict[int, Any]]] = [None]
logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:])
return logprobs
......@@ -198,7 +199,7 @@ def create_dummy_prompt_logprobs(
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", [True, False], indirect=True)
def test_decode_sequence_logprobs(complete_sequence: str,
complete_sequence_token_ids: List[int],
complete_sequence_token_ids: list[int],
detokenizer: Detokenizer,
skip_special_tokens: bool):
"""Verify Detokenizer decodes logprobs correctly."""
......@@ -208,8 +209,8 @@ def test_decode_sequence_logprobs(complete_sequence: str,
# Run sequentially.
seq = create_sequence()
dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
sequential_logprobs_text_chosen_token: List[str] = []
sequential_logprobs_text_other_token: List[str] = []
sequential_logprobs_text_chosen_token: list[str] = []
sequential_logprobs_text_other_token: list[str] = []
for new_token, logprobs in zip(complete_sequence_token_ids,
dummy_logprobs):
seq.append_token_id(new_token, logprobs)
......@@ -232,7 +233,7 @@ def test_decode_sequence_logprobs(complete_sequence: str,
@pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int],
def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int],
detokenizer: Detokenizer):
"""Verify Detokenizer decodes prompt logprobs correctly."""
sampling_params = SamplingParams(skip_special_tokens=True,
......@@ -249,7 +250,7 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int],
dummy_logprobs,
position_offset=0)
# First logprob is None.
decoded_prompt_logprobs: List[Dict[int, Any]] = dummy_logprobs[
decoded_prompt_logprobs: list[dict[int, Any]] = dummy_logprobs[
1:] # type: ignore
# decoded_prompt_logprobs doesn't contain the first token.
......
......@@ -3,7 +3,7 @@
import asyncio
import os
import sys
from typing import List, Optional
from typing import Optional
from unittest.mock import patch
import pytest
......@@ -129,7 +129,7 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
def __init__(self,
*args,
fail_at: Optional[List[int]] = None,
fail_at: Optional[list[int]] = None,
**kwargs):
super().__init__(*args, **kwargs)
self.i = 0
......
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.transformers_utils.tokenizer_base import (TokenizerBase,
......@@ -17,15 +17,15 @@ class TestTokenizer(TokenizerBase):
return TestTokenizer()
@property
def all_special_tokens_extended(self) -> List[str]:
def all_special_tokens_extended(self) -> list[str]:
raise NotImplementedError()
@property
def all_special_tokens(self) -> List[str]:
def all_special_tokens(self) -> list[str]:
raise NotImplementedError()
@property
def all_special_ids(self) -> List[int]:
def all_special_ids(self) -> list[int]:
raise NotImplementedError()
@property
......@@ -58,7 +58,7 @@ class TestTokenizer(TokenizerBase):
def __call__(
self,
text: Union[str, List[str], List[int]],
text: Union[str, list[str], list[int]],
text_pair: Optional[str] = None,
add_special_tokens: bool = False,
truncation: bool = False,
......@@ -66,10 +66,10 @@ class TestTokenizer(TokenizerBase):
):
raise NotImplementedError()
def get_vocab(self) -> Dict[str, int]:
def get_vocab(self) -> dict[str, int]:
raise NotImplementedError()
def get_added_vocab(self) -> Dict[str, int]:
def get_added_vocab(self) -> dict[str, int]:
raise NotImplementedError()
def encode_one(
......@@ -77,33 +77,33 @@ class TestTokenizer(TokenizerBase):
text: str,
truncation: bool = False,
max_length: Optional[int] = None,
) -> List[int]:
) -> list[int]:
raise NotImplementedError()
def encode(self,
text: str,
add_special_tokens: Optional[bool] = None) -> List[int]:
add_special_tokens: Optional[bool] = None) -> list[int]:
raise NotImplementedError()
def apply_chat_template(self,
messages: List["ChatCompletionMessageParam"],
tools: Optional[List[Dict[str, Any]]] = None,
**kwargs) -> List[int]:
messages: list["ChatCompletionMessageParam"],
tools: Optional[list[dict[str, Any]]] = None,
**kwargs) -> list[int]:
raise NotImplementedError()
def convert_tokens_to_string(self, tokens: List[str]) -> str:
def convert_tokens_to_string(self, tokens: list[str]) -> str:
raise NotImplementedError()
def decode(self,
ids: Union[List[int], int],
ids: Union[list[int], int],
skip_special_tokens: bool = True) -> str:
raise NotImplementedError()
def convert_ids_to_tokens(
self,
ids: List[int],
ids: list[int],
skip_special_tokens: bool = True,
) -> List[str]:
) -> list[str]:
raise NotImplementedError()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment