Unverified Commit cf069aa8 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated Python 3.8 typing (#13971)

parent bf33700e
# SPDX-License-Identifier: Apache-2.0
import re
from typing import List, Tuple
from vllm import CompletionOutput
def get_test_batch(batch_logprobs_composition: str) -> List[Tuple]:
def get_test_batch(batch_logprobs_composition: str) -> list[tuple]:
"""Generate logprobs configs for a batch of requests
A given request's logprobs configuration is (1) num_sample_logprobs and (2)
......@@ -32,7 +31,7 @@ def get_test_batch(batch_logprobs_composition: str) -> List[Tuple]:
Returns:
List of (Optional[num_sample_logprobs], Optional[num_prompt_logprobs])
list of (Optional[num_sample_logprobs], Optional[num_prompt_logprobs])
tuples
"""
if batch_logprobs_composition == "NONE":
......
# SPDX-License-Identifier: Apache-2.0
from typing import List
import torch
from vllm.v1.utils import bind_kv_cache
......@@ -22,7 +20,7 @@ def test_bind_kv_cache():
'layers.2.self_attn': torch.zeros((1, )),
'layers.3.self_attn': torch.zeros((1, )),
}
runner_kv_caches: List[torch.Tensor] = []
runner_kv_caches: list[torch.Tensor] = []
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[
'layers.0.self_attn']
......@@ -52,7 +50,7 @@ def test_bind_kv_cache_non_attention():
'model.layers.28.attn': torch.zeros((1, )),
}
runner_kv_caches: List[torch.Tensor] = []
runner_kv_caches: list[torch.Tensor] = []
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[
......
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, List, Optional, Set, Tuple
from typing import Optional
import numpy as np
import pytest
......@@ -22,22 +22,22 @@ MAX_NUM_PROMPT_TOKENS = 64
def _remove_requests(
input_batch: InputBatch, batch_size: int,
reqs: List[CachedRequestState]) -> Tuple[Set[str], List[int]]:
reqs: list[CachedRequestState]) -> tuple[set[str], list[int]]:
"""
Remove some requests randomly from the batch and returns a Tuple
Remove some requests randomly from the batch and returns a tuple
of 1) set of request removed 2) indices of the requests removed
ordered in descending order
"""
num_reqs_to_remove = np.random.randint(0, batch_size)
req_indices_to_remove: Set[int] = set()
req_indices_to_remove: set[int] = set()
for _ in range(num_reqs_to_remove):
req_index_to_remove = np.random.randint(0, batch_size)
req_indices_to_remove.add(req_index_to_remove)
req_indices_to_remove_list = list(req_indices_to_remove)
req_indices_to_remove_list.sort(reverse=True)
req_ids_to_remove: Set[str] = set()
req_ids_to_remove: set[str] = set()
for index in req_indices_to_remove:
input_batch.remove_request(reqs[index].req_id)
req_ids_to_remove.add(reqs[index].req_id)
......@@ -45,9 +45,9 @@ def _remove_requests(
def _construct_expected_sampling_metadata(
reqs: List[CachedRequestState],
req_ids_retained: Set[int],
req_id_index_in_input_batch: Dict[str, int],
reqs: list[CachedRequestState],
req_ids_retained: set[int],
req_id_index_in_input_batch: dict[str, int],
device: torch.device,
) -> SamplingMetadata:
"""
......@@ -55,8 +55,8 @@ def _construct_expected_sampling_metadata(
batch.
"""
num_reqs = len(req_ids_retained)
output_token_ids: List[List[int]] = [list() for _ in range(num_reqs)]
prompt_token_ids: List[List[int]] = [list() for _ in range(num_reqs)]
output_token_ids: list[list[int]] = [list() for _ in range(num_reqs)]
prompt_token_ids: list[list[int]] = [list() for _ in range(num_reqs)]
presence_penalties = [0.0 for _ in range(num_reqs)]
frequency_penalties = [0.0 for _ in range(num_reqs)]
repetition_penalties = [1.0 for _ in range(num_reqs)]
......@@ -191,7 +191,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
pin_memory=is_pin_memory_available(),
vocab_size=1024,
)
reqs: List[CachedRequestState] = []
reqs: list[CachedRequestState] = []
req_id_reqs = {}
req_id_output_token_ids = {}
# Add requests
......
......@@ -4,7 +4,8 @@ import contextlib
import dataclasses
import sys
import traceback
from typing import Callable, Generator
from collections.abc import Generator
from typing import Callable
@dataclasses.dataclass
......
......@@ -4,7 +4,8 @@ import contextlib
import dataclasses
import sys
import traceback
from typing import Callable, Generator, Generic, TypeVar
from collections.abc import Generator
from typing import Callable, Generic, TypeVar
_T = TypeVar("_T")
......
# SPDX-License-Identifier: Apache-2.0
import itertools
from typing import List
import pytest
import torch
......@@ -43,7 +42,7 @@ def test_empty_seq_group():
enable_chunked_prefill=False,
enforce_eager=True,
)
seq_group_metadata_list: List[SequenceGroupMetadata] = []
seq_group_metadata_list: list[SequenceGroupMetadata] = []
model_input = model_runner._prepare_model_input_tensors(
seq_group_metadata_list)
(
......@@ -103,9 +102,9 @@ def test_prepare_prompt(batch_size):
enforce_eager=True,
)
seq_lens: List[int] = []
encoder_seq_lens: List[int] = []
seq_group_metadata_list: List[SequenceGroupMetadata] = []
seq_lens: list[int] = []
encoder_seq_lens: list[int] = []
seq_group_metadata_list: list[SequenceGroupMetadata] = []
block_tables = {0: [1]}
cross_block_table = [2]
for i in range(batch_size):
......@@ -295,9 +294,9 @@ def test_prepare_decode(batch_size, multiple_seqs_per_seq_group):
enforce_eager=True,
)
seq_lens: List[int] = []
encoder_seq_lens: List[int] = []
seq_group_metadata_list: List[SequenceGroupMetadata] = []
seq_lens: list[int] = []
encoder_seq_lens: list[int] = []
seq_group_metadata_list: list[SequenceGroupMetadata] = []
block_tables = {
0: [1],
1: [3]
......@@ -503,9 +502,9 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
} if multiple_seqs_per_seq_group else {
0: [1]
}
seq_lens: List[int] = []
encoder_seq_lens: List[int] = []
seq_group_metadata_list: List[SequenceGroupMetadata] = []
seq_lens: list[int] = []
encoder_seq_lens: list[int] = []
seq_group_metadata_list: list[SequenceGroupMetadata] = []
cross_block_table = [2]
expanded_batch_size = 0
......
# SPDX-License-Identifier: Apache-2.0
import dataclasses
from typing import List, Tuple, Type
import torch
......@@ -27,15 +26,15 @@ class MockAttentionBackend(AttentionBackend):
raise NotImplementedError
@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
def get_metadata_cls() -> type["AttentionMetadata"]:
return AttentionMetadata
@staticmethod
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
def get_builder_cls() -> type["AttentionMetadataBuilder"]:
return AttentionMetadataBuilder
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
def get_state_cls() -> type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
......@@ -44,7 +43,7 @@ class MockAttentionBackend(AttentionBackend):
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
) -> tuple[int, ...]:
raise NotImplementedError
@staticmethod
......@@ -57,7 +56,7 @@ class MockAttentionBackend(AttentionBackend):
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
kv_caches: list[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
pass
......
# SPDX-License-Identifier: Apache-2.0
from typing import List
import pytest
import torch
......@@ -42,8 +40,8 @@ def test_prepare_prompt(batch_size):
enable_chunked_prefill=False,
)
seq_lens: List[int] = []
seq_group_metadata_list: List[SequenceGroupMetadata] = []
seq_lens: list[int] = []
seq_group_metadata_list: list[SequenceGroupMetadata] = []
block_tables = {0: [1]}
for i in range(batch_size):
# make sure all tokens fit into one block
......@@ -159,8 +157,8 @@ def test_prepare_decode_cuda_graph(batch_size):
enable_chunked_prefill=False,
)
context_lens: List[int] = []
seq_group_metadata_list: List[SequenceGroupMetadata] = []
context_lens: list[int] = []
seq_group_metadata_list: list[SequenceGroupMetadata] = []
# Assume each seq group finishes prefill.
for i in range(batch_size):
# make sure all tokens fit into one block
......@@ -265,7 +263,7 @@ def test_empty_seq_group():
dtype="float16",
enforce_eager=False,
)
seq_group_metadata_list: List[SequenceGroupMetadata] = []
seq_group_metadata_list: list[SequenceGroupMetadata] = []
model_input = model_runner._prepare_model_input_tensors(
seq_group_metadata_list)
input_tokens, input_positions, attn_metadata = (
......@@ -315,10 +313,10 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
)
# Add prefill requests.
seq_lens: List[int] = []
seq_group_metadata_list: List[SequenceGroupMetadata] = []
prefill_metadata_list: List[SequenceGroupMetadata] = []
decode_metadata_list: List[SequenceGroupMetadata] = []
seq_lens: list[int] = []
seq_group_metadata_list: list[SequenceGroupMetadata] = []
prefill_metadata_list: list[SequenceGroupMetadata] = []
decode_metadata_list: list[SequenceGroupMetadata] = []
block_tables = {0: [1]}
prefill_batch_size = batch_size // 2
decode_batch_size = batch_size - prefill_batch_size
......
......@@ -2,13 +2,12 @@
import argparse
import json
from typing import Dict
from vllm.profiler.layerwise_profile import ModelStatsEntry, SummaryStatsEntry
from vllm.profiler.utils import TablePrinter, indent_string
def flatten_entries(entry_cls, profile_dict: Dict):
def flatten_entries(entry_cls, profile_dict: dict):
entries_and_depth = []
def get_entries(node, curr_depth=0):
......
......@@ -6,7 +6,7 @@ import json
import math
import os
from pathlib import Path
from typing import Any, List, Optional, Tuple
from typing import Any, Optional
import matplotlib.pyplot as plt
import pandas as pd
......@@ -24,7 +24,7 @@ def largest_dist_from_leaf(node: dict, depth: int = 0):
def get_entries_at_depth(depth: int,
entries_and_traces: List[Tuple[Any, Any]],
entries_and_traces: list[tuple[Any, Any]],
node: dict,
curr_depth: int = 0,
trace=()):
......@@ -48,9 +48,9 @@ def get_entries_at_depth(depth: int,
trace=trace)
def fold_nodes(root: dict, nodes_to_fold: List[str]):
def fold_nodes(root: dict, nodes_to_fold: list[str]):
stack: List[dict] = [root]
stack: list[dict] = [root]
while len(stack) != 0:
node = stack.pop()
if node['entry']['name'] in nodes_to_fold:
......@@ -427,12 +427,12 @@ def main(
plot_metric: str,
make_names_unique: bool,
top_k: int,
json_nodes_to_fold: List[str]):
json_nodes_to_fold: list[str]):
def prepare_data(profile_json: dict, step_keys: List[str]) -> pd.DataFrame:
def prepare_data(profile_json: dict, step_keys: list[str]) -> pd.DataFrame:
def get_entries_and_traces(key: str):
entries_and_traces: List[Tuple[Any, Any]] = []
entries_and_traces: list[tuple[Any, Any]] = []
for root in profile_json[key]["summary_stats"]:
# Fold nodes in the traces as per user request. i.e. simply
# make the requested nodes leaf-nodes.
......
......@@ -2,7 +2,7 @@
import contextlib
import importlib
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Optional, Union
import torch
import torch.library
......@@ -198,7 +198,7 @@ def rms_norm_dynamic_per_token_quant(
quant_dtype: torch.dtype,
scale_ub: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
output = torch.empty_like(input, dtype=quant_dtype)
scales = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
......@@ -347,7 +347,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
@register_fake("_C::aqlm_gemm")
def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor,
codebooks: torch.Tensor, scales: torch.Tensor,
codebook_partition_sizes: List[int],
codebook_partition_sizes: list[int],
bias: Optional[torch.Tensor]) -> torch.Tensor:
out_features = codes.size(0) * codebooks.size(2)
flat_input = input.reshape((-1, input.size(-1)))
......@@ -363,7 +363,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
@register_fake("_C::aqlm_dequant")
def _aqlm_dequant_fake(
codes: torch.Tensor, codebooks: torch.Tensor,
codebook_partition_sizes: List[int]) -> torch.Tensor:
codebook_partition_sizes: list[int]) -> torch.Tensor:
in_features = codes.size(1) * 8
out_features = codes.size(0)
return torch.empty((out_features, in_features),
......@@ -554,7 +554,7 @@ def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool:
def cutlass_sparse_compress(a: torch.Tensor) \
-> Tuple[torch.Tensor, torch.Tensor]:
-> tuple[torch.Tensor, torch.Tensor]:
"""
Compresses a sparse matrix for use with Cutlass sparse operations.
......@@ -571,7 +571,7 @@ def cutlass_sparse_compress(a: torch.Tensor) \
- `torch.float16`
Returns:
Tuple[torch.Tensor, torch.Tensor]:
tuple[torch.Tensor, torch.Tensor]:
A tuple containing:
- `a_nzs` (torch.Tensor): A tensor containing non-zero elements of `a`.
- `a_meta` (torch.Tensor): A tensor containing metadata for the sparse representation.
......@@ -646,14 +646,14 @@ def cutlass_scaled_sparse_mm(
# aqlm
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
codebooks: torch.Tensor, scales: torch.Tensor,
codebook_partition_sizes: List[int],
codebook_partition_sizes: list[int],
bias: Optional[torch.Tensor]) -> torch.Tensor:
return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales,
codebook_partition_sizes, bias)
def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
codebook_partition_sizes: List[int]) -> torch.Tensor:
codebook_partition_sizes: list[int]) -> torch.Tensor:
return torch.ops._C.aqlm_dequant(codes, codebooks,
codebook_partition_sizes)
......@@ -738,7 +738,7 @@ def machete_supported_schedules(
group_zeros_type: Optional[torch.dtype] = None,
channel_scales_type: Optional[torch.dtype] = None,
token_scales_type: Optional[torch.dtype] = None,
out_type: Optional[torch.dtype] = None) -> List[str]:
out_type: Optional[torch.dtype] = None) -> list[str]:
return torch.ops._C.machete_supported_schedules(
a_type, b_type.id, group_scales_type, group_zeros_type,
channel_scales_type, token_scales_type, out_type)
......@@ -783,7 +783,7 @@ def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
# fp4
def scaled_fp4_quant(
input: torch.Tensor,
input_global_scale: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
input_global_scale: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP4 and return quantized tensor and scale.
......@@ -798,7 +798,7 @@ def scaled_fp4_quant(
input_global_scale: A scalar scaling factor for the entire tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every
two values are packed into a uint8 and float8_e4m3 scaling factors
in the sizzled layout.
"""
......@@ -845,7 +845,7 @@ def scaled_fp8_quant(
num_token_padding: Optional[int] = None,
scale_ub: Optional[torch.Tensor] = None,
use_per_token_if_dynamic: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
......@@ -866,12 +866,12 @@ def scaled_fp8_quant(
in the dynamic quantization case.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
scaling factor.
"""
# This code assumes batch_dim and num_tokens are flattened
assert (input.ndim == 2)
shape: Union[Tuple[int, int], torch.Size] = input.shape
shape: Union[tuple[int, int], torch.Size] = input.shape
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype: torch.dtype = torch.float8_e4m3fnuz \
if current_platform.is_rocm() else torch.float8_e4m3fn
......@@ -903,7 +903,7 @@ def allspark_repack_weight(
scale: torch.Tensor,
zero_point: Optional[torch.Tensor] = None,
has_zp: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Rearrange qweight, scale, and zero_point(if asymmetric) to n32k16 format
for Ampere W8A16 Fused Gemm kernel
......@@ -917,7 +917,7 @@ def allspark_repack_weight(
if use asymmetric quantization, has_zp = True.
Returns:
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] :
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] :
rearranged weight, scale, and optionally zero_point.
"""
K = qweight.shape[0]
......@@ -964,7 +964,7 @@ def scaled_int8_quant(
scale: Optional[torch.Tensor] = None,
azp: Optional[torch.Tensor] = None,
symmetric: bool = True
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
......@@ -977,7 +977,7 @@ def scaled_int8_quant(
symmetric: Whether to use symmetric quantization (scale only, azp ignored).
Returns:
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
"""
output = torch.empty_like(input, dtype=torch.int8)
if scale is not None:
......@@ -1165,13 +1165,13 @@ def concat_and_cache_mla(
scale)
def copy_blocks(key_caches: List[torch.Tensor],
value_caches: List[torch.Tensor],
def copy_blocks(key_caches: list[torch.Tensor],
value_caches: list[torch.Tensor],
block_mapping: torch.Tensor) -> None:
torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
def copy_blocks_mla(kv_caches: List[torch.Tensor],
def copy_blocks_mla(kv_caches: list[torch.Tensor],
block_mapping: torch.Tensor) -> None:
torch.ops._C_cache_ops.copy_blocks_mla(kv_caches, block_mapping)
......@@ -1209,7 +1209,7 @@ def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
# custom ar
def init_custom_ar(ipc_tensors: List[torch.Tensor], rank_data: torch.Tensor,
def init_custom_ar(ipc_tensors: list[torch.Tensor], rank_data: torch.Tensor,
rank: int, full_nvlink: bool) -> int:
return torch.ops._C_custom_ar.init_custom_ar(ipc_tensors, rank_data, rank,
full_nvlink)
......@@ -1229,16 +1229,16 @@ def meta_size() -> int:
return torch.ops._C_custom_ar.meta_size()
def register_buffer(fa: int, ipc_tensors: List[int]) -> None:
def register_buffer(fa: int, ipc_tensors: list[int]) -> None:
return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors)
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
def get_graph_buffer_ipc_meta(fa: int) -> tuple[list[int], list[int]]:
return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(fa: int, handles: List[List[int]],
offsets: List[List[int]]) -> None:
def register_graph_buffers(fa: int, handles: list[list[int]],
offsets: list[list[int]]) -> None:
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
......@@ -1246,7 +1246,7 @@ def get_flash_mla_metadata(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,
num_heads_k: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
cache_seqlens: (batch_size), dtype torch.int32.
......@@ -1272,7 +1272,7 @@ def flash_mla_with_kvcache(
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
......
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Tuple
from typing import Optional
import torch
......@@ -18,7 +18,7 @@ class ipex_ops:
@staticmethod
def _reshape_activation_tensor(
x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
num = x.size(0)
d = x.size(1) // 2
x = x.reshape(num, 2, d)
......@@ -213,8 +213,8 @@ class ipex_ops:
key, value, key_cache, value_cache, slot_mapping)
@staticmethod
def copy_blocks(key_caches: List[torch.Tensor],
value_caches: List[torch.Tensor],
def copy_blocks(key_caches: list[torch.Tensor],
value_caches: list[torch.Tensor],
block_mapping: torch.Tensor) -> None:
torch.xpu.copy_blocks( # type: ignore
key_caches,
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union
from vllm.sequence import Logprob
......@@ -17,14 +17,14 @@ class BeamSearchSequence:
about to be returned to the user.
"""
# The tokens includes the prompt.
tokens: List[int]
logprobs: List[Dict[int, Logprob]]
tokens: list[int]
logprobs: list[dict[int, Logprob]]
cum_logprob: float = 0.0
text: Optional[str] = None
finish_reason: Optional[str] = None
stop_reason: Union[int, str, None] = None
multi_modal_data: Optional["MultiModalDataDict"] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None
mm_processor_kwargs: Optional[dict[str, Any]] = None
@dataclass
......@@ -33,20 +33,20 @@ class BeamSearchOutput:
It contains the list of the best beam search sequences.
The length of the list is equal to the beam width.
"""
sequences: List[BeamSearchSequence]
sequences: list[BeamSearchSequence]
class BeamSearchInstance:
def __init__(self, prompt_tokens: List[int]):
self.beams: List[BeamSearchSequence] = [
def __init__(self, prompt_tokens: list[int]):
self.beams: list[BeamSearchSequence] = [
BeamSearchSequence(tokens=prompt_tokens, logprobs=[])
]
self.completed: List[BeamSearchSequence] = []
self.completed: list[BeamSearchSequence] = []
def get_beam_search_score(
tokens: List[int],
tokens: list[int],
cumulative_logprob: float,
eos_token_id: int,
length_penalty: float = 1.0,
......
......@@ -7,13 +7,14 @@ import hashlib
import json
import sys
import warnings
from collections import Counter
from collections.abc import Mapping
from contextlib import contextmanager
from dataclasses import dataclass, field, replace
from importlib.util import find_spec
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict,
Final, List, Literal, Mapping, Optional, Protocol, Set,
Tuple, Type, Union)
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
Optional, Protocol, Union)
import torch
from pydantic import BaseModel, Field, PrivateAttr
......@@ -67,20 +68,20 @@ _ResolvedTask = Literal["generate", "embed", "classify", "score", "reward",
RunnerType = Literal["generate", "pooling", "draft", "transcription"]
_RUNNER_TASKS: Dict[RunnerType, List[_ResolvedTask]] = {
_RUNNER_TASKS: dict[RunnerType, list[_ResolvedTask]] = {
"generate": ["generate"],
"pooling": ["embed", "classify", "score", "reward"],
"draft": ["draft"],
"transcription": ["transcription"],
}
_TASK_RUNNER: Dict[_ResolvedTask, RunnerType] = {
_TASK_RUNNER: dict[_ResolvedTask, RunnerType] = {
task: runner
for runner, tasks in _RUNNER_TASKS.items()
for task in tasks
}
HfOverrides = Union[Dict[str, Any], Callable[[PretrainedConfig],
HfOverrides = Union[dict[str, Any], Callable[[PretrainedConfig],
PretrainedConfig]]
......@@ -92,7 +93,7 @@ class SupportsHash(Protocol):
class SupportsMetricsInfo(Protocol):
def metrics_info(self) -> Dict[str, str]:
def metrics_info(self) -> dict[str, str]:
...
......@@ -209,7 +210,7 @@ class ModelConfig:
excluding anything before input ids/embeddings and after
the final hidden states.
"""
factors: List[Any] = []
factors: list[Any] = []
factors.append(self.model)
factors.append(self.dtype)
factors.append(self.quantization)
......@@ -233,7 +234,7 @@ class ModelConfig:
allowed_local_media_path: str = "",
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[Dict[str, Any]] = None,
rope_scaling: Optional[dict[str, Any]] = None,
rope_theta: Optional[float] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
......@@ -244,19 +245,19 @@ class ModelConfig:
max_logprobs: int = 20,
disable_sliding_window: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
served_model_name: Optional[Union[str, list[str]]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
use_async_output_proc: bool = True,
config_format: ConfigFormat = ConfigFormat.AUTO,
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
mm_processor_kwargs: Optional[dict[str, Any]] = None,
disable_mm_preprocessor_cache: bool = False,
override_neuron_config: Optional[Dict[str, Any]] = None,
override_neuron_config: Optional[dict[str, Any]] = None,
override_pooler_config: Optional["PoolerConfig"] = None,
logits_processor_pattern: Optional[str] = None,
generation_config: Optional[str] = None,
enable_sleep_mode: bool = False,
override_generation_config: Optional[Dict[str, Any]] = None,
override_generation_config: Optional[dict[str, Any]] = None,
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
) -> None:
self.model = model
......@@ -283,7 +284,7 @@ class ModelConfig:
hf_overrides_fn = None
if rope_scaling is not None:
hf_override: Dict[str, Any] = {"rope_scaling": rope_scaling}
hf_override: dict[str, Any] = {"rope_scaling": rope_scaling}
hf_overrides_kw.update(hf_override)
msg = ("`--rope-scaling` will be removed in a future release. "
f"'Please instead use `--hf-overrides '{hf_override!r}'`")
......@@ -505,8 +506,8 @@ class ModelConfig:
def _get_preferred_task(
self,
architectures: List[str],
supported_tasks: Set[_ResolvedTask],
architectures: list[str],
supported_tasks: set[_ResolvedTask],
) -> Optional[_ResolvedTask]:
model_id = self.model
if get_pooling_config(model_id, self.revision):
......@@ -516,7 +517,7 @@ class ModelConfig:
if self.registry.is_transcription_model(architectures):
return "transcription"
suffix_to_preferred_task: List[Tuple[str, _ResolvedTask]] = [
suffix_to_preferred_task: list[tuple[str, _ResolvedTask]] = [
# Other models follow this pattern
("ForCausalLM", "generate"),
("ForConditionalGeneration", "generate"),
......@@ -537,27 +538,27 @@ class ModelConfig:
def _resolve_task(
self,
task_option: Union[TaskOption, Literal["draft"]],
) -> Tuple[Set[_ResolvedTask], _ResolvedTask]:
) -> tuple[set[_ResolvedTask], _ResolvedTask]:
if task_option == "draft":
return {"draft"}, "draft"
registry = self.registry
architectures = self.architectures
runner_support: Dict[RunnerType, bool] = {
runner_support: dict[RunnerType, bool] = {
# NOTE: Listed from highest to lowest priority,
# in case the model supports multiple of them
"transcription": registry.is_transcription_model(architectures),
"generate": registry.is_text_generation_model(architectures),
"pooling": registry.is_pooling_model(architectures),
}
supported_runner_types_lst: List[RunnerType] = [
supported_runner_types_lst: list[RunnerType] = [
runner_type
for runner_type, is_supported in runner_support.items()
if is_supported
]
supported_tasks_lst: List[_ResolvedTask] = [
supported_tasks_lst: list[_ResolvedTask] = [
task for runner_type in supported_runner_types_lst
for task in _RUNNER_TASKS[runner_type]
]
......@@ -767,7 +768,7 @@ class ModelConfig:
self.use_async_output_proc = False
def get_hf_config_sliding_window(
self) -> Union[Optional[int], List[Optional[int]]]:
self) -> Union[Optional[int], list[Optional[int]]]:
"""Get the sliding window size, or None if disabled."""
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
......@@ -778,7 +779,7 @@ class ModelConfig:
return None
return getattr(self.hf_text_config, "sliding_window", None)
def get_sliding_window(self) -> Optional[Union[int, List[Optional[int]]]]:
def get_sliding_window(self) -> Optional[Union[int, list[Optional[int]]]]:
"""Get the sliding window size, or None if disabled.
"""
# If user disables sliding window, return None.
......@@ -888,7 +889,7 @@ class ModelConfig:
return num_heads // parallel_config.tensor_parallel_size
def get_layers_start_end_indices(
self, parallel_config: "ParallelConfig") -> Tuple[int, int]:
self, parallel_config: "ParallelConfig") -> tuple[int, int]:
from vllm.distributed.utils import get_pp_indices
if self.hf_text_config.model_type == "deepseek_mtp":
total_num_hidden_layers = getattr(self.hf_text_config,
......@@ -949,7 +950,7 @@ class ModelConfig:
return self.multimodal_config
def try_get_generation_config(self) -> Dict[str, Any]:
def try_get_generation_config(self) -> dict[str, Any]:
if self.generation_config is None or self.generation_config == "auto":
config = try_get_generation_config(
self.hf_config_path or self.model,
......@@ -967,7 +968,7 @@ class ModelConfig:
return config.to_diff_dict()
def get_diff_sampling_param(self) -> Dict[str, Any]:
def get_diff_sampling_param(self) -> dict[str, Any]:
"""
This method returns a dictionary containing the parameters
that differ from the default sampling parameters, but only
......@@ -975,7 +976,7 @@ class ModelConfig:
set, an empty dictionary is returned.
Returns:
Dict[str, Any]: A dictionary with the differing sampling
dict[str, Any]: A dictionary with the differing sampling
parameters if `generation_config` is set, otherwise an
empty dictionary.
"""
......@@ -1032,7 +1033,7 @@ class ModelConfig:
return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE
@property
def supported_runner_types(self) -> Set[RunnerType]:
def supported_runner_types(self) -> set[RunnerType]:
return {_TASK_RUNNER[task] for task in self.supported_tasks}
@property
......@@ -1075,7 +1076,7 @@ class CacheConfig:
excluding anything before input ids/embeddings and after
the final hidden states.
"""
factors: List[Any] = []
factors: list[Any] = []
factors.append(self.cache_dtype)
# `cpu_offload_gb` does not use `torch.compile` yet.
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
......@@ -1183,7 +1184,7 @@ class TokenizerPoolConfig:
pool type.
"""
pool_size: int
pool_type: Union[str, Type["BaseTokenizerGroup"]]
pool_type: Union[str, type["BaseTokenizerGroup"]]
extra_config: dict
def compute_hash(self) -> str:
......@@ -1200,7 +1201,7 @@ class TokenizerPoolConfig:
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: List[Any] = []
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
......@@ -1214,7 +1215,7 @@ class TokenizerPoolConfig:
@classmethod
def create_config(
cls, tokenizer_pool_size: int,
tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]],
tokenizer_pool_type: Union[str, type["BaseTokenizerGroup"]],
tokenizer_pool_extra_config: Optional[Union[str, dict]]
) -> Optional["TokenizerPoolConfig"]:
"""Create a TokenizerPoolConfig from the given parameters.
......@@ -1285,7 +1286,7 @@ class LoadConfig:
download_dir: Optional[str] = None
model_loader_extra_config: Optional[Union[str, dict]] = field(
default_factory=dict)
ignore_patterns: Optional[Union[List[str], str]] = None
ignore_patterns: Optional[Union[list[str], str]] = None
def compute_hash(self) -> str:
"""
......@@ -1301,7 +1302,7 @@ class LoadConfig:
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: List[Any] = []
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
......@@ -1359,7 +1360,7 @@ class ParallelConfig:
# to "ray" if Ray is installed and fail otherwise. Note that tpu
# and hpu only support Ray for distributed inference.
distributed_executor_backend: Optional[Union[str,
Type["ExecutorBase"]]] = None
type["ExecutorBase"]]] = None
# the full name of the worker class to use. If "auto", the worker class
# will be determined based on the platform.
......@@ -1423,7 +1424,7 @@ class ParallelConfig:
excluding anything before input ids/embeddings and after
the final hidden states.
"""
factors: List[Any] = []
factors: list[Any] = []
factors.append(self.pipeline_parallel_size)
factors.append(self.tensor_parallel_size)
return hashlib.sha256(str(factors).encode()).hexdigest()
......@@ -1600,7 +1601,7 @@ class SchedulerConfig:
# scheduler class or path. "vllm.core.scheduler.Scheduler" (default)
# or "mod.custom_class".
scheduler_cls: Union[str, Type[object]] = "vllm.core.scheduler.Scheduler"
scheduler_cls: Union[str, type[object]] = "vllm.core.scheduler.Scheduler"
def compute_hash(self) -> str:
"""
......@@ -1616,7 +1617,7 @@ class SchedulerConfig:
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: List[Any] = []
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
......@@ -1752,7 +1753,7 @@ class DeviceConfig:
# no factors to consider.
# the device/platform information will be summarized
# by torch/vllm automatically.
factors: List[Any] = []
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
......@@ -1798,7 +1799,7 @@ class SpeculativeConfig:
"""
# no factors to consider.
# spec decode does not use `torch.compile` yet.
factors: List[Any] = []
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
......@@ -2261,7 +2262,7 @@ class LoRAConfig:
lora_extra_vocab_size: int = 256
# This is a constant.
lora_vocab_padding_size: ClassVar[int] = 256
long_lora_scaling_factors: Optional[Tuple[float]] = None
long_lora_scaling_factors: Optional[tuple[float]] = None
bias_enabled: bool = False
def compute_hash(self) -> str:
......@@ -2278,7 +2279,7 @@ class LoRAConfig:
"""
# no factors to consider.
# LoRA is not compatible with `torch.compile` .
factors: List[Any] = []
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
......@@ -2350,7 +2351,7 @@ class PromptAdapterConfig:
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: List[Any] = []
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
......@@ -2395,7 +2396,7 @@ class MultiModalConfig:
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: List[Any] = []
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
......@@ -2431,7 +2432,7 @@ class PoolerConfig:
are returned.
"""
returned_token_ids: Optional[List[int]] = None
returned_token_ids: Optional[list[int]] = None
"""
A list of indices for the vocabulary dimensions to be extracted,
such as the token IDs of ``good_token`` and ``bad_token`` in the
......@@ -2452,7 +2453,7 @@ class PoolerConfig:
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: List[Any] = []
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
......@@ -2469,7 +2470,7 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
"bfloat16": torch.bfloat16,
}
_ROCM_NOT_SUPPORTED_DTYPE: List[str] = [] #
_ROCM_NOT_SUPPORTED_DTYPE: list[str] = [] #
def _get_and_verify_dtype(
......@@ -2558,7 +2559,7 @@ def _get_and_verify_max_len(
hf_config: PretrainedConfig,
max_model_len: Optional[int],
disable_sliding_window: bool,
sliding_window_len: Optional[Union[int, List[Optional[int]]]],
sliding_window_len: Optional[Union[int, list[Optional[int]]]],
spec_target_max_model_len: Optional[int] = None,
encoder_config: Optional[Any] = None,
) -> int:
......@@ -2684,7 +2685,7 @@ def _get_and_verify_max_len(
def get_min_sliding_window(
sliding_window: Union[int, List[Optional[int]]]) -> int:
sliding_window: Union[int, list[Optional[int]]]) -> int:
if isinstance(sliding_window, list):
return min(s for s in sliding_window if s is not None)
......@@ -2692,7 +2693,7 @@ def get_min_sliding_window(
def get_served_model_name(model: str,
served_model_name: Optional[Union[str, List[str]]]):
served_model_name: Optional[Union[str, list[str]]]):
"""
If the input is a non-empty list, the first model_name in
`served_model_name` is taken.
......@@ -2731,7 +2732,7 @@ class DecodingConfig:
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: List[Any] = []
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
......@@ -2774,7 +2775,7 @@ class ObservabilityConfig:
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: List[Any] = []
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
......@@ -2833,7 +2834,7 @@ class KVTransferConfig(BaseModel):
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: List[Any] = []
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str
......@@ -2930,7 +2931,7 @@ class CompilationConfig(BaseModel):
torch.compile will handle cudagraph capture logic in the future.
- cudagraph_capture_sizes: sizes to capture cudagraph.
- None (default): capture sizes are inferred from vllm config.
- List[int]: capture sizes are specified as given.
- list[int]: capture sizes are specified as given.
- cudagraph_num_of_warmups: number of warmup runs for cudagraph.
It means the first several runs will be treated as warmup runs.
Only after that, the execution will be recorded, and the recorded
......@@ -2972,17 +2973,17 @@ class CompilationConfig(BaseModel):
debug_dump_path: str = ""
cache_dir: str = ""
backend: str = ""
custom_ops: List[str] = Field(default_factory=list)
splitting_ops: List[str] = Field(default=None) # type: ignore
custom_ops: list[str] = Field(default_factory=list)
splitting_ops: list[str] = Field(default=None) # type: ignore
use_inductor: bool = True
compile_sizes: Optional[List[Union[int, str]]] = Field(default=None)
inductor_compile_config: Dict = Field(default_factory=dict)
inductor_passes: Dict[str, str] = Field(default_factory=dict)
compile_sizes: Optional[list[Union[int, str]]] = Field(default=None)
inductor_compile_config: dict = Field(default_factory=dict)
inductor_passes: dict[str, str] = Field(default_factory=dict)
use_cudagraph: bool = False
cudagraph_num_of_warmups: int = 0
cudagraph_capture_sizes: Optional[List[int]] = None
cudagraph_capture_sizes: Optional[list[int]] = None
cudagraph_copy_inputs: bool = False
class PassConfig(BaseModel):
......@@ -2998,7 +2999,7 @@ class CompilationConfig(BaseModel):
- enable_noop: whether to enable the custom no-op elimination pass.
TODO(luka) better pass enabling system.
"""
dump_graph_stages: List[str] = Field(default_factory=list)
dump_graph_stages: list[str] = Field(default_factory=list)
dump_graph_dir: Path = Field(default=Path("."))
enable_fusion: bool = True
enable_noop: bool = True
......@@ -3026,20 +3027,20 @@ class CompilationConfig(BaseModel):
max_capture_size: int = PrivateAttr
local_cache_dir: str = PrivateAttr # local cache dir for each rank
# optimization:
# Intuitively, bs_to_padded_graph_size should be Dict[int, int].
# Intuitively, bs_to_padded_graph_size should be dict[int, int].
# since we know all keys are in a range [0, max_capture_size],
# we can optimize it to List[int] for better lookup performance.
bs_to_padded_graph_size: List[int] = PrivateAttr
# we can optimize it to list[int] for better lookup performance.
bs_to_padded_graph_size: list[int] = PrivateAttr
# keep track of enabled and disabled custom ops
enabled_custom_ops: Counter[str] = PrivateAttr
disabled_custom_ops: Counter[str] = PrivateAttr
traced_files: Set[str] = PrivateAttr
traced_files: set[str] = PrivateAttr
compilation_time: float = PrivateAttr
# Per-model forward context
# Map from layer name to the attention cls
static_forward_context: Dict[str, Any] = PrivateAttr
static_forward_context: dict[str, Any] = PrivateAttr
def compute_hash(self) -> str:
"""
......@@ -3053,7 +3054,7 @@ class CompilationConfig(BaseModel):
excluding anything before input ids/embeddings and after
the final hidden states.
"""
factors: List[Any] = []
factors: list[Any] = []
factors.append(self.level)
factors.append(self.backend)
factors.append(self.custom_ops)
......@@ -3150,7 +3151,7 @@ class CompilationConfig(BaseModel):
return VllmBackend(vllm_config)
def init_with_cudagraph_sizes(self,
cudagraph_capture_sizes: List[int]) -> None:
cudagraph_capture_sizes: list[int]) -> None:
"""To complete the initialization of config,
we need to know the cudagraph sizes."""
......@@ -3243,10 +3244,10 @@ class VllmConfig:
excluding anything before input ids/embeddings and after
the final hidden states.
"""
factors: List[Any] = []
factors: list[Any] = []
# summarize vllm config
vllm_factors: List[Any] = []
vllm_factors: list[Any] = []
from vllm import __version__
vllm_factors.append(__version__)
if self.model_config:
......
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Mapping, MutableMapping
from pathlib import Path
from typing import Mapping, MutableMapping, Optional
from typing import Optional
from urllib.parse import urlparse
import aiohttp
......
......@@ -10,7 +10,8 @@ import asyncio
import json
import ssl
from argparse import Namespace
from typing import Any, AsyncGenerator, Optional
from collections.abc import AsyncGenerator
from typing import Any, Optional
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
......
......@@ -5,10 +5,11 @@ import codecs
import json
from abc import ABC, abstractmethod
from collections import defaultdict, deque
from collections.abc import Awaitable, Iterable
from functools import cache, lru_cache, partial
from pathlib import Path
from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List,
Literal, Optional, Tuple, TypeVar, Union, cast)
from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union,
cast)
import jinja2.nodes
import transformers.utils.chat_template_utils as hf_chat_utils
......@@ -117,7 +118,7 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
role: Required[str]
"""The role of the message's author."""
content: Union[str, List[ChatCompletionContentPartParam]]
content: Union[str, list[ChatCompletionContentPartParam]]
"""The contents of the message."""
name: str
......@@ -143,7 +144,7 @@ class ConversationMessage(TypedDict, total=False):
role: Required[str]
"""The role of the message's author."""
content: Union[Optional[str], List[Dict[str, str]]]
content: Union[Optional[str], list[dict[str, str]]]
"""The contents of the message"""
tool_call_id: Optional[str]
......@@ -495,13 +496,13 @@ class BaseMultiModalContentParser(ABC):
super().__init__()
# multimodal placeholder_string : count
self._placeholder_counts: Dict[str, int] = defaultdict(lambda: 0)
self._placeholder_counts: dict[str, int] = defaultdict(lambda: 0)
def _add_placeholder(self, placeholder: Optional[str]):
if placeholder:
self._placeholder_counts[placeholder] += 1
def mm_placeholder_counts(self) -> Dict[str, int]:
def mm_placeholder_counts(self) -> dict[str, int]:
return dict(self._placeholder_counts)
@abstractmethod
......@@ -652,12 +653,12 @@ def load_chat_template(
# TODO: Let user specify how to insert multimodal tokens into prompt
# (similar to chat template)
def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
text_prompt: str) -> str:
"""Combine multimodal prompts for a multimodal language model."""
# Look through the text prompt to check for missing placeholders
missing_placeholders: List[str] = []
missing_placeholders: list[str] = []
for placeholder in placeholder_counts:
# For any existing placeholder in the text prompt, we leave it as is
......@@ -684,10 +685,10 @@ _InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
_ContentPart: TypeAlias = Union[str, Dict[str, str], InputAudio]
_ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio]
# Define a mapping from part types to their corresponding parsing functions.
MM_PARSER_MAP: Dict[
MM_PARSER_MAP: dict[
str,
Callable[[ChatCompletionContentPartParam], _ContentPart],
] = {
......@@ -749,7 +750,7 @@ def _parse_chat_message_content_mm_part(
part)
return "audio_url", audio_params.get("audio_url", "")
if part.get("input_audio") is not None:
input_audio_params = cast(Dict[str, str], part)
input_audio_params = cast(dict[str, str], part)
return "input_audio", input_audio_params
if part.get("video_url") is not None:
video_params = cast(CustomChatCompletionContentSimpleVideoParam,
......@@ -773,7 +774,7 @@ def _parse_chat_message_content_parts(
mm_tracker: BaseMultiModalItemTracker,
*,
wrap_dicts: bool,
) -> List[ConversationMessage]:
) -> list[ConversationMessage]:
content = list[_ContentPart]()
mm_parser = mm_tracker.create_parser()
......@@ -791,7 +792,7 @@ def _parse_chat_message_content_parts(
# Parsing wraps images and texts as interleaved dictionaries
return [ConversationMessage(role=role,
content=content)] # type: ignore
texts = cast(List[str], content)
texts = cast(list[str], content)
text_prompt = "\n".join(texts)
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
if mm_placeholder_counts:
......@@ -866,7 +867,7 @@ def _parse_chat_message_content(
message: ChatCompletionMessageParam,
mm_tracker: BaseMultiModalItemTracker,
content_format: _ChatTemplateContentFormat,
) -> List[ConversationMessage]:
) -> list[ConversationMessage]:
role = message["role"]
content = message.get("content")
......@@ -900,7 +901,7 @@ def _parse_chat_message_content(
return result
def _postprocess_messages(messages: List[ConversationMessage]) -> None:
def _postprocess_messages(messages: list[ConversationMessage]) -> None:
# per the Transformers docs & maintainers, tool call arguments in
# assistant-role messages with tool_calls need to be dicts not JSON str -
# this is how tool-use chat templates will expect them moving forwards
......@@ -916,12 +917,12 @@ def _postprocess_messages(messages: List[ConversationMessage]) -> None:
def parse_chat_messages(
messages: List[ChatCompletionMessageParam],
messages: list[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: AnyTokenizer,
content_format: _ChatTemplateContentFormat,
) -> Tuple[List[ConversationMessage], Optional[MultiModalDataDict]]:
conversation: List[ConversationMessage] = []
) -> tuple[list[ConversationMessage], Optional[MultiModalDataDict]]:
conversation: list[ConversationMessage] = []
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
for msg in messages:
......@@ -939,12 +940,12 @@ def parse_chat_messages(
def parse_chat_messages_futures(
messages: List[ChatCompletionMessageParam],
messages: list[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: AnyTokenizer,
content_format: _ChatTemplateContentFormat,
) -> Tuple[List[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]:
conversation: List[ConversationMessage] = []
) -> tuple[list[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]:
conversation: list[ConversationMessage] = []
mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)
for msg in messages:
......@@ -963,7 +964,7 @@ def parse_chat_messages_futures(
def apply_hf_chat_template(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
conversation: List[ConversationMessage],
conversation: list[ConversationMessage],
chat_template: Optional[str],
*,
tokenize: bool = False, # Different from HF's default
......@@ -985,10 +986,10 @@ def apply_hf_chat_template(
def apply_mistral_chat_template(
tokenizer: MistralTokenizer,
messages: List[ChatCompletionMessageParam],
messages: list[ChatCompletionMessageParam],
chat_template: Optional[str] = None,
**kwargs: Any,
) -> List[int]:
) -> list[int]:
if chat_template is not None:
logger.warning_once(
"'chat_template' cannot be overridden for mistral tokenizer.")
......
......@@ -5,7 +5,7 @@ import argparse
import os
import signal
import sys
from typing import List, Optional, Tuple
from typing import Optional
from openai import OpenAI
from openai.types.chat import ChatCompletionMessageParam
......@@ -23,7 +23,7 @@ def _register_signal_handlers():
signal.signal(signal.SIGTSTP, signal_handler)
def _interactive_cli(args: argparse.Namespace) -> Tuple[str, OpenAI]:
def _interactive_cli(args: argparse.Namespace) -> tuple[str, OpenAI]:
_register_signal_handlers()
base_url = args.url
......@@ -43,7 +43,7 @@ def _interactive_cli(args: argparse.Namespace) -> Tuple[str, OpenAI]:
def chat(system_prompt: Optional[str], model_name: str,
client: OpenAI) -> None:
conversation: List[ChatCompletionMessageParam] = []
conversation: list[ChatCompletionMessageParam] = []
if system_prompt is not None:
conversation.append({"role": "system", "content": system_prompt})
......@@ -100,7 +100,7 @@ class ChatCommand(CLISubcommand):
def cmd(args: argparse.Namespace) -> None:
model_name, client = _interactive_cli(args)
system_prompt = args.system_prompt
conversation: List[ChatCompletionMessageParam] = []
conversation: list[ChatCompletionMessageParam] = []
if system_prompt is not None:
conversation.append({"role": "system", "content": system_prompt})
......@@ -168,5 +168,5 @@ class CompleteCommand(CLISubcommand):
return complete_parser
def cmd_init() -> List[CLISubcommand]:
def cmd_init() -> list[CLISubcommand]:
return [ChatCommand(), CompleteCommand()]
# SPDX-License-Identifier: Apache-2.0
import argparse
from typing import List
import uvloop
......@@ -59,5 +58,5 @@ class ServeSubcommand(CLISubcommand):
return make_arg_parser(serve_parser)
def cmd_init() -> List[CLISubcommand]:
def cmd_init() -> list[CLISubcommand]:
return [ServeSubcommand()]
......@@ -2,9 +2,9 @@
import itertools
import warnings
from collections.abc import Sequence
from contextlib import contextmanager
from typing import (Any, Callable, ClassVar, Dict, List, Optional, Sequence,
Tuple, Type, Union, cast, overload)
from typing import Any, Callable, ClassVar, Optional, Union, cast, overload
import cloudpickle
import torch.nn as nn
......@@ -177,11 +177,11 @@ class LLM:
disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = False,
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
mm_processor_kwargs: Optional[dict[str, Any]] = None,
# After positional args are removed, move this right below `model`
task: TaskOption = "auto",
override_pooler_config: Optional[PoolerConfig] = None,
compilation_config: Optional[Union[int, Dict[str, Any]]] = None,
compilation_config: Optional[Union[int, dict[str, Any]]] = None,
**kwargs,
) -> None:
'''
......@@ -246,7 +246,7 @@ class LLM:
self.request_counter = Counter()
@staticmethod
def get_engine_class() -> Type[LLMEngine]:
def get_engine_class() -> type[LLMEngine]:
if envs.VLLM_USE_V1:
# Lazy import: the v1 package isn't distributed
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
......@@ -283,11 +283,11 @@ class LLM:
Sequence[SamplingParams]]] = None,
*,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> List[RequestOutput]:
) -> list[RequestOutput]:
...
@overload # LEGACY: single (prompt + optional token ids)
......@@ -296,30 +296,30 @@ class LLM:
self,
prompts: str,
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
prompt_token_ids: Optional[List[int]] = None,
list[SamplingParams]]] = None,
prompt_token_ids: Optional[list[int]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> List[RequestOutput]:
) -> list[RequestOutput]:
...
@overload # LEGACY: multi (prompt + optional token ids)
@deprecated("'prompt_token_ids' will become part of 'prompts'")
def generate(
self,
prompts: List[str],
prompts: list[str],
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
list[SamplingParams]]] = None,
prompt_token_ids: Optional[list[list[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> List[RequestOutput]:
) -> list[RequestOutput]:
...
@overload # LEGACY: single (token ids + optional prompt)
......@@ -328,32 +328,32 @@ class LLM:
self,
prompts: Optional[str] = None,
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
list[SamplingParams]]] = None,
*,
prompt_token_ids: List[int],
prompt_token_ids: list[int],
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> List[RequestOutput]:
) -> list[RequestOutput]:
...
@overload # LEGACY: multi (token ids + optional prompt)
@deprecated("'prompt_token_ids' will become part of 'prompts'")
def generate(
self,
prompts: Optional[List[str]] = None,
prompts: Optional[list[str]] = None,
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
list[SamplingParams]]] = None,
*,
prompt_token_ids: List[List[int]],
prompt_token_ids: list[list[int]],
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> List[RequestOutput]:
) -> list[RequestOutput]:
...
@overload # LEGACY: single or multi token ids [pos-only]
......@@ -362,13 +362,13 @@ class LLM:
self,
prompts: None,
sampling_params: None,
prompt_token_ids: Union[List[int], List[List[int]]],
prompt_token_ids: Union[list[int], list[list[int]]],
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> List[RequestOutput]:
) -> list[RequestOutput]:
...
@deprecate_kwargs(
......@@ -379,17 +379,17 @@ class LLM:
def generate(
self,
prompts: Union[Union[PromptType, Sequence[PromptType]],
Optional[Union[str, List[str]]]] = None,
Optional[Union[str, list[str]]]] = None,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
priority: Optional[List[int]] = None,
) -> List[RequestOutput]:
priority: Optional[list[int]] = None,
) -> list[RequestOutput]:
"""Generates the completions for the input prompts.
This class automatically batches the given prompts, considering
......@@ -440,7 +440,7 @@ class LLM:
if prompt_token_ids is not None:
parsed_prompts = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
prompts=cast(Optional[Union[str, list[str]]], prompts),
prompt_token_ids=prompt_token_ids,
)
else:
......@@ -473,8 +473,8 @@ class LLM:
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict[str, Any]] = None) -> List[_R]:
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
"""
Execute an RPC call on all workers.
......@@ -510,9 +510,9 @@ class LLM:
def beam_search(
self,
prompts: List[Union[TokensPrompt, TextPrompt]],
prompts: list[Union[TokensPrompt, TextPrompt]],
params: BeamSearchParams,
) -> List[BeamSearchOutput]:
) -> list[BeamSearchOutput]:
"""
Generate sequences using beam search.
......@@ -543,7 +543,7 @@ class LLM:
beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature)
instances: List[BeamSearchInstance] = []
instances: list[BeamSearchInstance] = []
for prompt in prompts:
if is_token_prompt(prompt):
......@@ -553,12 +553,12 @@ class LLM:
instances.append(BeamSearchInstance(prompt_tokens))
for _ in range(max_tokens):
all_beams: List[BeamSearchSequence] = list(
all_beams: list[BeamSearchSequence] = list(
sum((instance.beams for instance in instances), []))
pos = [0] + list(
itertools.accumulate(
len(instance.beams) for instance in instances))
instance_start_and_end: List[Tuple[int, int]] = list(
instance_start_and_end: list[tuple[int, int]] = list(
zip(pos[:-1], pos[1:]))
if len(all_beams) == 0:
......@@ -620,19 +620,19 @@ class LLM:
def chat(
self,
messages: Union[List[ChatCompletionMessageParam],
List[List[ChatCompletionMessageParam]]],
messages: Union[list[ChatCompletionMessageParam],
list[list[ChatCompletionMessageParam]]],
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
list[SamplingParams]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
chat_template: Optional[str] = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
add_generation_prompt: bool = True,
continue_final_message: bool = False,
tools: Optional[List[Dict[str, Any]]] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> List[RequestOutput]:
tools: Optional[list[dict[str, Any]]] = None,
mm_processor_kwargs: Optional[dict[str, Any]] = None,
) -> list[RequestOutput]:
"""
Generate responses for a chat conversation.
......@@ -678,17 +678,17 @@ class LLM:
A list of ``RequestOutput`` objects containing the generated
responses in the same order as the input messages.
"""
list_of_messages: List[List[ChatCompletionMessageParam]]
list_of_messages: list[list[ChatCompletionMessageParam]]
# Handle multi and single conversations
if is_list_of(messages, list):
# messages is List[List[...]]
list_of_messages = cast(List[List[ChatCompletionMessageParam]],
# messages is list[list[...]]
list_of_messages = cast(list[list[ChatCompletionMessageParam]],
messages)
else:
# messages is List[...]
# messages is list[...]
list_of_messages = [
cast(List[ChatCompletionMessageParam], messages)
cast(list[ChatCompletionMessageParam], messages)
]
tokenizer = self.get_tokenizer()
......@@ -699,7 +699,7 @@ class LLM:
tokenizer,
)
prompts: List[Union[TokensPrompt, TextPrompt]] = []
prompts: list[Union[TokensPrompt, TextPrompt]] = []
for msgs in list_of_messages:
# NOTE: _parse_chat_message_content_parts() currently doesn't
......@@ -712,7 +712,7 @@ class LLM:
content_format=resolved_content_format,
)
prompt_data: Union[str, List[int]]
prompt_data: Union[str, list[int]]
if isinstance(tokenizer, MistralTokenizer):
prompt_data = apply_mistral_chat_template(
tokenizer,
......@@ -762,9 +762,9 @@ class LLM:
Sequence[PoolingParams]]] = None,
*,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[PoolingRequestOutput]:
) -> list[PoolingRequestOutput]:
...
@overload # LEGACY: single (prompt + optional token ids)
......@@ -774,25 +774,25 @@ class LLM:
prompts: str,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[List[int]] = None,
prompt_token_ids: Optional[list[int]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[PoolingRequestOutput]:
) -> list[PoolingRequestOutput]:
...
@overload # LEGACY: multi (prompt + optional token ids)
@deprecated("'prompt_token_ids' will become part of 'prompts'")
def encode(
self,
prompts: List[str],
prompts: list[str],
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
prompt_token_ids: Optional[list[list[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[PoolingRequestOutput]:
) -> list[PoolingRequestOutput]:
...
@overload # LEGACY: single (token ids + optional prompt)
......@@ -803,26 +803,26 @@ class LLM:
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
*,
prompt_token_ids: List[int],
prompt_token_ids: list[int],
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[PoolingRequestOutput]:
) -> list[PoolingRequestOutput]:
...
@overload # LEGACY: multi (token ids + optional prompt)
@deprecated("'prompt_token_ids' will become part of 'prompts'")
def encode(
self,
prompts: Optional[List[str]] = None,
prompts: Optional[list[str]] = None,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
*,
prompt_token_ids: List[List[int]],
prompt_token_ids: list[list[int]],
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[PoolingRequestOutput]:
) -> list[PoolingRequestOutput]:
...
@overload # LEGACY: single or multi token ids [pos-only]
......@@ -831,11 +831,11 @@ class LLM:
self,
prompts: None,
pooling_params: None,
prompt_token_ids: Union[List[int], List[List[int]]],
prompt_token_ids: Union[list[int], list[list[int]]],
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[PoolingRequestOutput]:
) -> list[PoolingRequestOutput]:
...
@deprecate_kwargs(
......@@ -846,14 +846,14 @@ class LLM:
def encode(
self,
prompts: Union[Union[PromptType, Sequence[PromptType]],
Optional[Union[str, List[str]]]] = None,
Optional[Union[str, list[str]]]] = None,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[PoolingRequestOutput]:
) -> list[PoolingRequestOutput]:
"""Apply pooling to the hidden states corresponding to the input
prompts.
......@@ -898,7 +898,7 @@ class LLM:
if prompt_token_ids is not None:
parsed_prompts = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
prompts=cast(Optional[Union[str, list[str]]], prompts),
prompt_token_ids=prompt_token_ids,
)
else:
......@@ -926,9 +926,9 @@ class LLM:
/,
*,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[EmbeddingRequestOutput]:
) -> list[EmbeddingRequestOutput]:
"""
Generate an embedding vector for each prompt.
......@@ -966,9 +966,9 @@ class LLM:
/,
*,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[ClassificationRequestOutput]:
) -> list[ClassificationRequestOutput]:
"""
Generate class logits for each prompt.
......@@ -1003,29 +1003,29 @@ class LLM:
def _embedding_score(
self,
tokenizer: AnyTokenizer,
text_1: List[Union[str, TextPrompt, TokensPrompt]],
text_2: List[Union[str, TextPrompt, TokensPrompt]],
text_1: list[Union[str, TextPrompt, TokensPrompt]],
text_2: list[Union[str, TextPrompt, TokensPrompt]],
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[ScoringRequestOutput]:
) -> list[ScoringRequestOutput]:
encoded_output: List[PoolingRequestOutput] = self.encode(
encoded_output: list[PoolingRequestOutput] = self.encode(
text_1 + text_2,
use_tqdm=use_tqdm,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
encoded_output_1: List[PoolingRequestOutput] = encoded_output[
encoded_output_1: list[PoolingRequestOutput] = encoded_output[
0:len(text_1)]
encoded_output_2: List[PoolingRequestOutput] = encoded_output[
encoded_output_2: list[PoolingRequestOutput] = encoded_output[
len(text_1):]
if len(encoded_output_1) == 1:
encoded_output_1 = encoded_output_1 * len(encoded_output_2)
scores: List[PoolingRequestOutput] = []
scores: list[PoolingRequestOutput] = []
scores = _cosine_similarity(tokenizer=tokenizer,
embed_1=encoded_output_1,
......@@ -1038,13 +1038,13 @@ class LLM:
def _cross_encoding_score(
self,
tokenizer: AnyTokenizer,
text_1: List[str],
text_2: List[str],
text_1: list[str],
text_2: list[str],
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[ScoringRequestOutput]:
) -> list[ScoringRequestOutput]:
if isinstance(tokenizer, MistralTokenizer):
raise ValueError(
......@@ -1057,7 +1057,7 @@ class LLM:
pooling_params = PoolingParams()
tokenization_kwargs: Dict[str, Any] = {}
tokenization_kwargs: dict[str, Any] = {}
if truncate_prompt_tokens is not None:
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens
......@@ -1094,9 +1094,9 @@ class LLM:
*,
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[ScoringRequestOutput]:
) -> list[ScoringRequestOutput]:
"""Generate similarity scores for all pairs ``<text,text_pair>``.
The inputs can be ``1 -> 1``, ``1 -> N`` or ``N -> N``.
......@@ -1162,12 +1162,12 @@ class LLM:
if isinstance(text_1, (str, dict)):
# Convert a single prompt to a list.
text_1 = [text_1]
input_text_1: List[str] = [ensure_str(t) for t in text_1]
input_text_1: list[str] = [ensure_str(t) for t in text_1]
if isinstance(text_2, (str, dict)):
# Convert a single prompt to a list.
text_2 = [text_2]
input_text_2: List[str] = [ensure_str(t) for t in text_2]
input_text_2: list[str] = [ensure_str(t) for t in text_2]
_validate_score_input_lens(input_text_1, input_text_2)
......@@ -1226,8 +1226,8 @@ class LLM:
# LEGACY
def _convert_v1_inputs(
self,
prompts: Optional[Union[str, List[str]]],
prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
prompts: Optional[Union[str, list[str]]],
prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
):
# skip_tokenizer_init is now checked in engine
......@@ -1252,7 +1252,7 @@ class LLM:
raise ValueError("Either prompts or prompt_token_ids must be "
"provided.")
parsed_prompts: List[PromptType] = []
parsed_prompts: list[PromptType] = []
for i in range(num_requests):
item: PromptType
......@@ -1275,7 +1275,7 @@ class LLM:
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
prompt_adapter_request: Optional[PromptAdapterRequest],
guided_options: Optional[GuidedDecodingRequest] = None,
priority: Optional[List[int]] = None,
priority: Optional[list[int]] = None,
) -> None:
if guided_options is not None:
warnings.warn(
......@@ -1357,7 +1357,7 @@ class LLM:
def _run_engine(
self, *, use_tqdm: bool
) -> List[Union[RequestOutput, PoolingRequestOutput]]:
) -> list[Union[RequestOutput, PoolingRequestOutput]]:
# Initialize tqdm.
if use_tqdm:
num_requests = self.llm_engine.get_num_unfinished_requests()
......@@ -1370,7 +1370,7 @@ class LLM:
)
# Run the engine.
outputs: List[Union[RequestOutput, PoolingRequestOutput]] = []
outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
total_in_toks = 0
total_out_toks = 0
while self.llm_engine.has_unfinished_requests():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment