Unverified Commit cf069aa8 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated Python 3.8 typing (#13971)

parent bf33700e
# SPDX-License-Identifier: Apache-2.0
from typing import List
import openai
import pytest
......@@ -45,7 +43,7 @@ async def test_chat_completion_without_tools(client: openai.AsyncOpenAI,
logprobs=False,
stream=True,
)
chunks: List[str] = []
chunks: list[str] = []
finish_reason_count = 0
role_sent: bool = False
......@@ -116,7 +114,7 @@ async def test_chat_completion_with_tools(client: openai.AsyncOpenAI,
stream=True,
)
chunks: List[str] = []
chunks: list[str] = []
finish_reason_count = 0
role_sent: bool = False
......
# SPDX-License-Identifier: Apache-2.0
import json
from typing import Generator, List, Optional
from collections.abc import Generator
from typing import Optional
import partial_json_parser
import pytest
......@@ -26,8 +27,8 @@ def jamba_tool_parser(jamba_tokenizer):
return JambaToolParser(jamba_tokenizer)
def assert_tool_calls(actual_tool_calls: List[ToolCall],
expected_tool_calls: List[ToolCall]):
def assert_tool_calls(actual_tool_calls: list[ToolCall],
expected_tool_calls: list[ToolCall]):
assert len(actual_tool_calls) == len(expected_tool_calls)
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
......@@ -218,10 +219,10 @@ def test_extract_tool_calls_streaming(jamba_tool_parser, jamba_tokenizer,
model_output, expected_tool_calls,
expected_content):
other_content: str = ''
function_names: List[str] = []
function_args_strs: List[str] = []
function_names: list[str] = []
function_args_strs: list[str] = []
tool_call_idx: int = -1
tool_call_ids: List[Optional[str]] = []
tool_call_ids: list[Optional[str]] = []
for delta_message in stream_delta_message_generator(
jamba_tool_parser, jamba_tokenizer, model_output):
......
# SPDX-License-Identifier: Apache-2.0
import json
from typing import Dict, List, Optional
from typing import Optional
import openai
import pytest
......@@ -54,7 +54,7 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
assert isinstance(tool_call.function.arguments, str)
parsed_arguments = json.loads(tool_call.function.arguments)
assert isinstance(parsed_arguments, Dict)
assert isinstance(parsed_arguments, dict)
assert isinstance(parsed_arguments.get("city"), str)
assert isinstance(parsed_arguments.get("state"), str)
......@@ -73,8 +73,8 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
role_name: Optional[str] = None
finish_reason_count: int = 0
tool_call_names: List[str] = []
tool_call_args: List[str] = []
tool_call_names: list[str] = []
tool_call_args: list[str] = []
tool_call_idx: int = -1
tool_call_id_count: int = 0
......@@ -180,7 +180,7 @@ async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI,
logprobs=False,
stream=True)
chunks: List[str] = []
chunks: list[str] = []
finish_reason_count = 0
role_sent: bool = False
......
# SPDX-License-Identifier: Apache-2.0
import json
from typing import Dict, List, Optional
from typing import Optional
import openai
import pytest
......@@ -44,7 +44,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
# make sure the arguments parse properly
parsed_arguments = json.loads(tool_calls[0].function.arguments)
assert isinstance(parsed_arguments, Dict)
assert isinstance(parsed_arguments, dict)
assert isinstance(parsed_arguments.get("city"), str)
assert isinstance(parsed_arguments.get("state"), str)
assert parsed_arguments.get("city") == "Dallas"
......@@ -117,7 +117,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
# validate arguments
streamed_args = json.loads(function_args_str)
assert isinstance(streamed_args, Dict)
assert isinstance(streamed_args, dict)
assert isinstance(streamed_args.get("city"), str)
assert isinstance(streamed_args.get("state"), str)
assert streamed_args.get("city") == "Dallas"
......@@ -128,7 +128,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
assert choice.message.role == role_name
assert choice.message.tool_calls[0].function.name == function_name
# compare streamed with non-streamed args Dict-wise, not string-wise
# compare streamed with non-streamed args dict-wise, not string-wise
# because character-to-character comparison might not work e.g. the tool
# call parser adding extra spaces or something like that. we care about the
# dicts matching not byte-wise match
......@@ -167,7 +167,7 @@ async def test_tool_call_with_results(client: openai.AsyncOpenAI):
logprobs=False,
stream=True)
chunks: List[str] = []
chunks: list[str] = []
finish_reason_count = 0
role_sent: bool = False
......
# SPDX-License-Identifier: Apache-2.0
from copy import deepcopy
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from openai.types.chat import (ChatCompletionMessageParam,
ChatCompletionToolParam)
......@@ -12,14 +12,14 @@ from tests.utils import VLLM_PATH
class ServerConfig(TypedDict, total=False):
model: str
arguments: List[str]
arguments: list[str]
system_prompt: Optional[str]
supports_parallel: Optional[bool]
supports_rocm: Optional[bool]
def patch_system_prompt(messages: List[Dict[str, Any]],
system_prompt: str) -> List[Dict[str, Any]]:
def patch_system_prompt(messages: list[dict[str, Any]],
system_prompt: str) -> list[dict[str, Any]]:
new_messages = deepcopy(messages)
if new_messages[0]["role"] == "system":
new_messages[0]["content"] = system_prompt
......@@ -28,8 +28,8 @@ def patch_system_prompt(messages: List[Dict[str, Any]],
return new_messages
def ensure_system_prompt(messages: List[Dict[str, Any]],
config: ServerConfig) -> List[Dict[str, Any]]:
def ensure_system_prompt(messages: list[dict[str, Any]],
config: ServerConfig) -> list[dict[str, Any]]:
prompt = config.get("system_prompt")
if prompt:
return patch_system_prompt(messages, prompt)
......@@ -39,9 +39,9 @@ def ensure_system_prompt(messages: List[Dict[str, Any]],
# universal args for all models go here. also good if you need to test locally
# and change type or KV cache quantization or something.
ARGS: List[str] = ["--enable-auto-tool-choice", "--max-model-len", "1024"]
ARGS: list[str] = ["--enable-auto-tool-choice", "--max-model-len", "1024"]
CONFIGS: Dict[str, ServerConfig] = {
CONFIGS: dict[str, ServerConfig] = {
"hermes": {
"model":
"NousResearch/Hermes-3-Llama-3.1-8B",
......@@ -205,7 +205,7 @@ SEARCH_TOOL: ChatCompletionToolParam = {
}
}
MESSAGES_WITHOUT_TOOLS: List[ChatCompletionMessageParam] = [{
MESSAGES_WITHOUT_TOOLS: list[ChatCompletionMessageParam] = [{
"role":
"user",
"content":
......@@ -222,14 +222,14 @@ MESSAGES_WITHOUT_TOOLS: List[ChatCompletionMessageParam] = [{
"Can you tell me a joke please?"
}]
MESSAGES_ASKING_FOR_TOOLS: List[ChatCompletionMessageParam] = [{
MESSAGES_ASKING_FOR_TOOLS: list[ChatCompletionMessageParam] = [{
"role":
"user",
"content":
"What is the weather in Dallas, Texas in Fahrenheit?"
}]
MESSAGES_WITH_TOOL_RESPONSE: List[ChatCompletionMessageParam] = [{
MESSAGES_WITH_TOOL_RESPONSE: list[ChatCompletionMessageParam] = [{
"role":
"user",
"content":
......@@ -258,7 +258,7 @@ MESSAGES_WITH_TOOL_RESPONSE: List[ChatCompletionMessageParam] = [{
"cloudy skies and a low chance of rain."
}]
MESSAGES_ASKING_FOR_PARALLEL_TOOLS: List[ChatCompletionMessageParam] = [{
MESSAGES_ASKING_FOR_PARALLEL_TOOLS: list[ChatCompletionMessageParam] = [{
"role":
"user",
"content":
......@@ -266,7 +266,7 @@ MESSAGES_ASKING_FOR_PARALLEL_TOOLS: List[ChatCompletionMessageParam] = [{
"Fahrenheit?"
}]
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE: List[ChatCompletionMessageParam] = [{
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE: list[ChatCompletionMessageParam] = [{
"role":
"user",
"content":
......
......@@ -2,8 +2,9 @@
import os
import threading
from collections.abc import Iterable
from concurrent import futures
from typing import Callable, Dict, Iterable, Literal
from typing import Callable, Literal
import grpc
import pytest
......@@ -25,7 +26,7 @@ FieldName = Literal['bool_value', 'string_value', 'int_value', 'double_value',
def decode_value(value: AnyValue):
field_decoders: Dict[FieldName, Callable] = {
field_decoders: dict[FieldName, Callable] = {
"bool_value": (lambda v: v.bool_value),
"string_value": (lambda v: v.string_value),
"int_value": (lambda v: v.int_value),
......
......@@ -11,7 +11,7 @@ import time
import warnings
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Type, Union
from typing import Any, Callable, Optional, Union
import openai
import pytest
......@@ -73,9 +73,9 @@ class RemoteOpenAIServer:
def __init__(self,
model: str,
vllm_serve_args: List[str],
vllm_serve_args: list[str],
*,
env_dict: Optional[Dict[str, str]] = None,
env_dict: Optional[dict[str, str]] = None,
auto_port: bool = True,
max_wait_seconds: Optional[float] = None) -> None:
if auto_port:
......@@ -183,7 +183,7 @@ def _test_completion(
client: openai.OpenAI,
model: str,
prompt: str,
token_ids: List[int],
token_ids: list[int],
):
results = []
......@@ -400,10 +400,10 @@ def _test_image_text(
def compare_two_settings(model: str,
arg1: List[str],
arg2: List[str],
env1: Optional[Dict[str, str]] = None,
env2: Optional[Dict[str, str]] = None,
arg1: list[str],
arg2: list[str],
env1: Optional[dict[str, str]] = None,
env2: Optional[dict[str, str]] = None,
*,
method: str = "generate",
max_wait_seconds: Optional[float] = None) -> None:
......@@ -429,8 +429,8 @@ def compare_two_settings(model: str,
def compare_all_settings(model: str,
all_args: List[List[str]],
all_envs: List[Optional[Dict[str, str]]],
all_args: list[list[str]],
all_envs: list[Optional[dict[str, str]]],
*,
method: str = "generate",
max_wait_seconds: Optional[float] = None) -> None:
......@@ -470,7 +470,7 @@ def compare_all_settings(model: str,
prompt = "Hello, my name is"
token_ids = tokenizer(prompt).input_ids
ref_results: List = []
ref_results: list = []
for i, (args, env) in enumerate(zip(all_args, all_envs)):
if can_force_load_format:
# we are comparing the results and
......@@ -481,7 +481,7 @@ def compare_all_settings(model: str,
# environment variable to force the load format,
# e.g. in quantization tests.
args = args + ["--load-format", envs.VLLM_TEST_FORCE_LOAD_FORMAT]
compare_results: List = []
compare_results: list = []
results = ref_results if i == 0 else compare_results
with RemoteOpenAIServer(model,
args,
......@@ -582,7 +582,7 @@ def multi_process_parallel(
@contextmanager
def error_on_warning(category: Type[Warning] = Warning):
def error_on_warning(category: type[Warning] = Warning):
"""
Within the scope of this context manager, tests will fail if any warning
of the given category is emitted.
......@@ -604,7 +604,7 @@ def get_physical_device_indices(devices):
@_nvml()
def wait_for_gpu_memory_to_clear(devices: List[int],
def wait_for_gpu_memory_to_clear(devices: list[int],
threshold_bytes: int,
timeout_s: float = 120) -> None:
# Use nvml instead of pytorch to reduce measurement error from torch cuda
......@@ -612,8 +612,8 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
devices = get_physical_device_indices(devices)
start_time = time.time()
while True:
output: Dict[int, str] = {}
output_raw: Dict[int, float] = {}
output: dict[int, str] = {}
output_raw: dict[int, float] = {}
for device in devices:
if current_platform.is_rocm():
dev_handle = amdsmi_get_processor_handles()[device]
......@@ -758,13 +758,13 @@ def multi_gpu_test(*, num_gpus: int):
async def completions_with_server_args(
prompts: List[str],
prompts: list[str],
model_name: str,
server_cli_args: List[str],
server_cli_args: list[str],
num_logprobs: Optional[int],
max_wait_seconds: int = 240,
max_tokens: Union[int, list] = 5,
) -> List[Completion]:
) -> list[Completion]:
'''Construct a remote OpenAI server, obtain an async client to the
server & invoke the completions API to obtain completions.
......@@ -807,7 +807,7 @@ async def completions_with_server_args(
return outputs
def get_client_text_generations(completions: List[Completion]) -> List[str]:
def get_client_text_generations(completions: list[Completion]) -> list[str]:
'''Extract generated tokens from the output of a
request made to an Open-AI-protocol completions endpoint.
'''
......@@ -816,7 +816,7 @@ def get_client_text_generations(completions: List[Completion]) -> List[str]:
def get_client_text_logprob_generations(
completions: List[Completion]) -> List[TextTextLogprobs]:
completions: list[Completion]) -> list[TextTextLogprobs]:
'''Operates on the output of a request made to an Open-AI-protocol
completions endpoint; obtains top-rank logprobs for each token in
each :class:`SequenceGroup`
......
# SPDX-License-Identifier: Apache-2.0
"""Compare the with and without prefix caching."""
from typing import List
import pytest
......@@ -434,7 +433,7 @@ def test_cache_blocks():
# Test that blocks are cached correctly for 2 full blocks from the start.
blocks = [KVCacheBlock(block_id=i) for i in range(2)]
block_hashes: List[BlockHashType] = []
block_hashes: list[BlockHashType] = []
block_pool.cache_full_blocks(
request=req,
......
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional
from typing import Optional
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
......@@ -48,9 +48,9 @@ def create_scheduler(
def create_requests(
num_requests: int,
num_tokens: int = 10,
mm_positions: Optional[List[PlaceholderRange]] = None,
mm_positions: Optional[list[PlaceholderRange]] = None,
max_tokens: int = 16,
stop_token_ids: Optional[List[int]] = None,
stop_token_ids: Optional[list[int]] = None,
):
sampling_params = SamplingParams(ignore_eos=False,
max_tokens=max_tokens,
......
# SPDX-License-Identifier: Apache-2.0
from typing import List, Tuple
import pytest
import torch
from transformers import AutoTokenizer
......@@ -17,8 +15,8 @@ from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from tests.v1.engine.utils import FULL_STRINGS # isort: skip
EngineCoreSampleLogprobsType = List[Tuple[torch.Tensor, torch.Tensor]]
EngineCorePromptLogprobsType = Tuple[torch.Tensor, torch.Tensor]
EngineCoreSampleLogprobsType = list[tuple[torch.Tensor, torch.Tensor]]
EngineCorePromptLogprobsType = tuple[torch.Tensor, torch.Tensor]
def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors:
......
......@@ -2,7 +2,7 @@
import asyncio
from contextlib import ExitStack
from typing import List, Optional, Tuple
from typing import Optional
import pytest
......@@ -47,7 +47,7 @@ async def generate(engine: AsyncLLM,
prompt: PromptType,
output_kind: RequestOutputKind,
max_tokens: int,
prompt_logprobs: Optional[int] = None) -> Tuple[int, str]:
prompt_logprobs: Optional[int] = None) -> tuple[int, str]:
# Ensure generate doesn't complete too fast for cancellation test.
await asyncio.sleep(0.2)
......@@ -114,7 +114,7 @@ async def test_async_llm_refuses_prompt_logprobs_with_apc(
(VISION_ENGINE_ARGS, VISION_PROMPT)])
@pytest.mark.asyncio
async def test_load(monkeypatch, output_kind: RequestOutputKind,
engine_args_and_prompt: Tuple[AsyncEngineArgs,
engine_args_and_prompt: tuple[AsyncEngineArgs,
PromptType]):
# TODO(rickyx): Remove monkeypatch once we have a better way to test V1
# so that in the future when we switch, we don't have to change all the
......@@ -160,7 +160,7 @@ async def test_load(monkeypatch, output_kind: RequestOutputKind,
(VISION_ENGINE_ARGS, VISION_PROMPT)])
@pytest.mark.asyncio
async def test_abort(monkeypatch, output_kind: RequestOutputKind,
engine_args_and_prompt: Tuple[AsyncEngineArgs,
engine_args_and_prompt: tuple[AsyncEngineArgs,
PromptType]):
with monkeypatch.context() as m, ExitStack() as after:
......@@ -177,7 +177,7 @@ async def test_abort(monkeypatch, output_kind: RequestOutputKind,
request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]
# Create concurrent requests.
tasks: List[asyncio.Task] = []
tasks: list[asyncio.Task] = []
for request_id in request_ids:
tasks.append(
asyncio.create_task(
......
......@@ -5,7 +5,6 @@ import threading
import time
import uuid
from concurrent.futures import Future
from typing import List
import pytest
from transformers import AutoTokenizer
......@@ -213,7 +212,7 @@ def test_engine_core_concurrent_batches(monkeypatch):
class DummyExecutor(UniProcExecutor):
def initialize_from_config(
self, kv_cache_configs: List[KVCacheConfig]) -> None:
self, kv_cache_configs: list[KVCacheConfig]) -> None:
super().initialize_from_config(kv_cache_configs)
# This executor actually can only run 1 batch at a time
......
......@@ -3,7 +3,7 @@
import asyncio
import time
import uuid
from typing import Dict, List, Optional
from typing import Optional
import pytest
from transformers import AutoTokenizer
......@@ -44,7 +44,7 @@ def make_request(params: SamplingParams) -> EngineCoreRequest:
)
def loop_until_done(client: EngineCoreClient, outputs: Dict):
def loop_until_done(client: EngineCoreClient, outputs: dict):
while True:
engine_core_outputs = client.get_output().outputs
......@@ -62,7 +62,7 @@ def loop_until_done(client: EngineCoreClient, outputs: Dict):
break
async def loop_until_done_async(client: EngineCoreClient, outputs: Dict):
async def loop_until_done_async(client: EngineCoreClient, outputs: dict):
while True:
engine_core_outputs = (await client.get_output_async()).outputs
......@@ -121,7 +121,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
client.add_request(request)
time.sleep(0.01)
outputs: Dict[str, List] = {req_id: [] for req_id in request_ids}
outputs: dict[str, list] = {req_id: [] for req_id in request_ids}
loop_until_done(client, outputs)
for req_id in request_ids:
......@@ -207,7 +207,7 @@ async def test_engine_core_client_asyncio(monkeypatch):
await client.add_request_async(request)
await asyncio.sleep(0.01)
outputs: Dict[str, List] = {req_id: [] for req_id in request_ids}
outputs: dict[str, list] = {req_id: [] for req_id in request_ids}
await loop_until_done_async(client, outputs)
for req_id in request_ids:
......
# SPDX-License-Identifier: Apache-2.0
import random
from typing import Dict, List, Optional, Tuple
from typing import Optional
import pytest
......@@ -47,9 +47,9 @@ def vllm_model_apc(vllm_runner, monkeypatch):
def _get_test_sampling_params(
prompt_list: List[str],
prompt_list: list[str],
seed: Optional[int] = 42,
) -> Tuple[List[SamplingParams], List[int]]:
) -> tuple[list[SamplingParams], list[int]]:
"""Generate random sampling params for a batch."""
def get_mostly_n_gt1() -> int:
......@@ -81,7 +81,7 @@ def test_parallel_sampling(vllm_model, example_prompts) -> None:
# Validate each request response
for out, n in zip(outputs, n_list):
completion_counts: Dict[str, int] = {}
completion_counts: dict[str, int] = {}
# Assert correct number of completions
assert len(out.outputs) == n, (
f"{len(out.outputs)} completions; {n} expected.")
......
......@@ -2,7 +2,7 @@
import math
import time
from typing import Dict, List, Optional
from typing import Optional
import pytest
......@@ -112,12 +112,12 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
def _validate_logprobs(
gen_tokens: Dict[str, List[int]],
gen_logprobs: Dict[str, Optional[SampleLogprobs]],
gen_prompt_logprobs: Dict[str, Optional[PromptLogprobs]],
gen_cumulative_logprob: Dict[str, float],
gen_tokens: dict[str, list[int]],
gen_logprobs: dict[str, Optional[SampleLogprobs]],
gen_prompt_logprobs: dict[str, Optional[PromptLogprobs]],
gen_cumulative_logprob: dict[str, float],
dtv: DummyOutputProcessorTestVectors,
request_id_list: List[str],
request_id_list: list[str],
num_sample_logprobs: Optional[int],
num_prompt_logprobs: Optional[int],
) -> None:
......
......@@ -2,7 +2,7 @@
import random
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import Optional, Union
import torch
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
......@@ -61,7 +61,7 @@ def _create_random_top_logprob_test_vector(
def _create_random_top_logprob_test_matrix(
shape: Tuple,
shape: tuple,
lower: float,
upper: float,
) -> torch.Tensor:
......@@ -90,7 +90,7 @@ def _create_random_top_token_test_vector(
lower: int,
upper: int,
sampled_token_id: int,
adjust_num_logprobs: bool = True) -> Tuple[torch.Tensor, int]:
adjust_num_logprobs: bool = True) -> tuple[torch.Tensor, int]:
"""Create a random vector of top logprob token indices
Use to create fake sample logprobs for testing. The sampled token
......@@ -141,11 +141,11 @@ def _create_random_top_token_test_vector(
def _create_random_top_token_test_matrix(
shape: Tuple[int, int],
shape: tuple[int, int],
lower: int,
upper: int,
tokens_list: List[int],
) -> Tuple[torch.Tensor, torch.Tensor]:
tokens_list: list[int],
) -> tuple[torch.Tensor, torch.Tensor]:
"""Create a random matrix of top logprob token indices
Use to create fake prompt logprobs for testing.
......@@ -160,7 +160,7 @@ def _create_random_top_token_test_matrix(
upper: upper range of token ids
Returns:
Tuple containing:
tuple containing:
- 2D num_tokens x num_logprobs+1 torch Tensor of token ids
- 1D tensor of ranks of prompt tokens in their respective
rows, or random values
......@@ -206,10 +206,10 @@ def decode_token(
def generate_dummy_sample_logprobs(
sampled_tokens_list: List,
sampled_tokens_list: list,
num_logprobs: int,
tokenizer: PreTrainedTokenizer,
) -> List[Tuple[List[int], List[float], int]]:
) -> list[tuple[list[int], list[float], int]]:
"""Generate dummy sample logprobs
Generate a test data structure which imitates the list of sample logprobs
......@@ -221,7 +221,7 @@ def generate_dummy_sample_logprobs(
tokenizer: model tokenizer to use for detokenization
Returns
List of (top token ids vector, logprobs vector, sampled token rank)
list of (top token ids vector, logprobs vector, sampled token rank)
Python lists tuples; in each tuple the logprobs and top token ids
vectors have the same length which is either `num_logprobs` or
`num_logprobs+1`. Sampled token rank is the rank (index+1) of the
......@@ -253,7 +253,7 @@ def generate_dummy_sample_logprobs(
def generate_dummy_prompt_logprobs_tensors(
prompt_tokens_list: List,
prompt_tokens_list: list,
num_logprobs: int,
tokenizer: PreTrainedTokenizer,
) -> LogprobsTensors:
......@@ -269,7 +269,7 @@ def generate_dummy_prompt_logprobs_tensors(
tokenizer: model tokenizer to use for detokenization
Returns
Single Tuple of (logprobs matrix, top token ids matrix) torch Tensor,
Single tuple of (logprobs matrix, top token ids matrix) torch Tensor,
where both matrices have dimensions
num_prompt_tokens x num_logprobs
"""
......@@ -301,19 +301,19 @@ class DummyOutputProcessorTestVectors:
tokenizer: GeneralTokenizerType
tokenizer_group: BaseTokenizerGroup
vllm_config: EngineArgs
full_tokens: List[List[int]] # Prompt + generated tokens
prompt_tokens: List[List[int]]
generation_tokens: List[List[int]]
full_tokens: list[list[int]] # Prompt + generated tokens
prompt_tokens: list[list[int]]
generation_tokens: list[list[int]]
# Each request is associated with a tuple of
# (top tokens, top logprobs, ranks) prompt logprobs tensors
prompt_logprobs: List[LogprobsTensors]
prompt_logprobs: list[LogprobsTensors]
# Each request is associated with a sample logprobs; a request's
# sample logprobs are a list of (top tokens, top logprobs, ranks)
# sample logprobs tensors at each sequence position
generation_logprobs: List[List[Tuple[List[int], List[float], int]]]
prompt_strings: List[str]
prompt_strings_len: List[int]
generation_strings: List[str]
generation_logprobs: list[list[tuple[list[int], list[float], int]]]
prompt_strings: list[str]
prompt_strings_len: list[int]
generation_strings: list[str]
class MockEngineCore:
......@@ -321,18 +321,18 @@ class MockEngineCore:
def __init__(
self,
tokens_list: List[List[int]],
tokens_list: list[list[int]],
# For each request, for each sampled token offset,
# a tuple of
# (list of topk token ids, list of sample logprob vals, rank)
generated_logprobs_raw: Optional[List[List[Tuple[List[int],
List[float],
generated_logprobs_raw: Optional[list[list[tuple[list[int],
list[float],
int]]]] = None,
# For each request, a tuple of
# (prompt logprob val matrix, prompt logprob tok id matrix);
# each matrix has dimensions
# (num prompt toks) x (num prompt logprobs+1)
prompt_logprobs_raw: Optional[List[LogprobsTensors]] = None,
prompt_logprobs_raw: Optional[list[LogprobsTensors]] = None,
) -> None:
self.tokens_list = tokens_list
self.current_idx = 0
......@@ -341,7 +341,7 @@ class MockEngineCore:
self.prompt_logprobs_raw = prompt_logprobs_raw
self.do_prompt_logprobs = prompt_logprobs_raw is not None
def get_outputs(self) -> List[EngineCoreOutput]:
def get_outputs(self) -> list[EngineCoreOutput]:
do_logprobs = self.do_logprobs
do_prompt_logprobs = self.do_prompt_logprobs
token_idx = self.current_idx
......
# SPDX-License-Identifier: Apache-2.0
import re
from typing import Dict, List, Optional
from typing import Optional
import openai # use the official client for correctness check
import pytest
......@@ -193,7 +193,7 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
model_name: str,
prompt_logprobs: Optional[int]):
params: Dict = {
params: dict = {
"prompt": ["A robot may not injure another robot", "My name is"],
"model": model_name,
}
......@@ -237,7 +237,7 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
max_tokens=5,
temperature=0.0,
stream=True)
chunks: List[str] = []
chunks: list[str] = []
finish_reason_count = 0
async for chunk in stream:
chunks.append(chunk.choices[0].text)
......@@ -278,7 +278,7 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI,
num_completions = len(completion.choices)
assert num_completions == n, (
f"Num completions {num_completions} but expected {n}.")
completion_repeats: Dict[str, int] = {}
completion_repeats: dict[str, int] = {}
for idx, choice in enumerate(completion.choices):
# Assert correct completion index & some finish reason.
assert choice.index == idx, (
......@@ -321,7 +321,7 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
temperature=0.95,
stream=True,
seed=42)
chunks: List[List[str]] = [[] for i in range(n)]
chunks: list[list[str]] = [[] for i in range(n)]
finish_reason_count = 0
async for chunk in stream:
index = chunk.choices[0].index
......@@ -332,7 +332,7 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
# Assert `n` completions with correct finish reasons
assert finish_reason_count == n, (
f"Expected {n} completions with valid indices and finish_reason.")
completion_repeats: Dict[str, int] = {}
completion_repeats: dict[str, int] = {}
for chunk in chunks:
chunk_len = len(chunk)
# Assert correct number of completion tokens
......
# SPDX-License-Identifier: Apache-2.0
import itertools
from typing import List, Tuple
import pytest
import torch
......@@ -46,8 +45,8 @@ def hf_model(hf_runner):
def _repeat_logprob_config(
test_prompts,
logprob_prompt_logprob_list: List[Tuple],
) -> List[Tuple]:
logprob_prompt_logprob_list: list[tuple],
) -> list[tuple]:
"""Ensure each test prompt has a logprob config.
A logprob config specifies the optional (i.e.
......@@ -74,7 +73,7 @@ def _repeat_logprob_config(
tuples
Returns:
List of
list of
(optional num sample logprob,optional num prompt logprob)
tuples which is either identical to
`logprob_prompt_logprob_list`, or else repeats
......@@ -177,7 +176,7 @@ def _test_case_get_logprobs_and_prompt_logprobs(
for r in range(1, num_top_logprobs + 1))
output_text = vllm_result.outputs[0].text
output_string_from_most_likely_tokens_lst: List[str] = []
output_string_from_most_likely_tokens_lst: list[str] = []
for top_logprobs in vllm_result.outputs[0].logprobs:
top_logprob = next(iter(top_logprobs.values()))
output_string_from_most_likely_tokens_lst.append(
......
# SPDX-License-Identifier: Apache-2.0
from typing import List
import pytest
import torch
......@@ -13,7 +12,7 @@ def sampler():
return RejectionSampler()
def create_logits_tensor(token_ids: List[int],
def create_logits_tensor(token_ids: list[int],
vocab_size: int = 100) -> torch.Tensor:
"""Helper function to create logits tensor that
will produce desired token ids on argmax"""
......@@ -23,7 +22,7 @@ def create_logits_tensor(token_ids: List[int],
return logits
def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
def create_sampling_metadata(spec_tokens: list[list[int]]) -> SamplingMetadata:
batch_size = len(spec_tokens)
return SamplingMetadata(
temperature=torch.tensor([]),
......@@ -106,7 +105,7 @@ def test_single_token_sequence(sampler):
def test_empty_sequence(sampler):
"""Test handling empty sequence of speculated tokens"""
spec_tokens: List[List[int]] = [[]]
spec_tokens: list[list[int]] = [[]]
output_tokens = [5] # Just the bonus token
metadata = create_sampling_metadata(spec_tokens)
......
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, List, Optional, Set, Tuple
from typing import Optional
import numpy as np
import pytest
......@@ -32,7 +32,7 @@ def _create_penalty_tensor(batch_size: int, penalty_value: float,
def _create_prompt_tokens_tensor(
prompt_token_ids: List[List[int]],
prompt_token_ids: list[list[int]],
vocab_size: int,
device: torch.device,
) -> torch.Tensor:
......@@ -49,8 +49,8 @@ def _create_logit_bias(
batch_size: int,
vocab_size: int,
bias_value: float,
) -> List[Optional[Dict[int, float]]]:
res: List[Optional[Dict[int, float]]] = []
) -> list[Optional[dict[int, float]]]:
res: list[Optional[dict[int, float]]] = []
for i in range(batch_size):
logit_bias = {min(i, vocab_size - 1): bias_value}
res.append(logit_bias)
......@@ -83,8 +83,8 @@ def _create_default_sampling_metadata(
vocab_size: int,
device: torch.device,
) -> SamplingMetadata:
output_token_ids: List[List[int]] = []
prompt_token_ids: List[List[int]] = []
output_token_ids: list[list[int]] = []
prompt_token_ids: list[list[int]] = []
for _ in range(batch_size):
output_token_ids.append(
np.random.randint(0, vocab_size, size=num_output_tokens).tolist())
......@@ -118,8 +118,8 @@ def _create_default_sampling_metadata(
def _generate_min_token_penalties_and_stop_tokens(
num_output_tokens: int, batch_size: int, vocab_size: int,
batch_indices_for_min_token_penalty: List[int]
) -> Dict[int, Tuple[int, Set[int]]]:
batch_indices_for_min_token_penalty: list[int]
) -> dict[int, tuple[int, set[int]]]:
"""
Generates and returns a dict of minimum token penalties and
corresponding stop token IDs (`min_tokens`, `stop_token_ids`) for each
......@@ -130,7 +130,7 @@ def _generate_min_token_penalties_and_stop_tokens(
and a random set of stop token IDs is created. Otherwise, a lower
`min_tokens` value is assigned, and the stop token IDs set is empty.
"""
min_tokens: Dict[int, Tuple[int, Set[int]]] = {}
min_tokens: dict[int, tuple[int, set[int]]] = {}
for index in range(batch_size):
if index in batch_indices_for_min_token_penalty:
min_tokens[index] = (
......@@ -147,7 +147,7 @@ def _generate_min_token_penalties_and_stop_tokens(
def _create_weighted_output_token_list(
batch_size: int,
vocab_size: int) -> Tuple[List[List[int]], List[List[int]]]:
vocab_size: int) -> tuple[list[list[int]], list[list[int]]]:
"""
Creates an output token list where each token occurs a distinct
number of times.
......@@ -157,7 +157,7 @@ def _create_weighted_output_token_list(
list, each with a different frequency.
Returns:
Tuple[List[List[int]], List[List[int]]]:
tuple[list[list[int]], list[list[int]]]:
- The first element is the output token list, where each sublist
corresponds to a batch and contains tokens with weighted
frequencies.
......@@ -165,8 +165,8 @@ def _create_weighted_output_token_list(
batch, ordered by their frequency in the corresponding output
list.
"""
output_token_ids: List[List[int]] = []
sorted_token_ids_in_output: List[List[int]] = []
output_token_ids: list[list[int]] = []
sorted_token_ids_in_output: list[list[int]] = []
for _ in range(batch_size):
distinct_token_ids = np.random.choice(vocab_size,
size=np.random.randint(1, 10),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment