Unverified Commit cf069aa8 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated Python 3.8 typing (#13971)

parent bf33700e
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
import json import json
import re import re
from typing import Dict, List, Sequence, Union from collections.abc import Sequence
from typing import Union
import partial_json_parser import partial_json_parser
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
...@@ -33,9 +34,9 @@ class Hermes2ProToolParser(ToolParser): ...@@ -33,9 +34,9 @@ class Hermes2ProToolParser(ToolParser):
self.model_tokenizer = self.model_tokenizer.tokenizer self.model_tokenizer = self.model_tokenizer.tokenizer
self.current_tool_name_sent: bool = False self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: List[Dict] = [] self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1 self.current_tool_id: int = -1
self.streamed_args_for_tool: List[str] = [ self.streamed_args_for_tool: list[str] = [
] # map what has been streamed for each tool so far to a list ] # map what has been streamed for each tool so far to a list
self.tool_call_start_token: str = "<tool_call>" self.tool_call_start_token: str = "<tool_call>"
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import json import json
from typing import Dict, Sequence, Union from collections.abc import Sequence
from typing import Union
import partial_json_parser import partial_json_parser
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
...@@ -90,7 +91,7 @@ class Internlm2ToolParser(ToolParser): ...@@ -90,7 +91,7 @@ class Internlm2ToolParser(ToolParser):
# tool calls are generated in an object in inernlm2 # tool calls are generated in an object in inernlm2
# it's not support parallel tool calls # it's not support parallel tool calls
try: try:
tool_call_arr: Dict = partial_json_parser.loads( tool_call_arr: dict = partial_json_parser.loads(
parsable_arr, flags) parsable_arr, flags)
except partial_json_parser.core.exceptions.MalformedJSON: except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet') logger.debug('not enough tokens to parse into JSON yet')
......
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
import json import json
import re import re
from typing import Dict, List, Sequence, Union from collections.abc import Sequence
from typing import Union
import partial_json_parser import partial_json_parser
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
...@@ -35,9 +36,9 @@ class JambaToolParser(ToolParser): ...@@ -35,9 +36,9 @@ class JambaToolParser(ToolParser):
) )
self.current_tool_name_sent: bool = False self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: List[Dict] = [] self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1 self.current_tool_id: int = -1
self.streamed_args_for_tool: List[str] = [ self.streamed_args_for_tool: list[str] = [
] # map what has been streamed for each tool so far to a list ] # map what has been streamed for each tool so far to a list
self.tool_calls_start_token: str = "<tool_calls>" self.tool_calls_start_token: str = "<tool_calls>"
...@@ -157,7 +158,7 @@ class JambaToolParser(ToolParser): ...@@ -157,7 +158,7 @@ class JambaToolParser(ToolParser):
# tool calls are generated in an array, so do partial JSON # tool calls are generated in an array, so do partial JSON
# parsing on the entire array # parsing on the entire array
try: try:
tool_call_arr: List[Dict] = partial_json_parser.loads( tool_call_arr: list[dict] = partial_json_parser.loads(
parsable_arr, flags) parsable_arr, flags)
except partial_json_parser.core.exceptions.MalformedJSON: except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet') logger.debug('not enough tokens to parse into JSON yet')
...@@ -165,7 +166,7 @@ class JambaToolParser(ToolParser): ...@@ -165,7 +166,7 @@ class JambaToolParser(ToolParser):
# select as the current tool call the one we're on the state at # select as the current tool call the one we're on the state at
current_tool_call: Dict = tool_call_arr[self.current_tool_id] \ current_tool_call: dict = tool_call_arr[self.current_tool_id] \
if len(tool_call_arr) > 0 else {} if len(tool_call_arr) > 0 else {}
# case -- if no tokens have been streamed for the tool, e.g. # case -- if no tokens have been streamed for the tool, e.g.
......
...@@ -2,8 +2,9 @@ ...@@ -2,8 +2,9 @@
import json import json
import re import re
from collections.abc import Sequence
from json import JSONDecoder from json import JSONDecoder
from typing import Dict, List, Sequence, Union from typing import Union
import partial_json_parser import partial_json_parser
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
...@@ -40,10 +41,10 @@ class Llama3JsonToolParser(ToolParser): ...@@ -40,10 +41,10 @@ class Llama3JsonToolParser(ToolParser):
# initialize properties used for state when parsing tool calls in # initialize properties used for state when parsing tool calls in
# streaming mode # streaming mode
self.prev_tool_call_arr: List[Dict] = [] self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1 self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: List[str] = [ self.streamed_args_for_tool: list[str] = [
] # map what has been streamed for each tool so far to a list ] # map what has been streamed for each tool so far to a list
self.bot_token = "<|python_tag|>" self.bot_token = "<|python_tag|>"
self.bot_token_id = tokenizer.encode(self.bot_token, self.bot_token_id = tokenizer.encode(self.bot_token,
...@@ -78,7 +79,7 @@ class Llama3JsonToolParser(ToolParser): ...@@ -78,7 +79,7 @@ class Llama3JsonToolParser(ToolParser):
start_idx += end_idx + len('; ') start_idx += end_idx + len('; ')
function_call_arr.append(obj) function_call_arr.append(obj)
tool_calls: List[ToolCall] = [ tool_calls: list[ToolCall] = [
ToolCall( ToolCall(
type="function", type="function",
function=FunctionCall( function=FunctionCall(
...@@ -152,7 +153,7 @@ class Llama3JsonToolParser(ToolParser): ...@@ -152,7 +153,7 @@ class Llama3JsonToolParser(ToolParser):
return None return None
# select as the current tool call the one we're on the state at # select as the current tool call the one we're on the state at
current_tool_call: Dict = tool_call_arr[self.current_tool_id] \ current_tool_call: dict = tool_call_arr[self.current_tool_id] \
if len(tool_call_arr) > 0 else {} if len(tool_call_arr) > 0 else {}
# case -- if no tokens have been streamed for the tool, e.g. # case -- if no tokens have been streamed for the tool, e.g.
......
...@@ -2,9 +2,10 @@ ...@@ -2,9 +2,10 @@
import json import json
import re import re
from collections.abc import Sequence
from random import choices from random import choices
from string import ascii_letters, digits from string import ascii_letters, digits
from typing import Dict, List, Sequence, Union from typing import Union
import partial_json_parser import partial_json_parser
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
...@@ -56,10 +57,10 @@ class MistralToolParser(ToolParser): ...@@ -56,10 +57,10 @@ class MistralToolParser(ToolParser):
# initialize properties used for state when parsing tool calls in # initialize properties used for state when parsing tool calls in
# streaming mode # streaming mode
self.prev_tool_call_arr: List[Dict] = [] self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1 self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: List[str] = [ self.streamed_args_for_tool: list[str] = [
] # map what has been streamed for each tool so far to a list ] # map what has been streamed for each tool so far to a list
self.bot_token = "[TOOL_CALLS]" self.bot_token = "[TOOL_CALLS]"
self.bot_token_id = self.vocab.get(self.bot_token) self.bot_token_id = self.vocab.get(self.bot_token)
...@@ -104,7 +105,7 @@ class MistralToolParser(ToolParser): ...@@ -104,7 +105,7 @@ class MistralToolParser(ToolParser):
function_call_arr = json.loads(raw_tool_call) function_call_arr = json.loads(raw_tool_call)
# Tool Call # Tool Call
tool_calls: List[MistralToolCall] = [ tool_calls: list[MistralToolCall] = [
MistralToolCall( MistralToolCall(
type="function", type="function",
function=FunctionCall( function=FunctionCall(
...@@ -172,7 +173,7 @@ class MistralToolParser(ToolParser): ...@@ -172,7 +173,7 @@ class MistralToolParser(ToolParser):
# tool calls are generated in an array, so do partial JSON # tool calls are generated in an array, so do partial JSON
# parsing on the entire array # parsing on the entire array
try: try:
tool_call_arr: List[Dict] = partial_json_parser.loads( tool_call_arr: list[dict] = partial_json_parser.loads(
parsable_arr, flags) parsable_arr, flags)
except partial_json_parser.core.exceptions.MalformedJSON: except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet') logger.debug('not enough tokens to parse into JSON yet')
...@@ -180,7 +181,7 @@ class MistralToolParser(ToolParser): ...@@ -180,7 +181,7 @@ class MistralToolParser(ToolParser):
# select as the current tool call the one we're on the state at # select as the current tool call the one we're on the state at
current_tool_call: Dict = tool_call_arr[self.current_tool_id] \ current_tool_call: dict = tool_call_arr[self.current_tool_id] \
if len(tool_call_arr) > 0 else {} if len(tool_call_arr) > 0 else {}
# case -- if no tokens have been streamed for the tool, e.g. # case -- if no tokens have been streamed for the tool, e.g.
......
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
import ast import ast
import json import json
import re import re
from typing import Any, Sequence, Tuple, Union from collections.abc import Sequence
from typing import Any, Union
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
...@@ -204,7 +205,7 @@ def _handle_single_tool(call: ast.Call) -> ToolCall: ...@@ -204,7 +205,7 @@ def _handle_single_tool(call: ast.Call) -> ToolCall:
arguments=json.dumps(arguments))) arguments=json.dumps(arguments)))
def _make_valid_python(text: str) -> Union[Tuple[str, str], None]: def _make_valid_python(text: str) -> Union[tuple[str, str], None]:
bracket_stack = [] bracket_stack = []
for index, char in enumerate(text): for index, char in enumerate(text):
if char in {"[", "(", "{"}: if char in {"[", "(", "{"}:
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import json import json
from json import JSONDecodeError, JSONDecoder from json import JSONDecodeError, JSONDecoder
from typing import Any, List, Tuple from typing import Any
import partial_json_parser import partial_json_parser
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
...@@ -82,7 +82,7 @@ def extract_intermediate_diff(curr: str, old: str) -> str: ...@@ -82,7 +82,7 @@ def extract_intermediate_diff(curr: str, old: str) -> str:
return diff return diff
def find_all_indices(string: str, substring: str) -> List[int]: def find_all_indices(string: str, substring: str) -> list[int]:
""" """
Find all (starting) indices of a substring in a given string. Useful for Find all (starting) indices of a substring in a given string. Useful for
tool call extraction tool call extraction
...@@ -99,7 +99,7 @@ def find_all_indices(string: str, substring: str) -> List[int]: ...@@ -99,7 +99,7 @@ def find_all_indices(string: str, substring: str) -> List[int]:
# partial_json_parser doesn't support extra data and # partial_json_parser doesn't support extra data and
# JSONDecorder.raw_decode doesn't support partial JSON # JSONDecorder.raw_decode doesn't support partial JSON
def partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]: def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]:
try: try:
return (partial_json_parser.loads(input_str, flags), len(input_str)) return (partial_json_parser.loads(input_str, flags), len(input_str))
except JSONDecodeError as e: except JSONDecodeError as e:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List, Union from typing import Union
from torch.nn import CosineSimilarity from torch.nn import CosineSimilarity
...@@ -10,12 +10,12 @@ from vllm.transformers_utils.tokenizer import (PreTrainedTokenizer, ...@@ -10,12 +10,12 @@ from vllm.transformers_utils.tokenizer import (PreTrainedTokenizer,
def _cosine_similarity( def _cosine_similarity(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
embed_1: List[PoolingRequestOutput], embed_1: list[PoolingRequestOutput],
embed_2: List[PoolingRequestOutput], embed_2: list[PoolingRequestOutput],
) -> List[PoolingRequestOutput]: ) -> list[PoolingRequestOutput]:
scorer = CosineSimilarity(0) scorer = CosineSimilarity(0)
scores: Union[List[PoolingRequestOutput]] = [] scores: Union[list[PoolingRequestOutput]] = []
for emb_1, emb_2 in zip(embed_1, embed_2): for emb_1, emb_2 in zip(embed_1, embed_2):
pair_score = scorer(emb_1.outputs.data, emb_2.outputs.data) pair_score = scorer(emb_1.outputs.data, emb_2.outputs.data)
...@@ -38,8 +38,8 @@ def _cosine_similarity( ...@@ -38,8 +38,8 @@ def _cosine_similarity(
def _validate_score_input_lens( def _validate_score_input_lens(
texts_1: Union[List[str], List[dict]], texts_1: Union[list[str], list[dict]],
texts_2: Union[List[str], List[dict]], texts_2: Union[list[str], list[dict]],
): ):
if len(texts_1) > 1 and len(texts_1) != len(texts_2): if len(texts_1) > 1 and len(texts_1) != len(texts_2):
raise ValueError("Input lengths must be either 1:1, 1:N or N:N") raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import os import os
import tempfile import tempfile
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional from typing import TYPE_CHECKING, Any, Callable, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
VLLM_HOST_IP: str = "" VLLM_HOST_IP: str = ""
...@@ -67,12 +67,12 @@ if TYPE_CHECKING: ...@@ -67,12 +67,12 @@ if TYPE_CHECKING:
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
VLLM_TEST_FORCE_FP8_MARLIN: bool = False VLLM_TEST_FORCE_FP8_MARLIN: bool = False
VLLM_RPC_TIMEOUT: int = 10000 # ms VLLM_RPC_TIMEOUT: int = 10000 # ms
VLLM_PLUGINS: Optional[List[str]] = None VLLM_PLUGINS: Optional[list[str]] = None
VLLM_TORCH_PROFILER_DIR: Optional[str] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None
VLLM_USE_TRITON_AWQ: bool = False VLLM_USE_TRITON_AWQ: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_SKIP_P2P_CHECK: bool = False VLLM_SKIP_P2P_CHECK: bool = False
VLLM_DISABLED_KERNELS: List[str] = [] VLLM_DISABLED_KERNELS: list[str] = []
VLLM_USE_V1: bool = False VLLM_USE_V1: bool = False
VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
...@@ -123,7 +123,7 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: ...@@ -123,7 +123,7 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
# begin-env-vars-definition # begin-env-vars-definition
environment_variables: Dict[str, Callable[[], Any]] = { environment_variables: dict[str, Callable[[], Any]] = {
# ================== Installation Time Env Vars ================== # ================== Installation Time Env Vars ==================
......
...@@ -4,7 +4,7 @@ import time ...@@ -4,7 +4,7 @@ import time
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -28,13 +28,13 @@ batchsize_forward_time: defaultdict = defaultdict(list) ...@@ -28,13 +28,13 @@ batchsize_forward_time: defaultdict = defaultdict(list)
@dataclass @dataclass
class ForwardContext: class ForwardContext:
# copy from vllm_config.compilation_config.static_forward_context # copy from vllm_config.compilation_config.static_forward_context
attn_layers: Dict[str, Any] attn_layers: dict[str, Any]
# TODO: extend to support per-layer dynamic forward context # TODO: extend to support per-layer dynamic forward context
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
# TODO: remove after making all virtual_engines share the same kv cache # TODO: remove after making all virtual_engines share the same kv cache
virtual_engine: int # set dynamically for each forward pass virtual_engine: int # set dynamically for each forward pass
num_tokens_across_dp: Optional[ num_tokens_across_dp: Optional[
List[int]] = None # set dynamically for each forward pass list[int]] = None # set dynamically for each forward pass
_forward_context: Optional[ForwardContext] = None _forward_context: Optional[ForwardContext] = None
......
...@@ -109,7 +109,7 @@ def _configure_vllm_root_logger() -> None: ...@@ -109,7 +109,7 @@ def _configure_vllm_root_logger() -> None:
custom_config = json.loads(file.read()) custom_config = json.loads(file.read())
if not isinstance(custom_config, dict): if not isinstance(custom_config, dict):
raise ValueError("Invalid logging config. Expected Dict, got %s.", raise ValueError("Invalid logging config. Expected dict, got %s.",
type(custom_config).__name__) type(custom_config).__name__)
logging_config = custom_config logging_config = custom_config
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Tuple, Union from typing import Callable, Union
import torch import torch
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor], LogitsProcessor = Union[Callable[[list[int], torch.Tensor], torch.Tensor],
Callable[[List[int], List[int], torch.Tensor], Callable[[list[int], list[int], torch.Tensor],
torch.Tensor]] torch.Tensor]]
"""LogitsProcessor is a function that takes a list """LogitsProcessor is a function that takes a list
of previously generated tokens, the logits tensor of previously generated tokens, the logits tensor
...@@ -17,9 +17,9 @@ to sample from.""" ...@@ -17,9 +17,9 @@ to sample from."""
def get_bad_words_logits_processors( def get_bad_words_logits_processors(
bad_words: List[str], bad_words: list[str],
tokenizer: AnyTokenizer) -> List[LogitsProcessor]: tokenizer: AnyTokenizer) -> list[LogitsProcessor]:
bad_words_ids: List[List[int]] = list() bad_words_ids: list[list[int]] = list()
for bad_word in bad_words: for bad_word in bad_words:
# To prohibit words both at the beginning # To prohibit words both at the beginning
...@@ -51,13 +51,13 @@ class NoBadWordsLogitsProcessor: ...@@ -51,13 +51,13 @@ class NoBadWordsLogitsProcessor:
_SMALLEST_LOGIT = float("-inf") _SMALLEST_LOGIT = float("-inf")
_NEUTRAL_LOGIT = 0.0 _NEUTRAL_LOGIT = 0.0
def __init__(self, bad_words_ids: List[List[int]]): def __init__(self, bad_words_ids: list[list[int]]):
self.bad_words_ids = bad_words_ids self.bad_words_ids = bad_words_ids
self.word_bias: torch.FloatTensor = None self.word_bias: torch.FloatTensor = None
def __call__( def __call__(
self, self,
past_tokens_ids: Union[List[int], Tuple[int]], past_tokens_ids: Union[list[int], tuple[int]],
logits: torch.FloatTensor, logits: torch.FloatTensor,
) -> torch.Tensor: ) -> torch.Tensor:
if self.word_bias is None: if self.word_bias is None:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import time import time
from collections.abc import MutableSequence
from collections.abc import Sequence as GenericSequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Generic, List, MutableSequence, Optional from typing import Generic, Optional, Union
from typing import Sequence as GenericSequence
from typing import Union
import torch import torch
from typing_extensions import TypeVar, deprecated from typing_extensions import TypeVar, deprecated
...@@ -109,14 +109,14 @@ class RequestOutput: ...@@ -109,14 +109,14 @@ class RequestOutput:
self, self,
request_id: str, request_id: str,
prompt: Optional[str], prompt: Optional[str],
prompt_token_ids: Optional[List[int]], prompt_token_ids: Optional[list[int]],
prompt_logprobs: Optional[PromptLogprobs], prompt_logprobs: Optional[PromptLogprobs],
outputs: List[CompletionOutput], outputs: list[CompletionOutput],
finished: bool, finished: bool,
metrics: Optional[RequestMetrics] = None, metrics: Optional[RequestMetrics] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
encoder_prompt: Optional[str] = None, encoder_prompt: Optional[str] = None,
encoder_prompt_token_ids: Optional[List[int]] = None, encoder_prompt_token_ids: Optional[list[int]] = None,
num_cached_tokens: Optional[int] = None, num_cached_tokens: Optional[int] = None,
*, *,
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None, multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None,
...@@ -139,9 +139,9 @@ class RequestOutput: ...@@ -139,9 +139,9 @@ class RequestOutput:
cls, cls,
request_id: str, request_id: str,
prompt: Optional[str], prompt: Optional[str],
prompt_token_ids: Optional[List[int]], prompt_token_ids: Optional[list[int]],
text: str, text: str,
token_ids: List[int], token_ids: list[int],
logprobs: Optional[SampleLogprobs], logprobs: Optional[SampleLogprobs],
prompt_logprobs: Optional[PromptLogprobs], prompt_logprobs: Optional[PromptLogprobs],
cumulative_logprob: Optional[float], cumulative_logprob: Optional[float],
...@@ -189,7 +189,7 @@ class RequestOutput: ...@@ -189,7 +189,7 @@ class RequestOutput:
@classmethod @classmethod
def from_seq_group( def from_seq_group(
cls, seq_group: SequenceGroup, use_cache: bool, cls, seq_group: SequenceGroup, use_cache: bool,
seq_id_to_seq_group: Dict[str, SequenceGroupBase] seq_id_to_seq_group: dict[str, SequenceGroupBase]
) -> Optional["RequestOutput"]: ) -> Optional["RequestOutput"]:
finished = seq_group.is_finished() finished = seq_group.is_finished()
...@@ -363,12 +363,12 @@ class PoolingRequestOutput(Generic[_O]): ...@@ -363,12 +363,12 @@ class PoolingRequestOutput(Generic[_O]):
Args: Args:
request_id (str): A unique identifier for the pooling request. request_id (str): A unique identifier for the pooling request.
outputs (PoolingOutput): The pooling results for the given input. outputs (PoolingOutput): The pooling results for the given input.
prompt_token_ids (List[int]): A list of token IDs used in the prompt. prompt_token_ids (list[int]): A list of token IDs used in the prompt.
finished (bool): A flag indicating whether the pooling is completed. finished (bool): A flag indicating whether the pooling is completed.
""" """
def __init__(self, request_id: str, outputs: _O, def __init__(self, request_id: str, outputs: _O,
prompt_token_ids: List[int], finished: bool): prompt_token_ids: list[int], finished: bool):
self.request_id = request_id self.request_id = request_id
self.prompt_token_ids = prompt_token_ids self.prompt_token_ids = prompt_token_ids
self.finished = finished self.finished = finished
...@@ -407,7 +407,7 @@ class RequestOutputFactory: ...@@ -407,7 +407,7 @@ class RequestOutputFactory:
@staticmethod @staticmethod
def create(seq_group: SequenceGroup, def create(seq_group: SequenceGroup,
seq_id_to_seq_group: Dict[str, SequenceGroupBase], seq_id_to_seq_group: dict[str, SequenceGroupBase],
use_cache: bool = False): use_cache: bool = False):
if seq_group.pooled_data is not None: if seq_group.pooled_data is not None:
return PoolingRequestOutput.from_seq_group(seq_group) return PoolingRequestOutput.from_seq_group(seq_group)
......
...@@ -4,11 +4,10 @@ import copy ...@@ -4,11 +4,10 @@ import copy
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, IntEnum from enum import Enum, IntEnum
from functools import cached_property from functools import cached_property
from typing import Any, Dict, List, Optional, Set, Union from typing import Annotated, Any, Optional, Union
import msgspec import msgspec
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import Annotated
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logits_process import LogitsProcessor from vllm.logits_process import LogitsProcessor
...@@ -29,9 +28,9 @@ class SamplingType(IntEnum): ...@@ -29,9 +28,9 @@ class SamplingType(IntEnum):
@dataclass @dataclass
class GuidedDecodingParams: class GuidedDecodingParams:
"""One of these fields will be used to build a logit processor.""" """One of these fields will be used to build a logit processor."""
json: Optional[Union[str, Dict]] = None json: Optional[Union[str, dict]] = None
regex: Optional[str] = None regex: Optional[str] = None
choice: Optional[List[str]] = None choice: Optional[list[str]] = None
grammar: Optional[str] = None grammar: Optional[str] = None
json_object: Optional[bool] = None json_object: Optional[bool] = None
"""These are other options that can be set""" """These are other options that can be set"""
...@@ -40,9 +39,9 @@ class GuidedDecodingParams: ...@@ -40,9 +39,9 @@ class GuidedDecodingParams:
@staticmethod @staticmethod
def from_optional( def from_optional(
json: Optional[Union[Dict, BaseModel, str]] = None, json: Optional[Union[dict, BaseModel, str]] = None,
regex: Optional[str] = None, regex: Optional[str] = None,
choice: Optional[List[str]] = None, choice: Optional[list[str]] = None,
grammar: Optional[str] = None, grammar: Optional[str] = None,
json_object: Optional[bool] = None, json_object: Optional[bool] = None,
backend: Optional[str] = None, backend: Optional[str] = None,
...@@ -72,7 +71,7 @@ class GuidedDecodingParams: ...@@ -72,7 +71,7 @@ class GuidedDecodingParams:
""" """
return (self.backend or "").split(":")[0] return (self.backend or "").split(":")[0]
def backend_options(self) -> List[str]: def backend_options(self) -> list[str]:
"""Return the backend options as a list of strings.""" """Return the backend options as a list of strings."""
if not self.backend or ":" not in self.backend: if not self.backend or ":" not in self.backend:
return [] return []
...@@ -144,12 +143,12 @@ class SamplingParams( ...@@ -144,12 +143,12 @@ class SamplingParams(
considered, relative to the probability of the most likely token. considered, relative to the probability of the most likely token.
Must be in [0, 1]. Set to 0 to disable this. Must be in [0, 1]. Set to 0 to disable this.
seed: Random seed to use for the generation. seed: Random seed to use for the generation.
stop: List of strings that stop the generation when they are generated. stop: list of strings that stop the generation when they are generated.
The returned output will not contain the stop strings. The returned output will not contain the stop strings.
stop_token_ids: List of tokens that stop the generation when they are stop_token_ids: list of tokens that stop the generation when they are
generated. The returned output will contain the stop tokens unless generated. The returned output will contain the stop tokens unless
the stop tokens are special tokens. the stop tokens are special tokens.
bad_words: List of words that are not allowed to be generated. bad_words: list of words that are not allowed to be generated.
More precisely, only the last token of a corresponding More precisely, only the last token of a corresponding
token sequence is not allowed when the next generated token token sequence is not allowed when the next generated token
can complete the sequence. can complete the sequence.
...@@ -172,7 +171,7 @@ class SamplingParams( ...@@ -172,7 +171,7 @@ class SamplingParams(
skip_special_tokens: Whether to skip special tokens in the output. skip_special_tokens: Whether to skip special tokens in the output.
spaces_between_special_tokens: Whether to add spaces between special spaces_between_special_tokens: Whether to add spaces between special
tokens in the output. Defaults to True. tokens in the output. Defaults to True.
logits_processors: List of functions that modify logits based on logits_processors: list of functions that modify logits based on
previously generated tokens, and optionally prompt tokens as previously generated tokens, and optionally prompt tokens as
a first argument. a first argument.
truncate_prompt_tokens: If set to an integer k, will use only the last k truncate_prompt_tokens: If set to an integer k, will use only the last k
...@@ -198,9 +197,9 @@ class SamplingParams( ...@@ -198,9 +197,9 @@ class SamplingParams(
top_k: int = -1 top_k: int = -1
min_p: float = 0.0 min_p: float = 0.0
seed: Optional[int] = None seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None stop: Optional[Union[str, list[str]]] = None
stop_token_ids: Optional[List[int]] = None stop_token_ids: Optional[list[int]] = None
bad_words: Optional[List[str]] = None bad_words: Optional[list[str]] = None
ignore_eos: bool = False ignore_eos: bool = False
max_tokens: Optional[int] = 16 max_tokens: Optional[int] = 16
min_tokens: int = 0 min_tokens: int = 0
...@@ -212,8 +211,8 @@ class SamplingParams( ...@@ -212,8 +211,8 @@ class SamplingParams(
detokenize: bool = True detokenize: bool = True
skip_special_tokens: bool = True skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True spaces_between_special_tokens: bool = True
# Optional[List[LogitsProcessor]] type. We use Any here because # Optional[list[LogitsProcessor]] type. We use Any here because
# Optional[List[LogitsProcessor]] type is not supported by msgspec. # Optional[list[LogitsProcessor]] type is not supported by msgspec.
logits_processors: Optional[Any] = None logits_processors: Optional[Any] = None
include_stop_str_in_output: bool = False include_stop_str_in_output: bool = False
truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None
...@@ -222,12 +221,12 @@ class SamplingParams( ...@@ -222,12 +221,12 @@ class SamplingParams(
# The below fields are not supposed to be used as an input. # The below fields are not supposed to be used as an input.
# They are set in post_init. # They are set in post_init.
output_text_buffer_length: int = 0 output_text_buffer_length: int = 0
_all_stop_token_ids: Set[int] = msgspec.field(default_factory=set) _all_stop_token_ids: set[int] = msgspec.field(default_factory=set)
# Fields used to construct logits processors # Fields used to construct logits processors
guided_decoding: Optional[GuidedDecodingParams] = None guided_decoding: Optional[GuidedDecodingParams] = None
logit_bias: Optional[Dict[int, float]] = None logit_bias: Optional[dict[int, float]] = None
allowed_token_ids: Optional[List[int]] = None allowed_token_ids: Optional[list[int]] = None
@staticmethod @staticmethod
def from_optional( def from_optional(
...@@ -241,9 +240,9 @@ class SamplingParams( ...@@ -241,9 +240,9 @@ class SamplingParams(
top_k: int = -1, top_k: int = -1,
min_p: float = 0.0, min_p: float = 0.0,
seed: Optional[int] = None, seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, list[str]]] = None,
stop_token_ids: Optional[List[int]] = None, stop_token_ids: Optional[list[int]] = None,
bad_words: Optional[List[str]] = None, bad_words: Optional[list[str]] = None,
include_stop_str_in_output: bool = False, include_stop_str_in_output: bool = False,
ignore_eos: bool = False, ignore_eos: bool = False,
max_tokens: Optional[int] = 16, max_tokens: Optional[int] = 16,
...@@ -253,13 +252,13 @@ class SamplingParams( ...@@ -253,13 +252,13 @@ class SamplingParams(
detokenize: bool = True, detokenize: bool = True,
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True, spaces_between_special_tokens: bool = True,
logits_processors: Optional[List[LogitsProcessor]] = None, logits_processors: Optional[list[LogitsProcessor]] = None,
truncate_prompt_tokens: Optional[Annotated[int, truncate_prompt_tokens: Optional[Annotated[int,
msgspec.Meta(ge=1)]] = None, msgspec.Meta(ge=1)]] = None,
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE, output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
guided_decoding: Optional[GuidedDecodingParams] = None, guided_decoding: Optional[GuidedDecodingParams] = None,
logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]] = None, logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None,
allowed_token_ids: Optional[List[int]] = None, allowed_token_ids: Optional[list[int]] = None,
) -> "SamplingParams": ) -> "SamplingParams":
if logit_bias is not None: if logit_bias is not None:
# Convert token_id to integer # Convert token_id to integer
...@@ -435,7 +434,7 @@ class SamplingParams( ...@@ -435,7 +434,7 @@ class SamplingParams(
def update_from_generation_config( def update_from_generation_config(
self, self,
generation_config: Dict[str, Any], generation_config: dict[str, Any],
model_eos_token_id: Optional[int] = None) -> None: model_eos_token_id: Optional[int] = None) -> None:
"""Update if there are non-default values from generation_config""" """Update if there are non-default values from generation_config"""
...@@ -468,7 +467,7 @@ class SamplingParams( ...@@ -468,7 +467,7 @@ class SamplingParams(
return SamplingType.RANDOM return SamplingType.RANDOM
@property @property
def all_stop_token_ids(self) -> Set[int]: def all_stop_token_ids(self) -> set[int]:
return self._all_stop_token_ids return self._all_stop_token_ids
def clone(self) -> "SamplingParams": def clone(self) -> "SamplingParams":
......
...@@ -5,11 +5,11 @@ import enum ...@@ -5,11 +5,11 @@ import enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from array import array from array import array
from collections import defaultdict from collections import defaultdict
from collections.abc import Mapping
from collections.abc import Sequence as GenericSequence
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import reduce from functools import reduce
from typing import Any, Callable, DefaultDict, Dict, List, Mapping, Optional from typing import Any, Callable, Optional, Union
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Union
import msgspec import msgspec
import torch import torch
...@@ -50,9 +50,9 @@ class Logprob: ...@@ -50,9 +50,9 @@ class Logprob:
# {token_id -> logprob} per each sequence group. None if the corresponding # {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob. # sequence group doesn't require prompt logprob.
PromptLogprobs = List[Optional[Dict[int, Logprob]]] PromptLogprobs = list[Optional[dict[int, Logprob]]]
# {token_id -> logprob} for each sequence group. # {token_id -> logprob} for each sequence group.
SampleLogprobs = List[Dict[int, Logprob]] SampleLogprobs = list[dict[int, Logprob]]
class SequenceStatus(enum.IntEnum): class SequenceStatus(enum.IntEnum):
...@@ -129,7 +129,7 @@ class SequenceDataDelta( ...@@ -129,7 +129,7 @@ class SequenceDataDelta(
omit_defaults=True): # type: ignore[call-arg] omit_defaults=True): # type: ignore[call-arg]
"""Delta SequenceData to send to workers per step.""" """Delta SequenceData to send to workers per step."""
# A new token to be appended to existing SequenceData. # A new token to be appended to existing SequenceData.
new_output_token_ids: List[int] new_output_token_ids: list[int]
# Overwriting existing `cumulative_logprob` # Overwriting existing `cumulative_logprob`
new_cumulative_logprob: float new_cumulative_logprob: float
# Overwriting existing `num_computed_tokens`. # Overwriting existing `num_computed_tokens`.
...@@ -152,7 +152,7 @@ class SequenceData(msgspec.Struct, ...@@ -152,7 +152,7 @@ class SequenceData(msgspec.Struct,
output_token_ids: The token IDs of the output. output_token_ids: The token IDs of the output.
cumulative_logprob: The cumulative log probability of the output. cumulative_logprob: The cumulative log probability of the output.
""" """
# NOTE: we cannot use Union[List, array] because msgspec cannot support # NOTE: we cannot use Union[list, array] because msgspec cannot support
# union of 2 list types. # union of 2 list types.
_prompt_token_ids: array _prompt_token_ids: array
_output_token_ids: array = msgspec.field( _output_token_ids: array = msgspec.field(
...@@ -160,25 +160,25 @@ class SequenceData(msgspec.Struct, ...@@ -160,25 +160,25 @@ class SequenceData(msgspec.Struct,
### The below fields should not be passed as an argument ### ### The below fields should not be passed as an argument ###
_cumulative_logprob: float = 0.0 _cumulative_logprob: float = 0.0
_prompt_token_ids_tuple: Tuple[int, _prompt_token_ids_tuple: tuple[int,
...] = msgspec.field(default_factory=tuple) ...] = msgspec.field(default_factory=tuple)
# The number of tokens that are computed (that run against the model). # The number of tokens that are computed (that run against the model).
_num_computed_tokens: int = 0 _num_computed_tokens: int = 0
# The number of tokens with prefix cache hit. # The number of tokens with prefix cache hit.
_num_cached_tokens: int = 0 _num_cached_tokens: int = 0
_stage: SequenceStage = SequenceStage.PREFILL _stage: SequenceStage = SequenceStage.PREFILL
_cached_all_token_ids: List[int] = msgspec.field(default_factory=list) _cached_all_token_ids: list[int] = msgspec.field(default_factory=list)
# It is used to get delta input. It is reset when `get_delta_and_reset` # It is used to get delta input. It is reset when `get_delta_and_reset`
# is called. # is called.
_new_appended_tokens: List[int] = msgspec.field(default_factory=list) _new_appended_tokens: list[int] = msgspec.field(default_factory=list)
# It is used to compute mrope_position_ids. # It is used to compute mrope_position_ids.
_mrope_position_delta: Optional[int] = None _mrope_position_delta: Optional[int] = None
@staticmethod @staticmethod
def from_prompt_token_counts( def from_prompt_token_counts(
*token_counts: Tuple[int, int]) -> "SequenceData": *token_counts: tuple[int, int]) -> "SequenceData":
""" """
Construct a :class:`SequenceData` instance by concatenating Construct a :class:`SequenceData` instance by concatenating
prompt token sequences. prompt token sequences.
...@@ -220,14 +220,14 @@ class SequenceData(msgspec.Struct, ...@@ -220,14 +220,14 @@ class SequenceData(msgspec.Struct,
def __post_init__(self) -> None: def __post_init__(self) -> None:
assert self._prompt_token_ids.typecode == "l" assert self._prompt_token_ids.typecode == "l"
assert self._output_token_ids.typecode == "l" assert self._output_token_ids.typecode == "l"
self._prompt_token_ids_tuple: Tuple[int, ...] = tuple( self._prompt_token_ids_tuple: tuple[int, ...] = tuple(
self._prompt_token_ids) self._prompt_token_ids)
self._update_cached_all_tokens() self._update_cached_all_tokens()
def _update_cached_all_tokens(self): def _update_cached_all_tokens(self):
assert isinstance(self._prompt_token_ids, array) assert isinstance(self._prompt_token_ids, array)
assert isinstance(self._output_token_ids, array) assert isinstance(self._output_token_ids, array)
self._cached_all_token_ids: List[int] = list(self._prompt_token_ids + self._cached_all_token_ids: list[int] = list(self._prompt_token_ids +
self._output_token_ids) self._output_token_ids)
@property @property
...@@ -235,7 +235,7 @@ class SequenceData(msgspec.Struct, ...@@ -235,7 +235,7 @@ class SequenceData(msgspec.Struct,
return self._cumulative_logprob return self._cumulative_logprob
@property @property
def prompt_token_ids(self) -> Tuple[int, ...]: def prompt_token_ids(self) -> tuple[int, ...]:
return self._prompt_token_ids_tuple return self._prompt_token_ids_tuple
@prompt_token_ids.setter @prompt_token_ids.setter
...@@ -252,7 +252,7 @@ class SequenceData(msgspec.Struct, ...@@ -252,7 +252,7 @@ class SequenceData(msgspec.Struct,
return self._prompt_token_ids return self._prompt_token_ids
@property @property
def output_token_ids(self) -> Tuple[int, ...]: def output_token_ids(self) -> tuple[int, ...]:
return tuple(self._output_token_ids) return tuple(self._output_token_ids)
@output_token_ids.setter @output_token_ids.setter
...@@ -295,12 +295,12 @@ class SequenceData(msgspec.Struct, ...@@ -295,12 +295,12 @@ class SequenceData(msgspec.Struct,
def get_output_len(self) -> int: def get_output_len(self) -> int:
return len(self._output_token_ids) return len(self._output_token_ids)
def get_token_ids(self) -> List[int]: def get_token_ids(self) -> list[int]:
return self._cached_all_token_ids return self._cached_all_token_ids
def get_prefix_token_ids( def get_prefix_token_ids(
self, num_tokens: int self, num_tokens: int
) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]: ) -> tuple[tuple[int, ...], Optional[tuple[int, ...]]]:
"""Get prefix tokens, and make the return value hashable""" """Get prefix tokens, and make the return value hashable"""
prompt_length = self.get_prompt_len() prompt_length = self.get_prompt_len()
if num_tokens > prompt_length: if num_tokens > prompt_length:
...@@ -351,10 +351,10 @@ class SequenceData(msgspec.Struct, ...@@ -351,10 +351,10 @@ class SequenceData(msgspec.Struct,
return self._prompt_token_ids[-1] return self._prompt_token_ids[-1]
return self._output_token_ids[-1] return self._output_token_ids[-1]
def get_prompt_token_ids(self) -> Tuple[int, ...]: def get_prompt_token_ids(self) -> tuple[int, ...]:
return self.prompt_token_ids return self.prompt_token_ids
def get_output_token_ids(self) -> Tuple[int, ...]: def get_output_token_ids(self) -> tuple[int, ...]:
return self.output_token_ids return self.output_token_ids
def get_delta_and_reset(self) -> SequenceDataDelta: def get_delta_and_reset(self) -> SequenceDataDelta:
...@@ -432,7 +432,7 @@ class Sequence: ...@@ -432,7 +432,7 @@ class Sequence:
self.prefix_offset = 0 self.prefix_offset = 0
self.read_offset = 0 self.read_offset = 0
# Input + output tokens # Input + output tokens
self.tokens: Optional[List[str]] = None self.tokens: Optional[list[str]] = None
@property @property
def n_blocks(self) -> int: def n_blocks(self) -> int:
...@@ -443,7 +443,7 @@ class Sequence: ...@@ -443,7 +443,7 @@ class Sequence:
return self.inputs.prompt return self.inputs.prompt
@property @property
def prompt_token_ids(self) -> List[int]: def prompt_token_ids(self) -> list[int]:
return self.inputs.prompt_token_ids return self.inputs.prompt_token_ids
@property @property
...@@ -451,7 +451,7 @@ class Sequence: ...@@ -451,7 +451,7 @@ class Sequence:
return self.inputs.prompt_embeds return self.inputs.prompt_embeds
@property @property
def token_type_ids(self) -> List[int]: def token_type_ids(self) -> list[int]:
return self.inputs.token_type_ids return self.inputs.token_type_ids
@property @property
...@@ -463,7 +463,7 @@ class Sequence: ...@@ -463,7 +463,7 @@ class Sequence:
return self.inputs.multi_modal_placeholders return self.inputs.multi_modal_placeholders
@property @property
def mm_processor_kwargs(self) -> Dict[str, Any]: def mm_processor_kwargs(self) -> dict[str, Any]:
return self.inputs.mm_processor_kwargs return self.inputs.mm_processor_kwargs
@property @property
...@@ -548,7 +548,7 @@ class Sequence: ...@@ -548,7 +548,7 @@ class Sequence:
"""Reset the sequence states for recomputation.""" """Reset the sequence states for recomputation."""
self.data.reset_state_for_recompute() self.data.reset_state_for_recompute()
def append_token_id(self, token_id: int, logprobs: Dict[int, def append_token_id(self, token_id: int, logprobs: dict[int,
Logprob]) -> None: Logprob]) -> None:
assert token_id in logprobs assert token_id in logprobs
self.output_logprobs.append(logprobs) self.output_logprobs.append(logprobs)
...@@ -563,16 +563,16 @@ class Sequence: ...@@ -563,16 +563,16 @@ class Sequence:
def get_output_len(self) -> int: def get_output_len(self) -> int:
return self.data.get_output_len() return self.data.get_output_len()
def get_token_ids(self) -> List[int]: def get_token_ids(self) -> list[int]:
return self.data.get_token_ids() return self.data.get_token_ids()
def get_prompt_token_ids(self) -> Tuple[int, ...]: def get_prompt_token_ids(self) -> tuple[int, ...]:
return self.data.get_prompt_token_ids() return self.data.get_prompt_token_ids()
def get_last_token_id(self) -> int: def get_last_token_id(self) -> int:
return self.data.get_last_token_id() return self.data.get_last_token_id()
def get_output_token_ids(self) -> Tuple[int, ...]: def get_output_token_ids(self) -> tuple[int, ...]:
return self.data.get_output_token_ids() return self.data.get_output_token_ids()
def get_cumulative_logprob(self) -> float: def get_cumulative_logprob(self) -> float:
...@@ -644,7 +644,7 @@ class SequenceGroup: ...@@ -644,7 +644,7 @@ class SequenceGroup:
def __init__( def __init__(
self, self,
request_id: str, request_id: str,
seqs: List[Sequence], seqs: list[Sequence],
arrival_time: float, arrival_time: float,
sampling_params: Optional[SamplingParams] = None, sampling_params: Optional[SamplingParams] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
...@@ -686,7 +686,7 @@ class SequenceGroup: ...@@ -686,7 +686,7 @@ class SequenceGroup:
return self.first_seq.prompt return self.first_seq.prompt
@property @property
def prompt_token_ids(self) -> List[int]: def prompt_token_ids(self) -> list[int]:
return self.first_seq.prompt_token_ids return self.first_seq.prompt_token_ids
@property @property
...@@ -698,7 +698,7 @@ class SequenceGroup: ...@@ -698,7 +698,7 @@ class SequenceGroup:
if self.encoder_seq is not None else None) if self.encoder_seq is not None else None)
@property @property
def encoder_prompt_token_ids(self) -> Optional[List[int]]: def encoder_prompt_token_ids(self) -> Optional[list[int]]:
# There are either 0 or 1 encoder sequences # There are either 0 or 1 encoder sequences
# If one is present, its prompt token ids are # If one is present, its prompt token ids are
# distinct from the decoder's. # distinct from the decoder's.
...@@ -706,7 +706,7 @@ class SequenceGroup: ...@@ -706,7 +706,7 @@ class SequenceGroup:
if self.encoder_seq is not None else None) if self.encoder_seq is not None else None)
@property @property
def token_type_ids(self) -> Optional[List[int]]: def token_type_ids(self) -> Optional[list[int]]:
return self.first_seq.token_type_ids return self.first_seq.token_type_ids
@property @property
...@@ -726,7 +726,7 @@ class SequenceGroup: ...@@ -726,7 +726,7 @@ class SequenceGroup:
return {} return {}
@property @property
def mm_processor_kwargs(self) -> Dict[str, Any]: def mm_processor_kwargs(self) -> dict[str, Any]:
if self.first_seq.multi_modal_data: if self.first_seq.multi_modal_data:
return self.first_seq.mm_processor_kwargs return self.first_seq.mm_processor_kwargs
elif self.encoder_seq is not None: elif self.encoder_seq is not None:
...@@ -823,7 +823,7 @@ class SequenceGroup: ...@@ -823,7 +823,7 @@ class SequenceGroup:
def get_seqs( def get_seqs(
self, self,
status: Optional[SequenceStatus] = None, status: Optional[SequenceStatus] = None,
) -> List[Sequence]: ) -> list[Sequence]:
if status is None: if status is None:
return self.seqs return self.seqs
...@@ -838,7 +838,7 @@ class SequenceGroup: ...@@ -838,7 +838,7 @@ class SequenceGroup:
def get_encoder_seq(self) -> Optional[Sequence]: def get_encoder_seq(self) -> Optional[Sequence]:
return self.encoder_seq return self.encoder_seq
def get_finished_seqs(self) -> List[Sequence]: def get_finished_seqs(self) -> list[Sequence]:
if self.is_single_seq: if self.is_single_seq:
return self.seqs if self.first_seq.is_finished() else [] return self.seqs if self.first_seq.is_finished() else []
...@@ -897,13 +897,13 @@ class SequenceGroupMetadataDelta( ...@@ -897,13 +897,13 @@ class SequenceGroupMetadataDelta(
After sending the first SequenceGroupMetadata, vLLM scheduler After sending the first SequenceGroupMetadata, vLLM scheduler
only sends delta to reduce the data payload size. only sends delta to reduce the data payload size.
""" """
seq_data_delta: Dict[int, SequenceDataDelta] seq_data_delta: dict[int, SequenceDataDelta]
request_id: str request_id: str
block_tables: Dict[int, List[int]] block_tables: dict[int, list[int]]
is_prompt: bool is_prompt: bool
do_sample: bool = True do_sample: bool = True
token_chunk_size: Optional[int] = None token_chunk_size: Optional[int] = None
computed_block_nums: Optional[List[int]] = None computed_block_nums: Optional[list[int]] = None
state: Optional[SequenceGroupState] = msgspec.field( state: Optional[SequenceGroupState] = msgspec.field(
default_factory=lambda: SequenceGroupState()) default_factory=lambda: SequenceGroupState())
...@@ -947,23 +947,23 @@ class SequenceGroupMetadata( ...@@ -947,23 +947,23 @@ class SequenceGroupMetadata(
request_id: str request_id: str
is_prompt: bool is_prompt: bool
seq_data: Dict[int, SequenceData] seq_data: dict[int, SequenceData]
sampling_params: Optional[SamplingParams] sampling_params: Optional[SamplingParams]
block_tables: Dict[int, List[int]] block_tables: dict[int, list[int]]
do_sample: bool = True do_sample: bool = True
pooling_params: Optional[PoolingParams] = None pooling_params: Optional[PoolingParams] = None
lora_request: Optional[LoRARequest] = None lora_request: Optional[LoRARequest] = None
computed_block_nums: Optional[List[int]] = None computed_block_nums: Optional[list[int]] = None
state: Optional[SequenceGroupState] = msgspec.field( state: Optional[SequenceGroupState] = msgspec.field(
default_factory=lambda: SequenceGroupState()) default_factory=lambda: SequenceGroupState())
# "MultiModalDataDict" types. We have to use Any due to msgspec # "MultiModalDataDict" types. We have to use Any due to msgspec
# doesn't allow to have union of 2 different dicts. # doesn't allow to have union of 2 different dicts.
token_type_ids: Optional[List[int]] = None token_type_ids: Optional[list[int]] = None
multi_modal_data: Optional[Any] = None multi_modal_data: Optional[Any] = None
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None mm_processor_kwargs: Optional[dict[str, Any]] = None
encoder_seq_data: Optional[SequenceData] = None encoder_seq_data: Optional[SequenceData] = None
cross_block_table: Optional[List[int]] = None cross_block_table: Optional[list[int]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None
token_chunk_size: Optional[int] = None token_chunk_size: Optional[int] = None
...@@ -1042,7 +1042,7 @@ class SequenceOutput( ...@@ -1042,7 +1042,7 @@ class SequenceOutput(
""" """
parent_seq_id: int parent_seq_id: int
output_token: int output_token: int
logprobs: Dict[int, Logprob] logprobs: dict[int, Logprob]
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
...@@ -1076,7 +1076,7 @@ class CompletionSequenceGroupOutput( ...@@ -1076,7 +1076,7 @@ class CompletionSequenceGroupOutput(
array_like=True): # type: ignore[call-arg] array_like=True): # type: ignore[call-arg]
"""The model output associated with a completion sequence group.""" """The model output associated with a completion sequence group."""
__metaclass__ = SequenceGroupOutput __metaclass__ = SequenceGroupOutput
samples: List[SequenceOutput] samples: list[SequenceOutput]
# Prompt logprob for each prompt query token. # Prompt logprob for each prompt query token.
prompt_logprobs: Optional[PromptLogprobs] prompt_logprobs: Optional[PromptLogprobs]
...@@ -1119,7 +1119,7 @@ class IntermediateTensors: ...@@ -1119,7 +1119,7 @@ class IntermediateTensors:
contains the hidden states and residuals for a request. contains the hidden states and residuals for a request.
""" """
tensors: Dict[str, torch.Tensor] tensors: dict[str, torch.Tensor]
def __init__(self, tensors): def __init__(self, tensors):
# manually define this function, so that # manually define this function, so that
...@@ -1155,7 +1155,7 @@ class PoolerOutput( ...@@ -1155,7 +1155,7 @@ class PoolerOutput(
omit_defaults=True, # type: ignore[call-arg] omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg] array_like=True): # type: ignore[call-arg]
"""The output from a pooling operation in the pooling model.""" """The output from a pooling operation in the pooling model."""
outputs: List[PoolingSequenceGroupOutput] outputs: list[PoolingSequenceGroupOutput]
def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput: def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput:
return self.outputs[idx] return self.outputs[idx]
...@@ -1172,7 +1172,7 @@ class PoolerOutput( ...@@ -1172,7 +1172,7 @@ class PoolerOutput(
def get_all_seq_ids( def get_all_seq_ids(
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]: seq_group_metadata_list: list[SequenceGroupMetadata]) -> list[int]:
"""Given a list of SequenceGroupMetadata, create a list of all """Given a list of SequenceGroupMetadata, create a list of all
sequence ids. sequence ids.
""" """
...@@ -1180,13 +1180,13 @@ def get_all_seq_ids( ...@@ -1180,13 +1180,13 @@ def get_all_seq_ids(
def get_all_seq_ids_and_request_ids( def get_all_seq_ids_and_request_ids(
seq_group_metadata_list: List[SequenceGroupMetadata] seq_group_metadata_list: list[SequenceGroupMetadata]
) -> Tuple[List[int], Dict[str, Set[int]]]: ) -> tuple[list[int], dict[str, set[int]]]:
"""Given a list of SequenceGroupMetadata, create a list of all """Given a list of SequenceGroupMetadata, create a list of all
sequence ids. sequence ids.
""" """
seq_ids: List[int] = [] seq_ids: list[int] = []
request_id_seq_ids_mapping: DefaultDict[str, Set[int]] = defaultdict(set) request_id_seq_ids_mapping: defaultdict[str, set[int]] = defaultdict(set)
for sg in seq_group_metadata_list: for sg in seq_group_metadata_list:
for seq_id in sg.seq_data: for seq_id in sg.seq_data:
seq_ids.append(seq_id) seq_ids.append(seq_id)
...@@ -1206,14 +1206,14 @@ class HiddenStates(msgspec.Struct, array_like=True, ...@@ -1206,14 +1206,14 @@ class HiddenStates(msgspec.Struct, array_like=True,
# all tokens, whereas for decode step, it use used for last accepted tokens. # all tokens, whereas for decode step, it use used for last accepted tokens.
hidden_states: torch.Tensor hidden_states: torch.Tensor
# The sequence group metadata list. Only needed for decode step. # The sequence group metadata list. Only needed for decode step.
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None seq_group_metadata_list: Optional[list[SequenceGroupMetadata]] = None
# Scorer hidden states of the 2nd last token proposed by the proposer ( # Scorer hidden states of the 2nd last token proposed by the proposer (
# irrespective of whether it was accepted or not). Only used for cases when # irrespective of whether it was accepted or not). Only used for cases when
# last proposed token is accepted (i.e., in case of bonus tokens). For the # last proposed token is accepted (i.e., in case of bonus tokens). For the
# case of no bonus tokens, these are ignored. # case of no bonus tokens, these are ignored.
second_last_token_hidden_states: Optional[torch.Tensor] = None second_last_token_hidden_states: Optional[torch.Tensor] = None
_seq_ids: List[int] = msgspec.field(default_factory=list) _seq_ids: list[int] = msgspec.field(default_factory=list)
def __post_init__(self): def __post_init__(self):
if self.seq_group_metadata_list is not None: if self.seq_group_metadata_list is not None:
...@@ -1221,12 +1221,12 @@ class HiddenStates(msgspec.Struct, array_like=True, ...@@ -1221,12 +1221,12 @@ class HiddenStates(msgspec.Struct, array_like=True,
self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list) self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
@property @property
def seq_ids(self) -> List[int]: def seq_ids(self) -> list[int]:
return self._seq_ids return self._seq_ids
def update(self, def update(self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: list[SequenceGroupMetadata],
second_last_token_hidden_states: Optional[torch.Tensor] = None): second_last_token_hidden_states: Optional[torch.Tensor] = None):
"""Update hidden states from target model invocation. Only used for """Update hidden states from target model invocation. Only used for
decode steps""" decode steps"""
...@@ -1244,7 +1244,7 @@ class HiddenStates(msgspec.Struct, array_like=True, ...@@ -1244,7 +1244,7 @@ class HiddenStates(msgspec.Struct, array_like=True,
]) ])
def prune(self, def prune(self,
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None: seq_group_metadata_list: list[SequenceGroupMetadata]) -> None:
"""Prune to provided list of sequence ids. Only used for decode steps. """Prune to provided list of sequence ids. Only used for decode steps.
""" """
# Currently this prunes all seq_ids not present in # Currently this prunes all seq_ids not present in
...@@ -1287,16 +1287,16 @@ class ExecuteModelRequest( ...@@ -1287,16 +1287,16 @@ class ExecuteModelRequest(
"""The model execution request, containing CPU metadata only. The LLM """The model execution request, containing CPU metadata only. The LLM
engine should create an instance of this class for each request batch.""" engine should create an instance of this class for each request batch."""
# The sequence group metadata list. # The sequence group metadata list.
seq_group_metadata_list: List[Union[SequenceGroupMetadata, seq_group_metadata_list: list[Union[SequenceGroupMetadata,
SequenceGroupMetadataDelta]] SequenceGroupMetadataDelta]]
# Blocks to swap in. List of CPU -> GPU block number. # Blocks to swap in. List of CPU -> GPU block number.
blocks_to_swap_in: List[Tuple[int, blocks_to_swap_in: list[tuple[int,
int]] = msgspec.field(default_factory=list) int]] = msgspec.field(default_factory=list)
# Blocks to swap out. List of GPU -> CPU block number. # Blocks to swap out. List of GPU -> CPU block number.
blocks_to_swap_out: List[Tuple[int, blocks_to_swap_out: list[tuple[int,
int]] = msgspec.field(default_factory=list) int]] = msgspec.field(default_factory=list)
# Blocks to copy. Source to dest block. # Blocks to copy. Source to dest block.
blocks_to_copy: List[Tuple[int, int]] = msgspec.field(default_factory=list) blocks_to_copy: list[tuple[int, int]] = msgspec.field(default_factory=list)
# Virtual engine ID for pipeline parallel. # Virtual engine ID for pipeline parallel.
virtual_engine: int = 0 virtual_engine: int = 0
# The number of slots for lookahead decoding. # The number of slots for lookahead decoding.
...@@ -1310,7 +1310,7 @@ class ExecuteModelRequest( ...@@ -1310,7 +1310,7 @@ class ExecuteModelRequest(
# The step index for spec model input. # The step index for spec model input.
spec_step_idx: Optional[int] = None spec_step_idx: Optional[int] = None
# Finished request ids since last step. # Finished request ids since last step.
finished_requests_ids: List[str] = msgspec.field(default_factory=list) finished_requests_ids: list[str] = msgspec.field(default_factory=list)
# The last sampled token ids for multi step decoding. # The last sampled token ids for multi step decoding.
last_sampled_token_ids: Optional[torch.Tensor] = None last_sampled_token_ids: Optional[torch.Tensor] = None
# Async callback # Async callback
...@@ -1344,7 +1344,7 @@ class ExecuteModelRequest( ...@@ -1344,7 +1344,7 @@ class ExecuteModelRequest(
return state.current_step return state.current_step
def clone( def clone(
self, seq_group_metadata_list: List[Union[SequenceGroupMetadata, self, seq_group_metadata_list: list[Union[SequenceGroupMetadata,
SequenceGroupMetadataDelta]] SequenceGroupMetadataDelta]]
) -> "ExecuteModelRequest": ) -> "ExecuteModelRequest":
"""Clone the request with a new sequence group metadata list.""" """Clone the request with a new sequence group metadata list."""
...@@ -1371,13 +1371,13 @@ class SequenceGroupBase: ...@@ -1371,13 +1371,13 @@ class SequenceGroupBase:
assembled_seq_group: Optional[SequenceGroup] = None assembled_seq_group: Optional[SequenceGroup] = None
# seq id to a unique index inside this group # seq id to a unique index inside this group
seq_id_to_index: Dict[str, int] = field(default_factory=dict) seq_id_to_index: dict[str, int] = field(default_factory=dict)
# seq ids to be finished # seq ids to be finished
to_be_finished: Dict[str, SequenceGroup] = field(default_factory=dict) to_be_finished: dict[str, SequenceGroup] = field(default_factory=dict)
# seq id to finished sequences # seq id to finished sequences
finished_reqs: Dict[str, SequenceGroup] = field(default_factory=dict) finished_reqs: dict[str, SequenceGroup] = field(default_factory=dict)
streaming: bool = False streaming: bool = False
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os import os
from typing import Mapping, Optional from collections.abc import Mapping
from typing import Optional
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import run_once from vllm.utils import run_once
......
...@@ -28,12 +28,12 @@ import warnings ...@@ -28,12 +28,12 @@ import warnings
import weakref import weakref
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
from collections import OrderedDict, UserDict, defaultdict from collections import OrderedDict, UserDict, defaultdict
from collections.abc import Hashable, Iterable, Mapping from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable,
Iterable, Iterator, Mapping)
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import cache, lru_cache, partial, wraps from functools import cache, lru_cache, partial, wraps
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
Dict, Generator, Generic, Iterator, List, Literal, Optional, TypeVar, Union)
NamedTuple, Optional, Tuple, Type, TypeVar, Union)
from uuid import uuid4 from uuid import uuid4
import cloudpickle import cloudpickle
...@@ -400,7 +400,7 @@ def _next_task(iterator: AsyncGenerator[T, None], ...@@ -400,7 +400,7 @@ def _next_task(iterator: AsyncGenerator[T, None],
async def merge_async_iterators( async def merge_async_iterators(
*iterators: AsyncGenerator[T, *iterators: AsyncGenerator[T,
None], ) -> AsyncGenerator[Tuple[int, T], None]: None], ) -> AsyncGenerator[tuple[int, T], None]:
"""Merge multiple asynchronous iterators into a single iterator. """Merge multiple asynchronous iterators into a single iterator.
This method handle the case where some iterators finish before others. This method handle the case where some iterators finish before others.
...@@ -433,7 +433,7 @@ async def merge_async_iterators( ...@@ -433,7 +433,7 @@ async def merge_async_iterators(
async def collect_from_async_generator( async def collect_from_async_generator(
iterator: AsyncGenerator[T, None]) -> List[T]: iterator: AsyncGenerator[T, None]) -> list[T]:
"""Collect all items from an async generator into a list.""" """Collect all items from an async generator into a list."""
items = [] items = []
async for item in iterator: async for item in iterator:
...@@ -560,7 +560,7 @@ def find_process_using_port(port: int) -> Optional[psutil.Process]: ...@@ -560,7 +560,7 @@ def find_process_using_port(port: int) -> Optional[psutil.Process]:
return None return None
def update_environment_variables(envs: Dict[str, str]): def update_environment_variables(envs: dict[str, str]):
for k, v in envs.items(): for k, v in envs.items():
if k in os.environ and os.environ[k] != v: if k in os.environ and os.environ[k] != v:
logger.warning( logger.warning(
...@@ -569,7 +569,7 @@ def update_environment_variables(envs: Dict[str, str]): ...@@ -569,7 +569,7 @@ def update_environment_variables(envs: Dict[str, str]):
os.environ[k] = v os.environ[k] = v
def chunk_list(lst: List[T], chunk_size: int): def chunk_list(lst: list[T], chunk_size: int):
"""Yield successive chunk_size chunks from lst.""" """Yield successive chunk_size chunks from lst."""
for i in range(0, len(lst), chunk_size): for i in range(0, len(lst), chunk_size):
yield lst[i:i + chunk_size] yield lst[i:i + chunk_size]
...@@ -642,7 +642,7 @@ def create_kv_caches_with_random_flash( ...@@ -642,7 +642,7 @@ def create_kv_caches_with_random_flash(
model_dtype: Optional[Union[str, torch.dtype]] = None, model_dtype: Optional[Union[str, torch.dtype]] = None,
seed: int = 0, seed: int = 0,
device: Optional[str] = "cuda", device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
from vllm.platforms import current_platform from vllm.platforms import current_platform
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
...@@ -650,8 +650,8 @@ def create_kv_caches_with_random_flash( ...@@ -650,8 +650,8 @@ def create_kv_caches_with_random_flash(
key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
scale = head_size**-0.5 scale = head_size**-0.5
key_caches: List[torch.Tensor] = [] key_caches: list[torch.Tensor] = []
value_caches: List[torch.Tensor] = [] value_caches: list[torch.Tensor] = []
for _ in range(num_layers): for _ in range(num_layers):
key_value_cache = torch.empty(size=key_value_cache_shape, key_value_cache = torch.empty(size=key_value_cache_shape,
...@@ -679,7 +679,7 @@ def create_kv_caches_with_random( ...@@ -679,7 +679,7 @@ def create_kv_caches_with_random(
model_dtype: Optional[Union[str, torch.dtype]] = None, model_dtype: Optional[Union[str, torch.dtype]] = None,
seed: int = 0, seed: int = 0,
device: Optional[str] = "cuda", device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
if cache_dtype == "fp8" and head_size % 16: if cache_dtype == "fp8" and head_size % 16:
raise ValueError( raise ValueError(
...@@ -693,7 +693,7 @@ def create_kv_caches_with_random( ...@@ -693,7 +693,7 @@ def create_kv_caches_with_random(
scale = head_size**-0.5 scale = head_size**-0.5
x = 16 // torch.tensor([], dtype=torch_dtype).element_size() x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
key_caches: List[torch.Tensor] = [] key_caches: list[torch.Tensor] = []
for _ in range(num_layers): for _ in range(num_layers):
key_cache = torch.empty(size=key_cache_shape, key_cache = torch.empty(size=key_cache_shape,
dtype=torch_dtype, dtype=torch_dtype,
...@@ -708,7 +708,7 @@ def create_kv_caches_with_random( ...@@ -708,7 +708,7 @@ def create_kv_caches_with_random(
key_caches.append(key_cache) key_caches.append(key_cache)
value_cache_shape = (num_blocks, num_heads, head_size, block_size) value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_caches: List[torch.Tensor] = [] value_caches: list[torch.Tensor] = []
for _ in range(num_layers): for _ in range(num_layers):
value_cache = torch.empty(size=value_cache_shape, value_cache = torch.empty(size=value_cache_shape,
dtype=torch_dtype, dtype=torch_dtype,
...@@ -754,7 +754,7 @@ class DeviceMemoryProfiler: ...@@ -754,7 +754,7 @@ class DeviceMemoryProfiler:
def make_ndarray_with_pad( def make_ndarray_with_pad(
x: List[List[T]], x: list[list[T]],
pad: T, pad: T,
dtype: npt.DTypeLike, dtype: npt.DTypeLike,
*, *,
...@@ -779,7 +779,7 @@ def make_ndarray_with_pad( ...@@ -779,7 +779,7 @@ def make_ndarray_with_pad(
def make_tensor_with_pad( def make_tensor_with_pad(
x: List[List[T]], x: list[list[T]],
pad: T, pad: T,
dtype: torch.dtype, dtype: torch.dtype,
*, *,
...@@ -831,7 +831,7 @@ def is_list_of( ...@@ -831,7 +831,7 @@ def is_list_of(
typ: Union[type[T], tuple[type[T], ...]], typ: Union[type[T], tuple[type[T], ...]],
*, *,
check: Literal["first", "all"] = "first", check: Literal["first", "all"] = "first",
) -> TypeIs[List[T]]: ) -> TypeIs[list[T]]:
if not isinstance(value, list): if not isinstance(value, list):
return False return False
...@@ -843,8 +843,8 @@ def is_list_of( ...@@ -843,8 +843,8 @@ def is_list_of(
assert_never(check) assert_never(check)
JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"], JSONTree = Union[dict[str, "JSONTree[T]"], list["JSONTree[T]"],
Tuple["JSONTree[T]", ...], T] tuple["JSONTree[T]", ...], T]
"""A nested JSON structure where the leaves need not be JSON-serializable.""" """A nested JSON structure where the leaves need not be JSON-serializable."""
...@@ -859,7 +859,7 @@ def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]: ...@@ -859,7 +859,7 @@ def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]:
return func(value) return func(value)
def flatten_2d_lists(lists: List[List[T]]) -> List[T]: def flatten_2d_lists(lists: list[list[T]]) -> list[T]:
"""Flatten a list of lists to a single list.""" """Flatten a list of lists to a single list."""
return [item for sublist in lists for item in sublist] return [item for sublist in lists for item in sublist]
...@@ -1226,7 +1226,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser): ...@@ -1226,7 +1226,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
return value return value
def _pull_args_from_config(self, args: List[str]) -> List[str]: def _pull_args_from_config(self, args: list[str]) -> list[str]:
"""Method to pull arguments specified in the config file """Method to pull arguments specified in the config file
into the command-line args variable. into the command-line args variable.
...@@ -1291,7 +1291,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser): ...@@ -1291,7 +1291,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
return args return args
def _load_config_file(self, file_path: str) -> List[str]: def _load_config_file(self, file_path: str) -> list[str]:
"""Loads a yaml file and returns the key value pairs as a """Loads a yaml file and returns the key value pairs as a
flattened list with argparse like pattern flattened list with argparse like pattern
```yaml ```yaml
...@@ -1313,9 +1313,9 @@ class FlexibleArgumentParser(argparse.ArgumentParser): ...@@ -1313,9 +1313,9 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
%s supplied", extension) %s supplied", extension)
# only expecting a flat dictionary of atomic types # only expecting a flat dictionary of atomic types
processed_args: List[str] = [] processed_args: list[str] = []
config: Dict[str, Union[int, str]] = {} config: dict[str, Union[int, str]] = {}
try: try:
with open(file_path) as config_file: with open(file_path) as config_file:
config = yaml.safe_load(config_file) config = yaml.safe_load(config_file)
...@@ -1399,7 +1399,7 @@ def resolve_mm_processor_kwargs( ...@@ -1399,7 +1399,7 @@ def resolve_mm_processor_kwargs(
*, *,
requires_kw_only: bool = True, requires_kw_only: bool = True,
allow_var_kwargs: bool = False, allow_var_kwargs: bool = False,
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Applies filtering to eliminate invalid mm_processor_kwargs, i.e., """Applies filtering to eliminate invalid mm_processor_kwargs, i.e.,
those who are not explicit keywords to the given callable (of one is those who are not explicit keywords to the given callable (of one is
given; otherwise no filtering is done), then merges the kwarg dicts, given; otherwise no filtering is done), then merges the kwarg dicts,
...@@ -1440,7 +1440,7 @@ def get_allowed_kwarg_only_overrides( ...@@ -1440,7 +1440,7 @@ def get_allowed_kwarg_only_overrides(
*, *,
requires_kw_only: bool = True, requires_kw_only: bool = True,
allow_var_kwargs: bool = False, allow_var_kwargs: bool = False,
) -> Dict[str, Any]: ) -> dict[str, Any]:
""" """
Given a callable which has one or more keyword only params and a dict Given a callable which has one or more keyword only params and a dict
mapping param names to values, drop values that can be not be kwarg mapping param names to values, drop values that can be not be kwarg
...@@ -1531,9 +1531,9 @@ class AtomicCounter: ...@@ -1531,9 +1531,9 @@ class AtomicCounter:
# Adapted from: https://stackoverflow.com/a/47212782/5082708 # Adapted from: https://stackoverflow.com/a/47212782/5082708
class LazyDict(Mapping[str, T], Generic[T]): class LazyDict(Mapping[str, T], Generic[T]):
def __init__(self, factory: Dict[str, Callable[[], T]]): def __init__(self, factory: dict[str, Callable[[], T]]):
self._factory = factory self._factory = factory
self._dict: Dict[str, T] = {} self._dict: dict[str, T] = {}
def __getitem__(self, key: str) -> T: def __getitem__(self, key: str) -> T:
if key not in self._dict: if key not in self._dict:
...@@ -1552,9 +1552,9 @@ class LazyDict(Mapping[str, T], Generic[T]): ...@@ -1552,9 +1552,9 @@ class LazyDict(Mapping[str, T], Generic[T]):
return len(self._factory) return len(self._factory)
class ClassRegistry(UserDict[Type[T], _V]): class ClassRegistry(UserDict[type[T], _V]):
def __getitem__(self, key: Type[T]) -> _V: def __getitem__(self, key: type[T]) -> _V:
for cls in key.mro(): for cls in key.mro():
if cls in self.data: if cls in self.data:
return self.data[cls] return self.data[cls]
...@@ -1584,8 +1584,8 @@ def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor: ...@@ -1584,8 +1584,8 @@ def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor:
def weak_ref_tensors( def weak_ref_tensors(
tensors: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]] tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]: ) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]:
""" """
Convenience function to create weak references to tensors, Convenience function to create weak references to tensors,
for single tensor, list of tensors or tuple of tensors. for single tensor, list of tensors or tuple of tensors.
...@@ -1857,7 +1857,7 @@ vllm_lib = Library("vllm", "FRAGMENT") # noqa ...@@ -1857,7 +1857,7 @@ vllm_lib = Library("vllm", "FRAGMENT") # noqa
def direct_register_custom_op( def direct_register_custom_op(
op_name: str, op_name: str,
op_func: Callable, op_func: Callable,
mutates_args: List[str], mutates_args: list[str],
fake_impl: Optional[Callable] = None, fake_impl: Optional[Callable] = None,
target_lib: Optional[Library] = None, target_lib: Optional[Library] = None,
dispatch_key: str = "CUDA", dispatch_key: str = "CUDA",
...@@ -2177,8 +2177,8 @@ def get_mp_context(): ...@@ -2177,8 +2177,8 @@ def get_mp_context():
def bind_kv_cache( def bind_kv_cache(
ctx: Dict[str, Any], ctx: dict[str, Any],
kv_cache: List[List[torch.Tensor]], # [virtual_engine][layer_index] kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index]
) -> None: ) -> None:
# Bind the kv_cache tensor to Attention modules, similar to # Bind the kv_cache tensor to Attention modules, similar to
# ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)] # ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
...@@ -2210,8 +2210,8 @@ def bind_kv_cache( ...@@ -2210,8 +2210,8 @@ def bind_kv_cache(
forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx] forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]
def run_method(obj: Any, method: Union[str, bytes, Callable], args: Tuple[Any], def run_method(obj: Any, method: Union[str, bytes, Callable], args: tuple[Any],
kwargs: Dict[str, Any]) -> Any: kwargs: dict[str, Any]) -> Any:
""" """
Run a method of an object with the given arguments and keyword arguments. Run a method of an object with the given arguments and keyword arguments.
If the method is string, it will be converted to a method using getattr. If the method is string, it will be converted to a method using getattr.
...@@ -2263,7 +2263,7 @@ def import_pynvml(): ...@@ -2263,7 +2263,7 @@ def import_pynvml():
return pynvml return pynvml
def warn_for_unimplemented_methods(cls: Type[T]) -> Type[T]: def warn_for_unimplemented_methods(cls: type[T]) -> type[T]:
""" """
A replacement for `abc.ABC`. A replacement for `abc.ABC`.
When we use `abc.ABC`, subclasses will fail to instantiate When we use `abc.ABC`, subclasses will fail to instantiate
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Attention layer with FlashAttention.""" """Attention layer with FlashAttention."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type from typing import TYPE_CHECKING, Any, Optional
import numpy as np import numpy as np
import torch import torch
...@@ -30,7 +30,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -30,7 +30,7 @@ class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
@staticmethod @staticmethod
def get_supported_head_sizes() -> List[int]: def get_supported_head_sizes() -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256] return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod @staticmethod
...@@ -38,15 +38,15 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -38,15 +38,15 @@ class FlashAttentionBackend(AttentionBackend):
return "FLASH_ATTN_VLLM_V1" return "FLASH_ATTN_VLLM_V1"
@staticmethod @staticmethod
def get_impl_cls() -> Type["FlashAttentionImpl"]: def get_impl_cls() -> type["FlashAttentionImpl"]:
return FlashAttentionImpl return FlashAttentionImpl
@staticmethod @staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]: def get_metadata_cls() -> type["AttentionMetadata"]:
return FlashAttentionMetadata return FlashAttentionMetadata
@staticmethod @staticmethod
def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]: def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder return FlashAttentionMetadataBuilder
@staticmethod @staticmethod
...@@ -55,7 +55,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -55,7 +55,7 @@ class FlashAttentionBackend(AttentionBackend):
block_size: int, block_size: int,
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
) -> Tuple[int, ...]: ) -> tuple[int, ...]:
if block_size % 16 != 0: if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.") raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size) return (2, num_blocks, block_size, num_kv_heads, head_size)
...@@ -158,10 +158,10 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -158,10 +158,10 @@ class FlashAttentionImpl(AttentionImpl):
head_size: int, head_size: int,
scale: float, scale: float,
num_kv_heads: int, num_kv_heads: int,
alibi_slopes: Optional[List[float]], alibi_slopes: Optional[list[float]],
sliding_window: Optional[int], sliding_window: Optional[int],
kv_cache_dtype: str, kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None, blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None, logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
) -> None: ) -> None:
...@@ -381,7 +381,7 @@ def cascade_attention( ...@@ -381,7 +381,7 @@ def cascade_attention(
max_kv_len: int, max_kv_len: int,
softmax_scale: float, softmax_scale: float,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
sliding_window: Tuple[int, int], sliding_window: tuple[int, int],
logits_soft_cap: float, logits_soft_cap: float,
block_table: torch.Tensor, block_table: torch.Tensor,
common_prefix_len: int, common_prefix_len: int,
......
...@@ -195,8 +195,7 @@ return curr_o @ W_O ...@@ -195,8 +195,7 @@ return curr_o @ W_O
import functools import functools
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple, from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
Type, TypeVar)
import torch import torch
from compressed_tensors.quantization import QuantizationStrategy from compressed_tensors.quantization import QuantizationStrategy
...@@ -250,11 +249,11 @@ class MLACommonBackend(AttentionBackend): ...@@ -250,11 +249,11 @@ class MLACommonBackend(AttentionBackend):
return "TRITON_MLA_VLLM_V1" return "TRITON_MLA_VLLM_V1"
@staticmethod @staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]: def get_metadata_cls() -> type["AttentionMetadata"]:
return MLACommonMetadata return MLACommonMetadata
@staticmethod @staticmethod
def get_builder_cls() -> Type["MLACommonMetadataBuilder"]: def get_builder_cls() -> type["MLACommonMetadataBuilder"]:
return MLACommonMetadataBuilder return MLACommonMetadataBuilder
@staticmethod @staticmethod
...@@ -263,11 +262,11 @@ class MLACommonBackend(AttentionBackend): ...@@ -263,11 +262,11 @@ class MLACommonBackend(AttentionBackend):
block_size: int, block_size: int,
num_kv_heads: int, # assumed to be 1 for MLA num_kv_heads: int, # assumed to be 1 for MLA
head_size: int, head_size: int,
) -> Tuple[int, ...]: ) -> tuple[int, ...]:
return (num_blocks, block_size, head_size) return (num_blocks, block_size, head_size)
@staticmethod @staticmethod
def get_supported_head_sizes() -> List[int]: def get_supported_head_sizes() -> list[int]:
return [576] return [576]
@staticmethod @staticmethod
...@@ -317,8 +316,8 @@ class MLACommonMetadata: ...@@ -317,8 +316,8 @@ class MLACommonMetadata:
has_context: bool = False has_context: bool = False
context_chunk_cu_seq_lens: Optional[torch.Tensor] = None context_chunk_cu_seq_lens: Optional[torch.Tensor] = None
context_chunk_starts: Optional[torch.Tensor] = None context_chunk_starts: Optional[torch.Tensor] = None
context_chunk_seq_tot: Optional[List[int]] = None context_chunk_seq_tot: Optional[list[int]] = None
context_chunk_max_seq_lens: Optional[List[int]] = None context_chunk_max_seq_lens: Optional[list[int]] = None
chunked_prefill_workspace: Optional[torch.Tensor] = None chunked_prefill_workspace: Optional[torch.Tensor] = None
def __post_init__(self): def __post_init__(self):
...@@ -538,10 +537,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -538,10 +537,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
head_size: int, head_size: int,
scale: float, scale: float,
num_kv_heads: int, num_kv_heads: int,
alibi_slopes: Optional[List[float]], alibi_slopes: Optional[list[float]],
sliding_window: Optional[int], sliding_window: Optional[int],
kv_cache_dtype: str, kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]], blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float], logits_soft_cap: Optional[float],
attn_type: str, attn_type: str,
# MLA Specific Arguments # MLA Specific Arguments
...@@ -634,7 +633,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -634,7 +633,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# #
# returns input_group_shape, weight_group_shape # returns input_group_shape, weight_group_shape
def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \ def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \
Tuple[Tuple[int, int], Tuple[int, int]]: tuple[tuple[int, int], tuple[int, int]]:
if isinstance(layer.quant_method, Fp8LinearMethod): if isinstance(layer.quant_method, Fp8LinearMethod):
if layer.quant_method.block_quant: if layer.quant_method.block_quant:
weight_block_size = \ weight_block_size = \
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type from typing import Any, Optional
import torch import torch
...@@ -25,21 +25,21 @@ class FlashMLABackend(MLACommonBackend): ...@@ -25,21 +25,21 @@ class FlashMLABackend(MLACommonBackend):
return "FLASHMLA_VLLM_V1" return "FLASHMLA_VLLM_V1"
@staticmethod @staticmethod
def get_metadata_cls() -> Type["FlashMLAMetadata"]: def get_metadata_cls() -> type["FlashMLAMetadata"]:
return FlashMLAMetadata return FlashMLAMetadata
@staticmethod @staticmethod
def get_builder_cls() -> Type["FlashMLAMetadataBuilder"]: def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
return FlashMLAMetadataBuilder return FlashMLAMetadataBuilder
@staticmethod @staticmethod
def get_impl_cls() -> Type["FlashMLAImpl"]: def get_impl_cls() -> type["FlashMLAImpl"]:
return FlashMLAImpl return FlashMLAImpl
@dataclass @dataclass
class FlashMLAMetadata(MLACommonMetadata): class FlashMLAMetadata(MLACommonMetadata):
decode_tile_scheduler_metadata: Optional[Tuple[torch.Tensor, decode_tile_scheduler_metadata: Optional[tuple[torch.Tensor,
torch.Tensor]] = None torch.Tensor]] = None
decode_num_splits: Optional[torch.Tensor] = None decode_num_splits: Optional[torch.Tensor] = None
...@@ -76,10 +76,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -76,10 +76,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
head_size: int, head_size: int,
scale: float, scale: float,
num_kv_heads: int, num_kv_heads: int,
alibi_slopes: Optional[List[float]], alibi_slopes: Optional[list[float]],
sliding_window: Optional[int], sliding_window: Optional[int],
kv_cache_dtype: str, kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]], blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float], logits_soft_cap: Optional[float],
attn_type: str, attn_type: str,
# MLA Specific Arguments # MLA Specific Arguments
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment