Unverified Commit cf069aa8 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated Python 3.8 typing (#13971)

parent bf33700e
......@@ -2,7 +2,8 @@
import json
import re
from typing import Dict, List, Sequence, Union
from collections.abc import Sequence
from typing import Union
import partial_json_parser
from partial_json_parser.core.options import Allow
......@@ -33,9 +34,9 @@ class Hermes2ProToolParser(ToolParser):
self.model_tokenizer = self.model_tokenizer.tokenizer
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: List[Dict] = []
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.streamed_args_for_tool: List[str] = [
self.streamed_args_for_tool: list[str] = [
] # map what has been streamed for each tool so far to a list
self.tool_call_start_token: str = "<tool_call>"
......
# SPDX-License-Identifier: Apache-2.0
import json
from typing import Dict, Sequence, Union
from collections.abc import Sequence
from typing import Union
import partial_json_parser
from partial_json_parser.core.options import Allow
......@@ -90,7 +91,7 @@ class Internlm2ToolParser(ToolParser):
# tool calls are generated in an object in inernlm2
# it's not support parallel tool calls
try:
tool_call_arr: Dict = partial_json_parser.loads(
tool_call_arr: dict = partial_json_parser.loads(
parsable_arr, flags)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
......
......@@ -2,7 +2,8 @@
import json
import re
from typing import Dict, List, Sequence, Union
from collections.abc import Sequence
from typing import Union
import partial_json_parser
from partial_json_parser.core.options import Allow
......@@ -35,9 +36,9 @@ class JambaToolParser(ToolParser):
)
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: List[Dict] = []
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.streamed_args_for_tool: List[str] = [
self.streamed_args_for_tool: list[str] = [
] # map what has been streamed for each tool so far to a list
self.tool_calls_start_token: str = "<tool_calls>"
......@@ -157,7 +158,7 @@ class JambaToolParser(ToolParser):
# tool calls are generated in an array, so do partial JSON
# parsing on the entire array
try:
tool_call_arr: List[Dict] = partial_json_parser.loads(
tool_call_arr: list[dict] = partial_json_parser.loads(
parsable_arr, flags)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
......@@ -165,7 +166,7 @@ class JambaToolParser(ToolParser):
# select as the current tool call the one we're on the state at
current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
current_tool_call: dict = tool_call_arr[self.current_tool_id] \
if len(tool_call_arr) > 0 else {}
# case -- if no tokens have been streamed for the tool, e.g.
......
......@@ -2,8 +2,9 @@
import json
import re
from collections.abc import Sequence
from json import JSONDecoder
from typing import Dict, List, Sequence, Union
from typing import Union
import partial_json_parser
from partial_json_parser.core.options import Allow
......@@ -40,10 +41,10 @@ class Llama3JsonToolParser(ToolParser):
# initialize properties used for state when parsing tool calls in
# streaming mode
self.prev_tool_call_arr: List[Dict] = []
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: List[str] = [
self.streamed_args_for_tool: list[str] = [
] # map what has been streamed for each tool so far to a list
self.bot_token = "<|python_tag|>"
self.bot_token_id = tokenizer.encode(self.bot_token,
......@@ -78,7 +79,7 @@ class Llama3JsonToolParser(ToolParser):
start_idx += end_idx + len('; ')
function_call_arr.append(obj)
tool_calls: List[ToolCall] = [
tool_calls: list[ToolCall] = [
ToolCall(
type="function",
function=FunctionCall(
......@@ -152,7 +153,7 @@ class Llama3JsonToolParser(ToolParser):
return None
# select as the current tool call the one we're on the state at
current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
current_tool_call: dict = tool_call_arr[self.current_tool_id] \
if len(tool_call_arr) > 0 else {}
# case -- if no tokens have been streamed for the tool, e.g.
......
......@@ -2,9 +2,10 @@
import json
import re
from collections.abc import Sequence
from random import choices
from string import ascii_letters, digits
from typing import Dict, List, Sequence, Union
from typing import Union
import partial_json_parser
from partial_json_parser.core.options import Allow
......@@ -56,10 +57,10 @@ class MistralToolParser(ToolParser):
# initialize properties used for state when parsing tool calls in
# streaming mode
self.prev_tool_call_arr: List[Dict] = []
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: List[str] = [
self.streamed_args_for_tool: list[str] = [
] # map what has been streamed for each tool so far to a list
self.bot_token = "[TOOL_CALLS]"
self.bot_token_id = self.vocab.get(self.bot_token)
......@@ -104,7 +105,7 @@ class MistralToolParser(ToolParser):
function_call_arr = json.loads(raw_tool_call)
# Tool Call
tool_calls: List[MistralToolCall] = [
tool_calls: list[MistralToolCall] = [
MistralToolCall(
type="function",
function=FunctionCall(
......@@ -172,7 +173,7 @@ class MistralToolParser(ToolParser):
# tool calls are generated in an array, so do partial JSON
# parsing on the entire array
try:
tool_call_arr: List[Dict] = partial_json_parser.loads(
tool_call_arr: list[dict] = partial_json_parser.loads(
parsable_arr, flags)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
......@@ -180,7 +181,7 @@ class MistralToolParser(ToolParser):
# select as the current tool call the one we're on the state at
current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
current_tool_call: dict = tool_call_arr[self.current_tool_id] \
if len(tool_call_arr) > 0 else {}
# case -- if no tokens have been streamed for the tool, e.g.
......
......@@ -3,7 +3,8 @@
import ast
import json
import re
from typing import Any, Sequence, Tuple, Union
from collections.abc import Sequence
from typing import Any, Union
from transformers import PreTrainedTokenizerBase
......@@ -204,7 +205,7 @@ def _handle_single_tool(call: ast.Call) -> ToolCall:
arguments=json.dumps(arguments)))
def _make_valid_python(text: str) -> Union[Tuple[str, str], None]:
def _make_valid_python(text: str) -> Union[tuple[str, str], None]:
bracket_stack = []
for index, char in enumerate(text):
if char in {"[", "(", "{"}:
......
......@@ -2,7 +2,7 @@
import json
from json import JSONDecodeError, JSONDecoder
from typing import Any, List, Tuple
from typing import Any
import partial_json_parser
from partial_json_parser.core.options import Allow
......@@ -82,7 +82,7 @@ def extract_intermediate_diff(curr: str, old: str) -> str:
return diff
def find_all_indices(string: str, substring: str) -> List[int]:
def find_all_indices(string: str, substring: str) -> list[int]:
"""
Find all (starting) indices of a substring in a given string. Useful for
tool call extraction
......@@ -99,7 +99,7 @@ def find_all_indices(string: str, substring: str) -> List[int]:
# partial_json_parser doesn't support extra data and
# JSONDecorder.raw_decode doesn't support partial JSON
def partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]:
try:
return (partial_json_parser.loads(input_str, flags), len(input_str))
except JSONDecodeError as e:
......
# SPDX-License-Identifier: Apache-2.0
from typing import List, Union
from typing import Union
from torch.nn import CosineSimilarity
......@@ -10,12 +10,12 @@ from vllm.transformers_utils.tokenizer import (PreTrainedTokenizer,
def _cosine_similarity(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
embed_1: List[PoolingRequestOutput],
embed_2: List[PoolingRequestOutput],
) -> List[PoolingRequestOutput]:
embed_1: list[PoolingRequestOutput],
embed_2: list[PoolingRequestOutput],
) -> list[PoolingRequestOutput]:
scorer = CosineSimilarity(0)
scores: Union[List[PoolingRequestOutput]] = []
scores: Union[list[PoolingRequestOutput]] = []
for emb_1, emb_2 in zip(embed_1, embed_2):
pair_score = scorer(emb_1.outputs.data, emb_2.outputs.data)
......@@ -38,8 +38,8 @@ def _cosine_similarity(
def _validate_score_input_lens(
texts_1: Union[List[str], List[dict]],
texts_2: Union[List[str], List[dict]],
texts_1: Union[list[str], list[dict]],
texts_2: Union[list[str], list[dict]],
):
if len(texts_1) > 1 and len(texts_1) != len(texts_2):
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
......
......@@ -2,7 +2,7 @@
import os
import tempfile
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Callable, Optional
if TYPE_CHECKING:
VLLM_HOST_IP: str = ""
......@@ -67,12 +67,12 @@ if TYPE_CHECKING:
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
VLLM_TEST_FORCE_FP8_MARLIN: bool = False
VLLM_RPC_TIMEOUT: int = 10000 # ms
VLLM_PLUGINS: Optional[List[str]] = None
VLLM_PLUGINS: Optional[list[str]] = None
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
VLLM_USE_TRITON_AWQ: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_SKIP_P2P_CHECK: bool = False
VLLM_DISABLED_KERNELS: List[str] = []
VLLM_DISABLED_KERNELS: list[str] = []
VLLM_USE_V1: bool = False
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
......@@ -123,7 +123,7 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
# begin-env-vars-definition
environment_variables: Dict[str, Callable[[], Any]] = {
environment_variables: dict[str, Callable[[], Any]] = {
# ================== Installation Time Env Vars ==================
......
......@@ -4,7 +4,7 @@ import time
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Optional
import torch
import torch.distributed as dist
......@@ -28,13 +28,13 @@ batchsize_forward_time: defaultdict = defaultdict(list)
@dataclass
class ForwardContext:
# copy from vllm_config.compilation_config.static_forward_context
attn_layers: Dict[str, Any]
attn_layers: dict[str, Any]
# TODO: extend to support per-layer dynamic forward context
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
# TODO: remove after making all virtual_engines share the same kv cache
virtual_engine: int # set dynamically for each forward pass
num_tokens_across_dp: Optional[
List[int]] = None # set dynamically for each forward pass
list[int]] = None # set dynamically for each forward pass
_forward_context: Optional[ForwardContext] = None
......
......@@ -109,7 +109,7 @@ def _configure_vllm_root_logger() -> None:
custom_config = json.loads(file.read())
if not isinstance(custom_config, dict):
raise ValueError("Invalid logging config. Expected Dict, got %s.",
raise ValueError("Invalid logging config. Expected dict, got %s.",
type(custom_config).__name__)
logging_config = custom_config
......
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Tuple, Union
from typing import Callable, Union
import torch
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor],
Callable[[List[int], List[int], torch.Tensor],
LogitsProcessor = Union[Callable[[list[int], torch.Tensor], torch.Tensor],
Callable[[list[int], list[int], torch.Tensor],
torch.Tensor]]
"""LogitsProcessor is a function that takes a list
of previously generated tokens, the logits tensor
......@@ -17,9 +17,9 @@ to sample from."""
def get_bad_words_logits_processors(
bad_words: List[str],
tokenizer: AnyTokenizer) -> List[LogitsProcessor]:
bad_words_ids: List[List[int]] = list()
bad_words: list[str],
tokenizer: AnyTokenizer) -> list[LogitsProcessor]:
bad_words_ids: list[list[int]] = list()
for bad_word in bad_words:
# To prohibit words both at the beginning
......@@ -51,13 +51,13 @@ class NoBadWordsLogitsProcessor:
_SMALLEST_LOGIT = float("-inf")
_NEUTRAL_LOGIT = 0.0
def __init__(self, bad_words_ids: List[List[int]]):
def __init__(self, bad_words_ids: list[list[int]]):
self.bad_words_ids = bad_words_ids
self.word_bias: torch.FloatTensor = None
def __call__(
self,
past_tokens_ids: Union[List[int], Tuple[int]],
past_tokens_ids: Union[list[int], tuple[int]],
logits: torch.FloatTensor,
) -> torch.Tensor:
if self.word_bias is None:
......
# SPDX-License-Identifier: Apache-2.0
import time
from collections.abc import MutableSequence
from collections.abc import Sequence as GenericSequence
from dataclasses import dataclass
from typing import Dict, Generic, List, MutableSequence, Optional
from typing import Sequence as GenericSequence
from typing import Union
from typing import Generic, Optional, Union
import torch
from typing_extensions import TypeVar, deprecated
......@@ -109,14 +109,14 @@ class RequestOutput:
self,
request_id: str,
prompt: Optional[str],
prompt_token_ids: Optional[List[int]],
prompt_token_ids: Optional[list[int]],
prompt_logprobs: Optional[PromptLogprobs],
outputs: List[CompletionOutput],
outputs: list[CompletionOutput],
finished: bool,
metrics: Optional[RequestMetrics] = None,
lora_request: Optional[LoRARequest] = None,
encoder_prompt: Optional[str] = None,
encoder_prompt_token_ids: Optional[List[int]] = None,
encoder_prompt_token_ids: Optional[list[int]] = None,
num_cached_tokens: Optional[int] = None,
*,
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None,
......@@ -139,9 +139,9 @@ class RequestOutput:
cls,
request_id: str,
prompt: Optional[str],
prompt_token_ids: Optional[List[int]],
prompt_token_ids: Optional[list[int]],
text: str,
token_ids: List[int],
token_ids: list[int],
logprobs: Optional[SampleLogprobs],
prompt_logprobs: Optional[PromptLogprobs],
cumulative_logprob: Optional[float],
......@@ -189,7 +189,7 @@ class RequestOutput:
@classmethod
def from_seq_group(
cls, seq_group: SequenceGroup, use_cache: bool,
seq_id_to_seq_group: Dict[str, SequenceGroupBase]
seq_id_to_seq_group: dict[str, SequenceGroupBase]
) -> Optional["RequestOutput"]:
finished = seq_group.is_finished()
......@@ -363,12 +363,12 @@ class PoolingRequestOutput(Generic[_O]):
Args:
request_id (str): A unique identifier for the pooling request.
outputs (PoolingOutput): The pooling results for the given input.
prompt_token_ids (List[int]): A list of token IDs used in the prompt.
prompt_token_ids (list[int]): A list of token IDs used in the prompt.
finished (bool): A flag indicating whether the pooling is completed.
"""
def __init__(self, request_id: str, outputs: _O,
prompt_token_ids: List[int], finished: bool):
prompt_token_ids: list[int], finished: bool):
self.request_id = request_id
self.prompt_token_ids = prompt_token_ids
self.finished = finished
......@@ -407,7 +407,7 @@ class RequestOutputFactory:
@staticmethod
def create(seq_group: SequenceGroup,
seq_id_to_seq_group: Dict[str, SequenceGroupBase],
seq_id_to_seq_group: dict[str, SequenceGroupBase],
use_cache: bool = False):
if seq_group.pooled_data is not None:
return PoolingRequestOutput.from_seq_group(seq_group)
......
......@@ -4,11 +4,10 @@ import copy
from dataclasses import dataclass
from enum import Enum, IntEnum
from functools import cached_property
from typing import Any, Dict, List, Optional, Set, Union
from typing import Annotated, Any, Optional, Union
import msgspec
from pydantic import BaseModel
from typing_extensions import Annotated
from vllm.logger import init_logger
from vllm.logits_process import LogitsProcessor
......@@ -29,9 +28,9 @@ class SamplingType(IntEnum):
@dataclass
class GuidedDecodingParams:
"""One of these fields will be used to build a logit processor."""
json: Optional[Union[str, Dict]] = None
json: Optional[Union[str, dict]] = None
regex: Optional[str] = None
choice: Optional[List[str]] = None
choice: Optional[list[str]] = None
grammar: Optional[str] = None
json_object: Optional[bool] = None
"""These are other options that can be set"""
......@@ -40,9 +39,9 @@ class GuidedDecodingParams:
@staticmethod
def from_optional(
json: Optional[Union[Dict, BaseModel, str]] = None,
json: Optional[Union[dict, BaseModel, str]] = None,
regex: Optional[str] = None,
choice: Optional[List[str]] = None,
choice: Optional[list[str]] = None,
grammar: Optional[str] = None,
json_object: Optional[bool] = None,
backend: Optional[str] = None,
......@@ -72,7 +71,7 @@ class GuidedDecodingParams:
"""
return (self.backend or "").split(":")[0]
def backend_options(self) -> List[str]:
def backend_options(self) -> list[str]:
"""Return the backend options as a list of strings."""
if not self.backend or ":" not in self.backend:
return []
......@@ -144,12 +143,12 @@ class SamplingParams(
considered, relative to the probability of the most likely token.
Must be in [0, 1]. Set to 0 to disable this.
seed: Random seed to use for the generation.
stop: List of strings that stop the generation when they are generated.
stop: list of strings that stop the generation when they are generated.
The returned output will not contain the stop strings.
stop_token_ids: List of tokens that stop the generation when they are
stop_token_ids: list of tokens that stop the generation when they are
generated. The returned output will contain the stop tokens unless
the stop tokens are special tokens.
bad_words: List of words that are not allowed to be generated.
bad_words: list of words that are not allowed to be generated.
More precisely, only the last token of a corresponding
token sequence is not allowed when the next generated token
can complete the sequence.
......@@ -172,7 +171,7 @@ class SamplingParams(
skip_special_tokens: Whether to skip special tokens in the output.
spaces_between_special_tokens: Whether to add spaces between special
tokens in the output. Defaults to True.
logits_processors: List of functions that modify logits based on
logits_processors: list of functions that modify logits based on
previously generated tokens, and optionally prompt tokens as
a first argument.
truncate_prompt_tokens: If set to an integer k, will use only the last k
......@@ -198,9 +197,9 @@ class SamplingParams(
top_k: int = -1
min_p: float = 0.0
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stop_token_ids: Optional[List[int]] = None
bad_words: Optional[List[str]] = None
stop: Optional[Union[str, list[str]]] = None
stop_token_ids: Optional[list[int]] = None
bad_words: Optional[list[str]] = None
ignore_eos: bool = False
max_tokens: Optional[int] = 16
min_tokens: int = 0
......@@ -212,8 +211,8 @@ class SamplingParams(
detokenize: bool = True
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
# Optional[List[LogitsProcessor]] type. We use Any here because
# Optional[List[LogitsProcessor]] type is not supported by msgspec.
# Optional[list[LogitsProcessor]] type. We use Any here because
# Optional[list[LogitsProcessor]] type is not supported by msgspec.
logits_processors: Optional[Any] = None
include_stop_str_in_output: bool = False
truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None
......@@ -222,12 +221,12 @@ class SamplingParams(
# The below fields are not supposed to be used as an input.
# They are set in post_init.
output_text_buffer_length: int = 0
_all_stop_token_ids: Set[int] = msgspec.field(default_factory=set)
_all_stop_token_ids: set[int] = msgspec.field(default_factory=set)
# Fields used to construct logits processors
guided_decoding: Optional[GuidedDecodingParams] = None
logit_bias: Optional[Dict[int, float]] = None
allowed_token_ids: Optional[List[int]] = None
logit_bias: Optional[dict[int, float]] = None
allowed_token_ids: Optional[list[int]] = None
@staticmethod
def from_optional(
......@@ -241,9 +240,9 @@ class SamplingParams(
top_k: int = -1,
min_p: float = 0.0,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
bad_words: Optional[List[str]] = None,
stop: Optional[Union[str, list[str]]] = None,
stop_token_ids: Optional[list[int]] = None,
bad_words: Optional[list[str]] = None,
include_stop_str_in_output: bool = False,
ignore_eos: bool = False,
max_tokens: Optional[int] = 16,
......@@ -253,13 +252,13 @@ class SamplingParams(
detokenize: bool = True,
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True,
logits_processors: Optional[List[LogitsProcessor]] = None,
logits_processors: Optional[list[LogitsProcessor]] = None,
truncate_prompt_tokens: Optional[Annotated[int,
msgspec.Meta(ge=1)]] = None,
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
guided_decoding: Optional[GuidedDecodingParams] = None,
logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]] = None,
allowed_token_ids: Optional[List[int]] = None,
logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None,
allowed_token_ids: Optional[list[int]] = None,
) -> "SamplingParams":
if logit_bias is not None:
# Convert token_id to integer
......@@ -435,7 +434,7 @@ class SamplingParams(
def update_from_generation_config(
self,
generation_config: Dict[str, Any],
generation_config: dict[str, Any],
model_eos_token_id: Optional[int] = None) -> None:
"""Update if there are non-default values from generation_config"""
......@@ -468,7 +467,7 @@ class SamplingParams(
return SamplingType.RANDOM
@property
def all_stop_token_ids(self) -> Set[int]:
def all_stop_token_ids(self) -> set[int]:
return self._all_stop_token_ids
def clone(self) -> "SamplingParams":
......
......@@ -5,11 +5,11 @@ import enum
from abc import ABC, abstractmethod
from array import array
from collections import defaultdict
from collections.abc import Mapping
from collections.abc import Sequence as GenericSequence
from dataclasses import dataclass, field
from functools import reduce
from typing import Any, Callable, DefaultDict, Dict, List, Mapping, Optional
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Union
from typing import Any, Callable, Optional, Union
import msgspec
import torch
......@@ -50,9 +50,9 @@ class Logprob:
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
PromptLogprobs = List[Optional[Dict[int, Logprob]]]
PromptLogprobs = list[Optional[dict[int, Logprob]]]
# {token_id -> logprob} for each sequence group.
SampleLogprobs = List[Dict[int, Logprob]]
SampleLogprobs = list[dict[int, Logprob]]
class SequenceStatus(enum.IntEnum):
......@@ -129,7 +129,7 @@ class SequenceDataDelta(
omit_defaults=True): # type: ignore[call-arg]
"""Delta SequenceData to send to workers per step."""
# A new token to be appended to existing SequenceData.
new_output_token_ids: List[int]
new_output_token_ids: list[int]
# Overwriting existing `cumulative_logprob`
new_cumulative_logprob: float
# Overwriting existing `num_computed_tokens`.
......@@ -152,7 +152,7 @@ class SequenceData(msgspec.Struct,
output_token_ids: The token IDs of the output.
cumulative_logprob: The cumulative log probability of the output.
"""
# NOTE: we cannot use Union[List, array] because msgspec cannot support
# NOTE: we cannot use Union[list, array] because msgspec cannot support
# union of 2 list types.
_prompt_token_ids: array
_output_token_ids: array = msgspec.field(
......@@ -160,25 +160,25 @@ class SequenceData(msgspec.Struct,
### The below fields should not be passed as an argument ###
_cumulative_logprob: float = 0.0
_prompt_token_ids_tuple: Tuple[int,
_prompt_token_ids_tuple: tuple[int,
...] = msgspec.field(default_factory=tuple)
# The number of tokens that are computed (that run against the model).
_num_computed_tokens: int = 0
# The number of tokens with prefix cache hit.
_num_cached_tokens: int = 0
_stage: SequenceStage = SequenceStage.PREFILL
_cached_all_token_ids: List[int] = msgspec.field(default_factory=list)
_cached_all_token_ids: list[int] = msgspec.field(default_factory=list)
# It is used to get delta input. It is reset when `get_delta_and_reset`
# is called.
_new_appended_tokens: List[int] = msgspec.field(default_factory=list)
_new_appended_tokens: list[int] = msgspec.field(default_factory=list)
# It is used to compute mrope_position_ids.
_mrope_position_delta: Optional[int] = None
@staticmethod
def from_prompt_token_counts(
*token_counts: Tuple[int, int]) -> "SequenceData":
*token_counts: tuple[int, int]) -> "SequenceData":
"""
Construct a :class:`SequenceData` instance by concatenating
prompt token sequences.
......@@ -220,14 +220,14 @@ class SequenceData(msgspec.Struct,
def __post_init__(self) -> None:
assert self._prompt_token_ids.typecode == "l"
assert self._output_token_ids.typecode == "l"
self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(
self._prompt_token_ids_tuple: tuple[int, ...] = tuple(
self._prompt_token_ids)
self._update_cached_all_tokens()
def _update_cached_all_tokens(self):
assert isinstance(self._prompt_token_ids, array)
assert isinstance(self._output_token_ids, array)
self._cached_all_token_ids: List[int] = list(self._prompt_token_ids +
self._cached_all_token_ids: list[int] = list(self._prompt_token_ids +
self._output_token_ids)
@property
......@@ -235,7 +235,7 @@ class SequenceData(msgspec.Struct,
return self._cumulative_logprob
@property
def prompt_token_ids(self) -> Tuple[int, ...]:
def prompt_token_ids(self) -> tuple[int, ...]:
return self._prompt_token_ids_tuple
@prompt_token_ids.setter
......@@ -252,7 +252,7 @@ class SequenceData(msgspec.Struct,
return self._prompt_token_ids
@property
def output_token_ids(self) -> Tuple[int, ...]:
def output_token_ids(self) -> tuple[int, ...]:
return tuple(self._output_token_ids)
@output_token_ids.setter
......@@ -295,12 +295,12 @@ class SequenceData(msgspec.Struct,
def get_output_len(self) -> int:
return len(self._output_token_ids)
def get_token_ids(self) -> List[int]:
def get_token_ids(self) -> list[int]:
return self._cached_all_token_ids
def get_prefix_token_ids(
self, num_tokens: int
) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]:
) -> tuple[tuple[int, ...], Optional[tuple[int, ...]]]:
"""Get prefix tokens, and make the return value hashable"""
prompt_length = self.get_prompt_len()
if num_tokens > prompt_length:
......@@ -351,10 +351,10 @@ class SequenceData(msgspec.Struct,
return self._prompt_token_ids[-1]
return self._output_token_ids[-1]
def get_prompt_token_ids(self) -> Tuple[int, ...]:
def get_prompt_token_ids(self) -> tuple[int, ...]:
return self.prompt_token_ids
def get_output_token_ids(self) -> Tuple[int, ...]:
def get_output_token_ids(self) -> tuple[int, ...]:
return self.output_token_ids
def get_delta_and_reset(self) -> SequenceDataDelta:
......@@ -432,7 +432,7 @@ class Sequence:
self.prefix_offset = 0
self.read_offset = 0
# Input + output tokens
self.tokens: Optional[List[str]] = None
self.tokens: Optional[list[str]] = None
@property
def n_blocks(self) -> int:
......@@ -443,7 +443,7 @@ class Sequence:
return self.inputs.prompt
@property
def prompt_token_ids(self) -> List[int]:
def prompt_token_ids(self) -> list[int]:
return self.inputs.prompt_token_ids
@property
......@@ -451,7 +451,7 @@ class Sequence:
return self.inputs.prompt_embeds
@property
def token_type_ids(self) -> List[int]:
def token_type_ids(self) -> list[int]:
return self.inputs.token_type_ids
@property
......@@ -463,7 +463,7 @@ class Sequence:
return self.inputs.multi_modal_placeholders
@property
def mm_processor_kwargs(self) -> Dict[str, Any]:
def mm_processor_kwargs(self) -> dict[str, Any]:
return self.inputs.mm_processor_kwargs
@property
......@@ -548,7 +548,7 @@ class Sequence:
"""Reset the sequence states for recomputation."""
self.data.reset_state_for_recompute()
def append_token_id(self, token_id: int, logprobs: Dict[int,
def append_token_id(self, token_id: int, logprobs: dict[int,
Logprob]) -> None:
assert token_id in logprobs
self.output_logprobs.append(logprobs)
......@@ -563,16 +563,16 @@ class Sequence:
def get_output_len(self) -> int:
return self.data.get_output_len()
def get_token_ids(self) -> List[int]:
def get_token_ids(self) -> list[int]:
return self.data.get_token_ids()
def get_prompt_token_ids(self) -> Tuple[int, ...]:
def get_prompt_token_ids(self) -> tuple[int, ...]:
return self.data.get_prompt_token_ids()
def get_last_token_id(self) -> int:
return self.data.get_last_token_id()
def get_output_token_ids(self) -> Tuple[int, ...]:
def get_output_token_ids(self) -> tuple[int, ...]:
return self.data.get_output_token_ids()
def get_cumulative_logprob(self) -> float:
......@@ -644,7 +644,7 @@ class SequenceGroup:
def __init__(
self,
request_id: str,
seqs: List[Sequence],
seqs: list[Sequence],
arrival_time: float,
sampling_params: Optional[SamplingParams] = None,
lora_request: Optional[LoRARequest] = None,
......@@ -686,7 +686,7 @@ class SequenceGroup:
return self.first_seq.prompt
@property
def prompt_token_ids(self) -> List[int]:
def prompt_token_ids(self) -> list[int]:
return self.first_seq.prompt_token_ids
@property
......@@ -698,7 +698,7 @@ class SequenceGroup:
if self.encoder_seq is not None else None)
@property
def encoder_prompt_token_ids(self) -> Optional[List[int]]:
def encoder_prompt_token_ids(self) -> Optional[list[int]]:
# There are either 0 or 1 encoder sequences
# If one is present, its prompt token ids are
# distinct from the decoder's.
......@@ -706,7 +706,7 @@ class SequenceGroup:
if self.encoder_seq is not None else None)
@property
def token_type_ids(self) -> Optional[List[int]]:
def token_type_ids(self) -> Optional[list[int]]:
return self.first_seq.token_type_ids
@property
......@@ -726,7 +726,7 @@ class SequenceGroup:
return {}
@property
def mm_processor_kwargs(self) -> Dict[str, Any]:
def mm_processor_kwargs(self) -> dict[str, Any]:
if self.first_seq.multi_modal_data:
return self.first_seq.mm_processor_kwargs
elif self.encoder_seq is not None:
......@@ -823,7 +823,7 @@ class SequenceGroup:
def get_seqs(
self,
status: Optional[SequenceStatus] = None,
) -> List[Sequence]:
) -> list[Sequence]:
if status is None:
return self.seqs
......@@ -838,7 +838,7 @@ class SequenceGroup:
def get_encoder_seq(self) -> Optional[Sequence]:
return self.encoder_seq
def get_finished_seqs(self) -> List[Sequence]:
def get_finished_seqs(self) -> list[Sequence]:
if self.is_single_seq:
return self.seqs if self.first_seq.is_finished() else []
......@@ -897,13 +897,13 @@ class SequenceGroupMetadataDelta(
After sending the first SequenceGroupMetadata, vLLM scheduler
only sends delta to reduce the data payload size.
"""
seq_data_delta: Dict[int, SequenceDataDelta]
seq_data_delta: dict[int, SequenceDataDelta]
request_id: str
block_tables: Dict[int, List[int]]
block_tables: dict[int, list[int]]
is_prompt: bool
do_sample: bool = True
token_chunk_size: Optional[int] = None
computed_block_nums: Optional[List[int]] = None
computed_block_nums: Optional[list[int]] = None
state: Optional[SequenceGroupState] = msgspec.field(
default_factory=lambda: SequenceGroupState())
......@@ -947,23 +947,23 @@ class SequenceGroupMetadata(
request_id: str
is_prompt: bool
seq_data: Dict[int, SequenceData]
seq_data: dict[int, SequenceData]
sampling_params: Optional[SamplingParams]
block_tables: Dict[int, List[int]]
block_tables: dict[int, list[int]]
do_sample: bool = True
pooling_params: Optional[PoolingParams] = None
lora_request: Optional[LoRARequest] = None
computed_block_nums: Optional[List[int]] = None
computed_block_nums: Optional[list[int]] = None
state: Optional[SequenceGroupState] = msgspec.field(
default_factory=lambda: SequenceGroupState())
# "MultiModalDataDict" types. We have to use Any due to msgspec
# doesn't allow to have union of 2 different dicts.
token_type_ids: Optional[List[int]] = None
token_type_ids: Optional[list[int]] = None
multi_modal_data: Optional[Any] = None
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None
mm_processor_kwargs: Optional[dict[str, Any]] = None
encoder_seq_data: Optional[SequenceData] = None
cross_block_table: Optional[List[int]] = None
cross_block_table: Optional[list[int]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None
token_chunk_size: Optional[int] = None
......@@ -1042,7 +1042,7 @@ class SequenceOutput(
"""
parent_seq_id: int
output_token: int
logprobs: Dict[int, Logprob]
logprobs: dict[int, Logprob]
def __repr__(self) -> str:
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
......@@ -1076,7 +1076,7 @@ class CompletionSequenceGroupOutput(
array_like=True): # type: ignore[call-arg]
"""The model output associated with a completion sequence group."""
__metaclass__ = SequenceGroupOutput
samples: List[SequenceOutput]
samples: list[SequenceOutput]
# Prompt logprob for each prompt query token.
prompt_logprobs: Optional[PromptLogprobs]
......@@ -1119,7 +1119,7 @@ class IntermediateTensors:
contains the hidden states and residuals for a request.
"""
tensors: Dict[str, torch.Tensor]
tensors: dict[str, torch.Tensor]
def __init__(self, tensors):
# manually define this function, so that
......@@ -1155,7 +1155,7 @@ class PoolerOutput(
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""The output from a pooling operation in the pooling model."""
outputs: List[PoolingSequenceGroupOutput]
outputs: list[PoolingSequenceGroupOutput]
def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput:
return self.outputs[idx]
......@@ -1172,7 +1172,7 @@ class PoolerOutput(
def get_all_seq_ids(
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
seq_group_metadata_list: list[SequenceGroupMetadata]) -> list[int]:
"""Given a list of SequenceGroupMetadata, create a list of all
sequence ids.
"""
......@@ -1180,13 +1180,13 @@ def get_all_seq_ids(
def get_all_seq_ids_and_request_ids(
seq_group_metadata_list: List[SequenceGroupMetadata]
) -> Tuple[List[int], Dict[str, Set[int]]]:
seq_group_metadata_list: list[SequenceGroupMetadata]
) -> tuple[list[int], dict[str, set[int]]]:
"""Given a list of SequenceGroupMetadata, create a list of all
sequence ids.
"""
seq_ids: List[int] = []
request_id_seq_ids_mapping: DefaultDict[str, Set[int]] = defaultdict(set)
seq_ids: list[int] = []
request_id_seq_ids_mapping: defaultdict[str, set[int]] = defaultdict(set)
for sg in seq_group_metadata_list:
for seq_id in sg.seq_data:
seq_ids.append(seq_id)
......@@ -1206,14 +1206,14 @@ class HiddenStates(msgspec.Struct, array_like=True,
# all tokens, whereas for decode step, it use used for last accepted tokens.
hidden_states: torch.Tensor
# The sequence group metadata list. Only needed for decode step.
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
seq_group_metadata_list: Optional[list[SequenceGroupMetadata]] = None
# Scorer hidden states of the 2nd last token proposed by the proposer (
# irrespective of whether it was accepted or not). Only used for cases when
# last proposed token is accepted (i.e., in case of bonus tokens). For the
# case of no bonus tokens, these are ignored.
second_last_token_hidden_states: Optional[torch.Tensor] = None
_seq_ids: List[int] = msgspec.field(default_factory=list)
_seq_ids: list[int] = msgspec.field(default_factory=list)
def __post_init__(self):
if self.seq_group_metadata_list is not None:
......@@ -1221,12 +1221,12 @@ class HiddenStates(msgspec.Struct, array_like=True,
self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
@property
def seq_ids(self) -> List[int]:
def seq_ids(self) -> list[int]:
return self._seq_ids
def update(self,
hidden_states: torch.Tensor,
seq_group_metadata_list: List[SequenceGroupMetadata],
seq_group_metadata_list: list[SequenceGroupMetadata],
second_last_token_hidden_states: Optional[torch.Tensor] = None):
"""Update hidden states from target model invocation. Only used for
decode steps"""
......@@ -1244,7 +1244,7 @@ class HiddenStates(msgspec.Struct, array_like=True,
])
def prune(self,
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
seq_group_metadata_list: list[SequenceGroupMetadata]) -> None:
"""Prune to provided list of sequence ids. Only used for decode steps.
"""
# Currently this prunes all seq_ids not present in
......@@ -1287,16 +1287,16 @@ class ExecuteModelRequest(
"""The model execution request, containing CPU metadata only. The LLM
engine should create an instance of this class for each request batch."""
# The sequence group metadata list.
seq_group_metadata_list: List[Union[SequenceGroupMetadata,
seq_group_metadata_list: list[Union[SequenceGroupMetadata,
SequenceGroupMetadataDelta]]
# Blocks to swap in. List of CPU -> GPU block number.
blocks_to_swap_in: List[Tuple[int,
blocks_to_swap_in: list[tuple[int,
int]] = msgspec.field(default_factory=list)
# Blocks to swap out. List of GPU -> CPU block number.
blocks_to_swap_out: List[Tuple[int,
blocks_to_swap_out: list[tuple[int,
int]] = msgspec.field(default_factory=list)
# Blocks to copy. Source to dest block.
blocks_to_copy: List[Tuple[int, int]] = msgspec.field(default_factory=list)
blocks_to_copy: list[tuple[int, int]] = msgspec.field(default_factory=list)
# Virtual engine ID for pipeline parallel.
virtual_engine: int = 0
# The number of slots for lookahead decoding.
......@@ -1310,7 +1310,7 @@ class ExecuteModelRequest(
# The step index for spec model input.
spec_step_idx: Optional[int] = None
# Finished request ids since last step.
finished_requests_ids: List[str] = msgspec.field(default_factory=list)
finished_requests_ids: list[str] = msgspec.field(default_factory=list)
# The last sampled token ids for multi step decoding.
last_sampled_token_ids: Optional[torch.Tensor] = None
# Async callback
......@@ -1344,7 +1344,7 @@ class ExecuteModelRequest(
return state.current_step
def clone(
self, seq_group_metadata_list: List[Union[SequenceGroupMetadata,
self, seq_group_metadata_list: list[Union[SequenceGroupMetadata,
SequenceGroupMetadataDelta]]
) -> "ExecuteModelRequest":
"""Clone the request with a new sequence group metadata list."""
......@@ -1371,13 +1371,13 @@ class SequenceGroupBase:
assembled_seq_group: Optional[SequenceGroup] = None
# seq id to a unique index inside this group
seq_id_to_index: Dict[str, int] = field(default_factory=dict)
seq_id_to_index: dict[str, int] = field(default_factory=dict)
# seq ids to be finished
to_be_finished: Dict[str, SequenceGroup] = field(default_factory=dict)
to_be_finished: dict[str, SequenceGroup] = field(default_factory=dict)
# seq id to finished sequences
finished_reqs: Dict[str, SequenceGroup] = field(default_factory=dict)
finished_reqs: dict[str, SequenceGroup] = field(default_factory=dict)
streaming: bool = False
......
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Mapping, Optional
from collections.abc import Mapping
from typing import Optional
from vllm.logger import init_logger
from vllm.utils import run_once
......
......@@ -28,12 +28,12 @@ import warnings
import weakref
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
from collections import OrderedDict, UserDict, defaultdict
from collections.abc import Hashable, Iterable, Mapping
from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable,
Iterable, Iterator, Mapping)
from dataclasses import dataclass, field
from functools import cache, lru_cache, partial, wraps
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
Dict, Generator, Generic, Iterator, List, Literal,
NamedTuple, Optional, Tuple, Type, TypeVar, Union)
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
Optional, TypeVar, Union)
from uuid import uuid4
import cloudpickle
......@@ -400,7 +400,7 @@ def _next_task(iterator: AsyncGenerator[T, None],
async def merge_async_iterators(
*iterators: AsyncGenerator[T,
None], ) -> AsyncGenerator[Tuple[int, T], None]:
None], ) -> AsyncGenerator[tuple[int, T], None]:
"""Merge multiple asynchronous iterators into a single iterator.
This method handle the case where some iterators finish before others.
......@@ -433,7 +433,7 @@ async def merge_async_iterators(
async def collect_from_async_generator(
iterator: AsyncGenerator[T, None]) -> List[T]:
iterator: AsyncGenerator[T, None]) -> list[T]:
"""Collect all items from an async generator into a list."""
items = []
async for item in iterator:
......@@ -560,7 +560,7 @@ def find_process_using_port(port: int) -> Optional[psutil.Process]:
return None
def update_environment_variables(envs: Dict[str, str]):
def update_environment_variables(envs: dict[str, str]):
for k, v in envs.items():
if k in os.environ and os.environ[k] != v:
logger.warning(
......@@ -569,7 +569,7 @@ def update_environment_variables(envs: Dict[str, str]):
os.environ[k] = v
def chunk_list(lst: List[T], chunk_size: int):
def chunk_list(lst: list[T], chunk_size: int):
"""Yield successive chunk_size chunks from lst."""
for i in range(0, len(lst), chunk_size):
yield lst[i:i + chunk_size]
......@@ -642,7 +642,7 @@ def create_kv_caches_with_random_flash(
model_dtype: Optional[Union[str, torch.dtype]] = None,
seed: int = 0,
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
from vllm.platforms import current_platform
current_platform.seed_everything(seed)
......@@ -650,8 +650,8 @@ def create_kv_caches_with_random_flash(
key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
scale = head_size**-0.5
key_caches: List[torch.Tensor] = []
value_caches: List[torch.Tensor] = []
key_caches: list[torch.Tensor] = []
value_caches: list[torch.Tensor] = []
for _ in range(num_layers):
key_value_cache = torch.empty(size=key_value_cache_shape,
......@@ -679,7 +679,7 @@ def create_kv_caches_with_random(
model_dtype: Optional[Union[str, torch.dtype]] = None,
seed: int = 0,
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
if cache_dtype == "fp8" and head_size % 16:
raise ValueError(
......@@ -693,7 +693,7 @@ def create_kv_caches_with_random(
scale = head_size**-0.5
x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
key_caches: List[torch.Tensor] = []
key_caches: list[torch.Tensor] = []
for _ in range(num_layers):
key_cache = torch.empty(size=key_cache_shape,
dtype=torch_dtype,
......@@ -708,7 +708,7 @@ def create_kv_caches_with_random(
key_caches.append(key_cache)
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_caches: List[torch.Tensor] = []
value_caches: list[torch.Tensor] = []
for _ in range(num_layers):
value_cache = torch.empty(size=value_cache_shape,
dtype=torch_dtype,
......@@ -754,7 +754,7 @@ class DeviceMemoryProfiler:
def make_ndarray_with_pad(
x: List[List[T]],
x: list[list[T]],
pad: T,
dtype: npt.DTypeLike,
*,
......@@ -779,7 +779,7 @@ def make_ndarray_with_pad(
def make_tensor_with_pad(
x: List[List[T]],
x: list[list[T]],
pad: T,
dtype: torch.dtype,
*,
......@@ -831,7 +831,7 @@ def is_list_of(
typ: Union[type[T], tuple[type[T], ...]],
*,
check: Literal["first", "all"] = "first",
) -> TypeIs[List[T]]:
) -> TypeIs[list[T]]:
if not isinstance(value, list):
return False
......@@ -843,8 +843,8 @@ def is_list_of(
assert_never(check)
JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"],
Tuple["JSONTree[T]", ...], T]
JSONTree = Union[dict[str, "JSONTree[T]"], list["JSONTree[T]"],
tuple["JSONTree[T]", ...], T]
"""A nested JSON structure where the leaves need not be JSON-serializable."""
......@@ -859,7 +859,7 @@ def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]:
return func(value)
def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
def flatten_2d_lists(lists: list[list[T]]) -> list[T]:
"""Flatten a list of lists to a single list."""
return [item for sublist in lists for item in sublist]
......@@ -1226,7 +1226,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
return value
def _pull_args_from_config(self, args: List[str]) -> List[str]:
def _pull_args_from_config(self, args: list[str]) -> list[str]:
"""Method to pull arguments specified in the config file
into the command-line args variable.
......@@ -1291,7 +1291,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
return args
def _load_config_file(self, file_path: str) -> List[str]:
def _load_config_file(self, file_path: str) -> list[str]:
"""Loads a yaml file and returns the key value pairs as a
flattened list with argparse like pattern
```yaml
......@@ -1313,9 +1313,9 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
%s supplied", extension)
# only expecting a flat dictionary of atomic types
processed_args: List[str] = []
processed_args: list[str] = []
config: Dict[str, Union[int, str]] = {}
config: dict[str, Union[int, str]] = {}
try:
with open(file_path) as config_file:
config = yaml.safe_load(config_file)
......@@ -1399,7 +1399,7 @@ def resolve_mm_processor_kwargs(
*,
requires_kw_only: bool = True,
allow_var_kwargs: bool = False,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Applies filtering to eliminate invalid mm_processor_kwargs, i.e.,
those who are not explicit keywords to the given callable (of one is
given; otherwise no filtering is done), then merges the kwarg dicts,
......@@ -1440,7 +1440,7 @@ def get_allowed_kwarg_only_overrides(
*,
requires_kw_only: bool = True,
allow_var_kwargs: bool = False,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Given a callable which has one or more keyword only params and a dict
mapping param names to values, drop values that can be not be kwarg
......@@ -1531,9 +1531,9 @@ class AtomicCounter:
# Adapted from: https://stackoverflow.com/a/47212782/5082708
class LazyDict(Mapping[str, T], Generic[T]):
def __init__(self, factory: Dict[str, Callable[[], T]]):
def __init__(self, factory: dict[str, Callable[[], T]]):
self._factory = factory
self._dict: Dict[str, T] = {}
self._dict: dict[str, T] = {}
def __getitem__(self, key: str) -> T:
if key not in self._dict:
......@@ -1552,9 +1552,9 @@ class LazyDict(Mapping[str, T], Generic[T]):
return len(self._factory)
class ClassRegistry(UserDict[Type[T], _V]):
class ClassRegistry(UserDict[type[T], _V]):
def __getitem__(self, key: Type[T]) -> _V:
def __getitem__(self, key: type[T]) -> _V:
for cls in key.mro():
if cls in self.data:
return self.data[cls]
......@@ -1584,8 +1584,8 @@ def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor:
def weak_ref_tensors(
tensors: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]:
tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]
) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]:
"""
Convenience function to create weak references to tensors,
for single tensor, list of tensors or tuple of tensors.
......@@ -1857,7 +1857,7 @@ vllm_lib = Library("vllm", "FRAGMENT") # noqa
def direct_register_custom_op(
op_name: str,
op_func: Callable,
mutates_args: List[str],
mutates_args: list[str],
fake_impl: Optional[Callable] = None,
target_lib: Optional[Library] = None,
dispatch_key: str = "CUDA",
......@@ -2177,8 +2177,8 @@ def get_mp_context():
def bind_kv_cache(
ctx: Dict[str, Any],
kv_cache: List[List[torch.Tensor]], # [virtual_engine][layer_index]
ctx: dict[str, Any],
kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index]
) -> None:
# Bind the kv_cache tensor to Attention modules, similar to
# ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
......@@ -2210,8 +2210,8 @@ def bind_kv_cache(
forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]
def run_method(obj: Any, method: Union[str, bytes, Callable], args: Tuple[Any],
kwargs: Dict[str, Any]) -> Any:
def run_method(obj: Any, method: Union[str, bytes, Callable], args: tuple[Any],
kwargs: dict[str, Any]) -> Any:
"""
Run a method of an object with the given arguments and keyword arguments.
If the method is string, it will be converted to a method using getattr.
......@@ -2263,7 +2263,7 @@ def import_pynvml():
return pynvml
def warn_for_unimplemented_methods(cls: Type[T]) -> Type[T]:
def warn_for_unimplemented_methods(cls: type[T]) -> type[T]:
"""
A replacement for `abc.ABC`.
When we use `abc.ABC`, subclasses will fail to instantiate
......
# SPDX-License-Identifier: Apache-2.0
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
from typing import TYPE_CHECKING, Any, Optional
import numpy as np
import torch
......@@ -30,7 +30,7 @@ class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> List[int]:
def get_supported_head_sizes() -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
......@@ -38,15 +38,15 @@ class FlashAttentionBackend(AttentionBackend):
return "FLASH_ATTN_VLLM_V1"
@staticmethod
def get_impl_cls() -> Type["FlashAttentionImpl"]:
def get_impl_cls() -> type["FlashAttentionImpl"]:
return FlashAttentionImpl
@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
def get_metadata_cls() -> type["AttentionMetadata"]:
return FlashAttentionMetadata
@staticmethod
def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder
@staticmethod
......@@ -55,7 +55,7 @@ class FlashAttentionBackend(AttentionBackend):
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
......@@ -158,10 +158,10 @@ class FlashAttentionImpl(AttentionImpl):
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
) -> None:
......@@ -381,7 +381,7 @@ def cascade_attention(
max_kv_len: int,
softmax_scale: float,
alibi_slopes: Optional[torch.Tensor],
sliding_window: Tuple[int, int],
sliding_window: tuple[int, int],
logits_soft_cap: float,
block_table: torch.Tensor,
common_prefix_len: int,
......
......@@ -195,8 +195,7 @@ return curr_o @ W_O
import functools
from abc import abstractmethod
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple,
Type, TypeVar)
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
import torch
from compressed_tensors.quantization import QuantizationStrategy
......@@ -250,11 +249,11 @@ class MLACommonBackend(AttentionBackend):
return "TRITON_MLA_VLLM_V1"
@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
def get_metadata_cls() -> type["AttentionMetadata"]:
return MLACommonMetadata
@staticmethod
def get_builder_cls() -> Type["MLACommonMetadataBuilder"]:
def get_builder_cls() -> type["MLACommonMetadataBuilder"]:
return MLACommonMetadataBuilder
@staticmethod
......@@ -263,11 +262,11 @@ class MLACommonBackend(AttentionBackend):
block_size: int,
num_kv_heads: int, # assumed to be 1 for MLA
head_size: int,
) -> Tuple[int, ...]:
) -> tuple[int, ...]:
return (num_blocks, block_size, head_size)
@staticmethod
def get_supported_head_sizes() -> List[int]:
def get_supported_head_sizes() -> list[int]:
return [576]
@staticmethod
......@@ -317,8 +316,8 @@ class MLACommonMetadata:
has_context: bool = False
context_chunk_cu_seq_lens: Optional[torch.Tensor] = None
context_chunk_starts: Optional[torch.Tensor] = None
context_chunk_seq_tot: Optional[List[int]] = None
context_chunk_max_seq_lens: Optional[List[int]] = None
context_chunk_seq_tot: Optional[list[int]] = None
context_chunk_max_seq_lens: Optional[list[int]] = None
chunked_prefill_workspace: Optional[torch.Tensor] = None
def __post_init__(self):
......@@ -538,10 +537,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]],
blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
# MLA Specific Arguments
......@@ -634,7 +633,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
#
# returns input_group_shape, weight_group_shape
def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \
Tuple[Tuple[int, int], Tuple[int, int]]:
tuple[tuple[int, int], tuple[int, int]]:
if isinstance(layer.quant_method, Fp8LinearMethod):
if layer.quant_method.block_quant:
weight_block_size = \
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Optional
import torch
......@@ -25,21 +25,21 @@ class FlashMLABackend(MLACommonBackend):
return "FLASHMLA_VLLM_V1"
@staticmethod
def get_metadata_cls() -> Type["FlashMLAMetadata"]:
def get_metadata_cls() -> type["FlashMLAMetadata"]:
return FlashMLAMetadata
@staticmethod
def get_builder_cls() -> Type["FlashMLAMetadataBuilder"]:
def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
return FlashMLAMetadataBuilder
@staticmethod
def get_impl_cls() -> Type["FlashMLAImpl"]:
def get_impl_cls() -> type["FlashMLAImpl"]:
return FlashMLAImpl
@dataclass
class FlashMLAMetadata(MLACommonMetadata):
decode_tile_scheduler_metadata: Optional[Tuple[torch.Tensor,
decode_tile_scheduler_metadata: Optional[tuple[torch.Tensor,
torch.Tensor]] = None
decode_num_splits: Optional[torch.Tensor] = None
......@@ -76,10 +76,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]],
blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
# MLA Specific Arguments
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment