Unverified Commit cf069aa8 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated Python 3.8 typing (#13971)

parent bf33700e
# SPDX-License-Identifier: Apache-2.0
from concurrent.futures import Future
from typing import List, Type, Union
from typing import Union
import torch
import torch.distributed as dist
......@@ -22,8 +22,8 @@ class Executor(ExecutorBase):
For methods shared by v0 and v1, define them in ExecutorBase"""
@staticmethod
def get_class(vllm_config: VllmConfig) -> Type["Executor"]:
executor_class: Type[Executor]
def get_class(vllm_config: VllmConfig) -> type["Executor"]:
executor_class: type[Executor]
parallel_config = vllm_config.parallel_config
distributed_executor_backend = (
parallel_config.distributed_executor_backend)
......@@ -53,7 +53,7 @@ class Executor(ExecutorBase):
return executor_class
def initialize_from_config(self,
kv_cache_configs: List[KVCacheConfig]) -> None:
kv_cache_configs: list[KVCacheConfig]) -> None:
"""
Initialize the KV caches and begin the model execution loop of the
underlying workers.
......@@ -69,7 +69,7 @@ class Executor(ExecutorBase):
# operators can be applied to all workers.
return min(output)
def get_kv_cache_specs(self) -> List[KVCacheSpec]:
def get_kv_cache_specs(self) -> list[KVCacheSpec]:
output = self.collective_rpc("get_kv_cache_spec")
return output
......
......@@ -10,7 +10,7 @@ from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
from multiprocessing.process import BaseProcess
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Optional, Union
import cloudpickle
import psutil
......@@ -77,7 +77,7 @@ class MultiprocExecutor(Executor):
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
# Create workers
self.workers: List[WorkerProcHandle] = []
self.workers: list[WorkerProcHandle] = []
for rank in range(self.world_size):
worker = WorkerProc.make_worker_process(self.vllm_config, rank,
rank,
......@@ -94,8 +94,8 @@ class MultiprocExecutor(Executor):
def collective_rpc(self,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
args: tuple = (),
kwargs: Optional[dict] = None) -> list[Any]:
start_time = time.monotonic()
kwargs = kwargs or {}
......@@ -208,7 +208,7 @@ class WorkerProc:
self.rank = rank
wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank)
# TODO: move `init_worker` to executor level as a collective rpc call
all_kwargs: List[Dict] = [
all_kwargs: list[dict] = [
{} for _ in range(vllm_config.parallel_config.world_size)
]
all_kwargs[rank] = {
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import Dict, List
import torch
......@@ -74,7 +73,7 @@ class FullAttentionSpec(KVCacheSpecBase):
return cdiv(num_tokens, self.block_size) * self.page_size_bytes
KVCacheSpec = Dict[str, KVCacheSpecBase]
KVCacheSpec = dict[str, KVCacheSpecBase]
@dataclass
......@@ -95,7 +94,7 @@ class KVCacheConfig:
"""The number of KV cache blocks"""
num_blocks: int
"""layer_name -> how to initialize KV cache for that layer"""
tensors: Dict[str, KVCacheTensor]
tensors: dict[str, KVCacheTensor]
"""
A list of kv-cache groups. Each group includes a set of layers with
the same kv-cache spec, and the total page_size of layers inside a group
......@@ -108,6 +107,6 @@ class KVCacheConfig:
3. (not implemented yet) A model with 2 full attention layers and 4 sliding
window attention layers: three groups, (full * 2), (sw * 2), (sw * 2).
"""
groups: List[List[str]]
groups: list[list[str]]
"""the KVCacheSpec of the model"""
kv_cache_spec: KVCacheSpec
......@@ -2,7 +2,7 @@
import time
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
from typing import Optional
import numpy as np
import prometheus_client
......@@ -35,8 +35,8 @@ class LoggingStatLogger(StatLoggerBase):
self.last_log_time = now
# Tracked stats over current local logging interval.
self.num_prompt_tokens: List[int] = []
self.num_generation_tokens: List[int] = []
self.num_prompt_tokens: list[int] = []
self.num_generation_tokens: list[int] = []
# Prefix cache metrics. TODO: Make the interval configurable.
self.prefix_caching_metrics = PrefixCachingMetrics()
......@@ -52,7 +52,7 @@ class LoggingStatLogger(StatLoggerBase):
self.num_generation_tokens.append(
iteration_stats.num_generation_tokens)
def _get_throughput(self, tracked_stats: List[int], now: float) -> float:
def _get_throughput(self, tracked_stats: list[int], now: float) -> float:
# Compute summary metrics for tracked stats
return float(np.sum(tracked_stats) / (now - self.last_log_time))
......@@ -147,7 +147,7 @@ class PrometheusStatLogger(StatLoggerBase):
documentation="Number of generation tokens processed.",
labelnames=labelnames).labels(*labelvalues)
self.counter_request_success: Dict[FinishReason,
self.counter_request_success: dict[FinishReason,
prometheus_client.Counter] = {}
counter_request_success_base = prometheus_client.Counter(
name="vllm:request_success_total",
......@@ -338,14 +338,14 @@ class PrometheusStatLogger(StatLoggerBase):
prometheus_client.REGISTRY.unregister(collector)
def build_buckets(mantissa_lst: List[int], max_value: int) -> List[int]:
def build_buckets(mantissa_lst: list[int], max_value: int) -> list[int]:
"""
Builds a list of buckets with increasing powers of 10 multiplied by
mantissa values until the value exceeds the specified maximum.
"""
exponent = 0
buckets: List[int] = []
buckets: list[int] = []
while True:
for m in mantissa_lst:
value = m * 10**exponent
......@@ -356,7 +356,7 @@ def build_buckets(mantissa_lst: List[int], max_value: int) -> List[int]:
exponent += 1
def build_1_2_5_buckets(max_value: int) -> List[int]:
def build_1_2_5_buckets(max_value: int) -> list[int]:
"""
Example:
>>> build_1_2_5_buckets(100)
......@@ -365,7 +365,7 @@ def build_1_2_5_buckets(max_value: int) -> List[int]:
return build_buckets([1, 2, 5], max_value)
def build_cudagraph_buckets(vllm_config: VllmConfig) -> List[int]:
def build_cudagraph_buckets(vllm_config: VllmConfig) -> list[int]:
if not vllm_config.model_config.enforce_eager:
buckets = vllm_config.compilation_config.\
cudagraph_capture_sizes.copy()
......
......@@ -2,7 +2,7 @@
import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Set
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from vllm.outputs import RequestOutput
......@@ -39,8 +39,8 @@ class SchedulerStats:
@dataclass
class LoRAStats:
waiting_requests: Set[str] = field(default_factory=set)
running_requests: Set[str] = field(default_factory=set)
waiting_requests: set[str] = field(default_factory=set)
running_requests: set[str] = field(default_factory=set)
@dataclass
......@@ -81,11 +81,11 @@ class IterationStats:
self.num_generation_tokens = 0
self.num_prompt_tokens = 0
self.num_preempted_reqs = 0
self.finished_requests: List[FinishedRequestStats] = []
self.time_to_first_tokens_iter: List[float] = []
self.time_per_output_tokens_iter: List[float] = []
self.waiting_lora_adapters: Dict[str, int] = {}
self.running_lora_adapters: Dict[str, int] = {}
self.finished_requests: list[FinishedRequestStats] = []
self.time_to_first_tokens_iter: list[float] = []
self.time_per_output_tokens_iter: list[float] = []
self.waiting_lora_adapters: dict[str, int] = {}
self.running_lora_adapters: dict[str, int] = {}
def _time_since(self, start: float) -> float:
"""Calculate an interval relative to this iteration's timestamp."""
......@@ -132,7 +132,7 @@ class IterationStats:
if num_new_generation_tokens > 0:
req_stats.last_token_ts = engine_core_timestamp
def update_from_events(self, req_id: str, events: List["EngineCoreEvent"],
def update_from_events(self, req_id: str, events: list["EngineCoreEvent"],
is_prefilling: bool, req_stats: RequestStateStats,
lora_stats: Optional[LoRAStats]):
# Avoid circular dependency
......@@ -185,7 +185,7 @@ class LoRARequestStates:
"""Per-LoRA request state stats."""
def __init__(self):
self.lora_name_to_stats: Dict[str, LoRAStats] = {}
self.lora_name_to_stats: dict[str, LoRAStats] = {}
def get_stats(self, req_state: 'RequestState') -> Optional[LoRAStats]:
if req_state.lora_name is None:
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import Dict, List, NamedTuple, Optional
from typing import NamedTuple, Optional
import torch
......@@ -9,11 +9,11 @@ import torch
class LogprobsLists(NamedTuple):
# [num_reqs, max_num_logprobs + 1]
logprob_token_ids: List[List[int]]
logprob_token_ids: list[list[int]]
# [num_reqs, max_num_logprobs + 1]
logprobs: List[List[float]]
logprobs: list[list[float]]
# [num_reqs]
sampled_token_ranks: List[int]
sampled_token_ranks: list[int]
def slice(self, start: int, end: int):
return LogprobsLists(
......@@ -52,23 +52,23 @@ class SamplerOutput:
# ModelRunnerOutput is serialized and sent to the scheduler process.
# This is expensive for torch.Tensor so prefer to use List instead.
# This is expensive for torch.Tensor so prefer to use list instead.
@dataclass
class ModelRunnerOutput:
# [num_reqs]
req_ids: List[str]
req_ids: list[str]
# req_id -> index
req_id_to_index: Dict[str, int]
req_id_to_index: dict[str, int]
# num_reqs x num_generated_tokens
# num_generated_tokens is the number of tokens
# generated in the current step. It can be different for
# each request due to speculative/jump decoding.
sampled_token_ids: List[List[int]]
sampled_token_ids: list[list[int]]
# num_reqs x num_spec_tokens
spec_token_ids: Optional[List[List[int]]]
spec_token_ids: Optional[list[list[int]]]
# [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1]
......@@ -79,4 +79,4 @@ class ModelRunnerOutput:
# [prompt_len, num_prompt_logprobs]
# [prompt_len, num_prompt_logprobs]
# [prompt_len]
prompt_logprobs_dict: Dict[str, Optional[LogprobsTensors]]
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]]
# SPDX-License-Identifier: Apache-2.0
import enum
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, Optional, Union
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
......@@ -20,10 +20,10 @@ class Request:
self,
request_id: str,
prompt: Optional[str],
prompt_token_ids: List[int],
multi_modal_inputs: Optional[List["MultiModalKwargs"]],
multi_modal_hashes: Optional[List[str]],
multi_modal_placeholders: Optional[List["PlaceholderRange"]],
prompt_token_ids: list[int],
multi_modal_inputs: Optional[list["MultiModalKwargs"]],
multi_modal_hashes: Optional[list[str]],
multi_modal_placeholders: Optional[list["PlaceholderRange"]],
sampling_params: SamplingParams,
eos_token_id: Optional[int],
arrival_time: float,
......@@ -36,7 +36,7 @@ class Request:
self.lora_request = lora_request
self.status = RequestStatus.WAITING
self.events: List[EngineCoreEvent] = []
self.events: list[EngineCoreEvent] = []
self.stop_reason: Union[int, str, None] = None
assert sampling_params.max_tokens is not None
self.max_tokens = sampling_params.max_tokens
......@@ -44,15 +44,15 @@ class Request:
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
self.num_prompt_tokens = len(self.prompt_token_ids)
self._output_token_ids: List[int] = []
self._all_token_ids: List[int] = self.prompt_token_ids.copy()
self.spec_token_ids: List[int] = []
self._output_token_ids: list[int] = []
self._all_token_ids: list[int] = self.prompt_token_ids.copy()
self.spec_token_ids: list[int] = []
self.num_computed_tokens = 0
# Multi-modal related
self.mm_positions = multi_modal_placeholders or []
self.mm_inputs = multi_modal_inputs or []
self.mm_hashes: List[str] = multi_modal_hashes or []
self.mm_hashes: list[str] = multi_modal_hashes or []
# Sanity check
assert len(self.mm_inputs) == len(self.mm_positions)
......@@ -89,7 +89,7 @@ class Request:
EngineCoreEvent.new_event(EngineCoreEventType.SCHEDULED,
timestamp))
def take_events(self) -> Optional[List[EngineCoreEvent]]:
def take_events(self) -> Optional[list[EngineCoreEvent]]:
if not self.events:
return None
events, self.events = self.events, []
......@@ -97,7 +97,7 @@ class Request:
def append_output_token_ids(
self,
token_ids: Union[int, List[int]],
token_ids: Union[int, list[int]],
) -> None:
if isinstance(token_ids, int):
token_ids = [token_ids]
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Tuple
from typing import Optional
import torch
......@@ -17,7 +17,7 @@ class SamplingMetadata:
top_k: Optional[torch.Tensor]
min_p: Optional[torch.Tensor]
generators: Dict[int, torch.Generator]
generators: dict[int, torch.Generator]
# None means no logprobs, 0 means sampled token logprobs only
max_num_logprobs: Optional[int]
......@@ -28,12 +28,12 @@ class SamplingMetadata:
presence_penalties: torch.Tensor
repetition_penalties: torch.Tensor
output_token_ids: List[List[int]]
output_token_ids: list[list[int]]
# req_index -> (min_tokens, stop_token_ids)
min_tokens: Dict[int, Tuple[int, Set[int]]]
min_tokens: dict[int, tuple[int, set[int]]]
logit_bias: List[Optional[Dict[int, float]]]
logit_bias: list[Optional[dict[int, float]]]
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
# vocab size).
......
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, List, Set, Tuple
import torch
from vllm.model_executor.layers.utils import apply_penalties
......@@ -9,13 +7,13 @@ from vllm.utils import is_pin_memory_available, make_tensor_with_pad
def apply_min_token_penalties(
logits: torch.Tensor, output_token_ids: List[List[int]],
min_tokens: Dict[int, Tuple[int, Set[int]]]) -> None:
logits: torch.Tensor, output_token_ids: list[list[int]],
min_tokens: dict[int, tuple[int, set[int]]]) -> None:
"""
Applies minimum token penalty by setting the logits of the stop tokens
to -inf.
"""
min_tokens_logits_to_penalize: List[Tuple[int, int]] = []
min_tokens_logits_to_penalize: list[tuple[int, int]] = []
for index, (min_token, stop_token_ids) in min_tokens.items():
if len(output_token_ids[index]) < min_token:
for stop_token_id in stop_token_ids:
......@@ -30,7 +28,7 @@ def apply_all_penalties(
presence_penalties: torch.Tensor,
frequency_penalties: torch.Tensor,
repetition_penalties: torch.Tensor,
output_token_ids: List[List[int]],
output_token_ids: list[list[int]],
) -> torch.Tensor:
"""
Applies presence, frequency and repetition penalties to the logits.
......@@ -43,7 +41,7 @@ def apply_all_penalties(
repetition_penalties)
def _convert_to_tensors(output_token_ids: List[List[int]], vocab_size: int,
def _convert_to_tensors(output_token_ids: list[list[int]], vocab_size: int,
device: torch.device) -> torch.Tensor:
"""
Convert the different list data structures to tensors.
......
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, Optional
from typing import Optional
import torch
import torch.nn as nn
......@@ -54,7 +54,7 @@ class TopKTopPSampler(nn.Module):
def forward_native(
self,
logits: torch.Tensor,
generators: Dict[int, torch.Generator],
generators: dict[int, torch.Generator],
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
......@@ -66,7 +66,7 @@ class TopKTopPSampler(nn.Module):
def forward_cuda(
self,
logits: torch.Tensor,
generators: Dict[int, torch.Generator],
generators: dict[int, torch.Generator],
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
......@@ -117,7 +117,7 @@ def apply_top_k_top_p(
def random_sample(
probs: torch.Tensor,
generators: Dict[int, torch.Generator],
generators: dict[int, torch.Generator],
) -> torch.Tensor:
"""Randomly sample from the probabilities.
......@@ -143,7 +143,7 @@ def flashinfer_sample(
probs: torch.Tensor,
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
generators: Dict[int, torch.Generator],
generators: dict[int, torch.Generator],
) -> torch.Tensor:
"""Sample from the probabilities using FlashInfer.
......
# SPDX-License-Identifier: Apache-2.0
from typing import List
import torch
import torch.nn as nn
......@@ -54,7 +53,7 @@ class RejectionSampler(nn.Module):
else:
self.forward_method = self.forward_native
def forward(self, draft_token_ids: List[List[int]],
def forward(self, draft_token_ids: list[list[int]],
target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata) -> SamplerOutput:
if not sampling_metadata.all_greedy:
......@@ -66,7 +65,7 @@ class RejectionSampler(nn.Module):
def flashinfer_sample(
self,
draft_token_ids: List[List[int]],
draft_token_ids: list[list[int]],
target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
......@@ -119,7 +118,7 @@ class RejectionSampler(nn.Module):
# TODO: The following method can be optimized for better performance.
def forward_native(
self,
draft_token_ids: List[List[int]],
draft_token_ids: list[list[int]],
target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
......
......@@ -4,7 +4,7 @@ import time
from dataclasses import dataclass
from dataclasses import field as dataclass_field
from enum import IntEnum
from typing import ClassVar, Dict, List, Optional, Set
from typing import ClassVar, Optional
import msgspec
from msgspec import field as msgspec_field
......@@ -78,7 +78,7 @@ class RequestStatsUpdate(
FINISHED (All could go to FINISHED)
"""
_VALID_TRANSITIONS: ClassVar[Dict[Type, Set[Type]]] = {
_VALID_TRANSITIONS: ClassVar[dict[Type, set[Type]]] = {
Type.ARRIVED: {
Type.INPUT_PROCESSED,
Type.FINISHED,
......@@ -140,7 +140,7 @@ class RequestStatsUpdate(
finish_reason: Optional[str] = None
# Non-optional fields for each update type.
_REQUIRED_FIELDS: ClassVar[Dict[Type, List[str]]] = {
_REQUIRED_FIELDS: ClassVar[dict[Type, list[str]]] = {
Type.INPUT_PROCESSED: ["num_prompt_tokens", "sampling_params"],
Type.PREFILLING: ["num_computed_tokens", "num_cached_tokens"],
Type.DETOKENIZED: ["num_new_tokens"],
......@@ -218,13 +218,13 @@ class RequestStats:
# 2. the request was preempted and resumed. It is equivalent to running
# a prefill of the original prefill tokens + generated output tokens
# before preemption.
prefill_start_ts_s_lst: List[float] = dataclass_field(default_factory=list)
prefill_start_ts_s_lst: list[float] = dataclass_field(default_factory=list)
# A list of timestamps when a token is decoded by the engine core.
decoding_ts_s_lst: List[float] = dataclass_field(default_factory=list)
decoding_ts_s_lst: list[float] = dataclass_field(default_factory=list)
# A sorted list of timestamps for each output token.
output_token_ts_s_lst: List[float] = dataclass_field(default_factory=list)
output_token_ts_s_lst: list[float] = dataclass_field(default_factory=list)
# First token's timestamp.
first_token_ts_s: Optional[float] = None
......@@ -241,7 +241,7 @@ class RequestStats:
# metric to measure the impact of preemption other than observation of
# large P99 TPOT. Ideally we could quantify the impact of preemption by
# measuring the number of tokens re-computed due to preemption.
preempted_ts_s_lst: List[float] = dataclass_field(default_factory=list)
preempted_ts_s_lst: list[float] = dataclass_field(default_factory=list)
# Timestamp when the request was finished at the engine core.
finished_ts_s: Optional[float] = None
......@@ -308,7 +308,7 @@ class RequestStats:
return self.e2e_latency_s - self.first_token_latency_s
@property
def output_token_latency_s_lst(self) -> List[float]:
def output_token_latency_s_lst(self) -> list[float]:
if len(self.output_token_ts_s_lst) == 0:
return []
latency_s_lst = []
......@@ -442,7 +442,7 @@ class EngineCoreStatsSnapshot(
default_factory=SchedulerStats)
# Per request stats updates.
requests_stats_updates: List[RequestStatsUpdate] = msgspec_field(
requests_stats_updates: list[RequestStatsUpdate] = msgspec_field(
default_factory=list)
# Engine core's queue stats.
......
......@@ -5,8 +5,8 @@ import os
import weakref
from collections import defaultdict
from collections.abc import Sequence
from typing import (TYPE_CHECKING, Any, Callable, Dict, Generic, List,
Optional, TypeVar, Union, overload)
from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
Union, overload)
import torch
......@@ -24,7 +24,7 @@ T = TypeVar("T")
class ConstantList(Generic[T], Sequence):
def __init__(self, x: List[T]) -> None:
def __init__(self, x: list[T]) -> None:
self._x = x
def append(self, item):
......@@ -57,10 +57,10 @@ class ConstantList(Generic[T], Sequence):
...
@overload
def __getitem__(self, s: slice, /) -> List[T]:
def __getitem__(self, s: slice, /) -> list[T]:
...
def __getitem__(self, item: Union[int, slice]) -> Union[T, List[T]]:
def __getitem__(self, item: Union[int, slice]) -> Union[T, list[T]]:
return self._x[item]
@overload
......@@ -71,7 +71,7 @@ class ConstantList(Generic[T], Sequence):
def __setitem__(self, s: slice, value: T, /):
...
def __setitem__(self, item: Union[int, slice], value: Union[T, List[T]]):
def __setitem__(self, item: Union[int, slice], value: Union[T, list[T]]):
raise Exception("Cannot set item in a constant list")
def __delitem__(self, item):
......@@ -99,7 +99,7 @@ class BackgroundProcHandle:
output_path: str,
process_name: str,
target_fn: Callable,
process_kwargs: Dict[Any, Any],
process_kwargs: dict[Any, Any],
):
context = get_mp_context()
reader, writer = context.Pipe(duplex=False)
......@@ -146,9 +146,9 @@ def shutdown(proc: multiprocessing.Process, input_path: str, output_path: str):
def bind_kv_cache(
kv_caches: Dict[str, torch.Tensor],
forward_context: Dict[str, "Attention"],
runner_kv_caches: List[torch.Tensor],
kv_caches: dict[str, torch.Tensor],
forward_context: dict[str, "Attention"],
runner_kv_caches: list[torch.Tensor],
) -> None:
"""
Bind the allocated KV cache to both ModelRunner and forward context so
......
# SPDX-License-Identifier: Apache-2.0
from typing import List
import numpy as np
import torch
......@@ -40,7 +38,7 @@ class BlockTable:
def append_row(
self,
block_ids: List[int],
block_ids: list[int],
row_idx: int,
) -> None:
if not block_ids:
......@@ -50,7 +48,7 @@ class BlockTable:
self.num_blocks_per_row[row_idx] += num_blocks
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
def add_row(self, block_ids: List[int], row_idx: int) -> None:
def add_row(self, block_ids: list[int], row_idx: int) -> None:
self.num_blocks_per_row[row_idx] = 0
self.append_row(block_ids, row_idx)
......
......@@ -2,7 +2,7 @@
# Datastructures defining an input batch
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, cast
from typing import TYPE_CHECKING, Optional, cast
import numpy as np
import torch
......@@ -24,16 +24,16 @@ if TYPE_CHECKING:
class CachedRequestState:
req_id: str
prompt_token_ids: List[int]
prompt_token_ids: list[int]
prompt: Optional[str]
mm_inputs: List[MultiModalKwargs]
mm_positions: List["PlaceholderRange"]
mm_inputs: list[MultiModalKwargs]
mm_positions: list["PlaceholderRange"]
sampling_params: SamplingParams
generator: Optional[torch.Generator]
block_ids: List[int]
block_ids: list[int]
num_computed_tokens: int
output_token_ids: List[int]
output_token_ids: list[int]
mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[int] = None
......@@ -63,8 +63,8 @@ class InputBatch:
self.pin_memory = pin_memory
self.vocab_size = vocab_size
self._req_ids: List[Optional[str]] = []
self.req_id_to_index: Dict[str, int] = {}
self._req_ids: list[Optional[str]] = []
self.req_id_to_index: dict[str, int] = {}
# TODO(woosuk): This buffer could be too large if max_model_len is big.
# Find a way to reduce the CPU memory usage.
......@@ -106,8 +106,8 @@ class InputBatch:
device="cpu",
pin_memory=pin_memory)
self.temperature_cpu = self.temperature_cpu_tensor.numpy()
self.greedy_reqs: Set[str] = set()
self.random_reqs: Set[str] = set()
self.greedy_reqs: set[str] = set()
self.random_reqs: set[str] = set()
self.top_p = torch.empty((max_num_reqs, ),
dtype=torch.float32,
......@@ -117,7 +117,7 @@ class InputBatch:
device="cpu",
pin_memory=pin_memory)
self.top_p_cpu = self.top_p_cpu_tensor.numpy()
self.top_p_reqs: Set[str] = set()
self.top_p_reqs: set[str] = set()
self.top_k = torch.empty((max_num_reqs, ),
dtype=torch.int32,
......@@ -127,7 +127,7 @@ class InputBatch:
device="cpu",
pin_memory=pin_memory)
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
self.top_k_reqs: Set[str] = set()
self.top_k_reqs: set[str] = set()
self.min_p = torch.empty((max_num_reqs, ),
dtype=torch.float32,
......@@ -137,7 +137,7 @@ class InputBatch:
device="cpu",
pin_memory=pin_memory)
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
self.min_p_reqs: Set[str] = set()
self.min_p_reqs: set[str] = set()
# Frequency penalty related data structures
self.frequency_penalties = torch.empty((max_num_reqs, ),
......@@ -150,7 +150,7 @@ class InputBatch:
pin_memory=pin_memory)
self.frequency_penalties_cpu = \
self.frequency_penalties_cpu_tensor.numpy()
self.frequency_penalties_reqs: Set[str] = set()
self.frequency_penalties_reqs: set[str] = set()
# Presence penalty related data structures
self.presence_penalties = torch.empty((max_num_reqs, ),
......@@ -162,7 +162,7 @@ class InputBatch:
pin_memory=pin_memory)
self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
)
self.presence_penalties_reqs: Set[str] = set()
self.presence_penalties_reqs: set[str] = set()
# Repetition penalty related data structures
self.repetition_penalties = torch.empty((max_num_reqs, ),
......@@ -175,43 +175,43 @@ class InputBatch:
pin_memory=pin_memory)
self.repetition_penalties_cpu = \
self.repetition_penalties_cpu_tensor.numpy()
self.repetition_penalties_reqs: Set[str] = set()
self.repetition_penalties_reqs: set[str] = set()
# req_index -> (min_tokens, stop_token_ids)
self.min_tokens: Dict[int, Tuple[int, Set[int]]] = {}
self.min_tokens: dict[int, tuple[int, set[int]]] = {}
# lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
dtype=np.int32)
self.lora_id_to_request_ids: Dict[int, Set[str]] = {}
self.lora_id_to_lora_request: Dict[int, LoRARequest] = {}
self.lora_id_to_request_ids: dict[int, set[str]] = {}
self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
# req_index -> generator
# NOTE(woosuk): The indices of the requests that do not have their own
# generator should not be included in the dictionary.
self.generators: Dict[int, torch.Generator] = {}
self.generators: dict[int, torch.Generator] = {}
self.num_logprobs: Dict[str, int] = {}
self.num_logprobs: dict[str, int] = {}
# NOTE(rob): num_prompt_logprobs only includes reqs
# that are currently in the prefill phase.
self.num_prompt_logprobs: Dict[str, int] = {}
self.num_prompt_logprobs: dict[str, int] = {}
self.logit_bias: List[Optional[Dict[int,
self.logit_bias: list[Optional[dict[int,
float]]] = [None] * max_num_reqs
self.has_allowed_token_ids: Set[str] = set()
self.has_allowed_token_ids: set[str] = set()
self.allowed_token_ids_mask: Optional[torch.Tensor] = None
self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
self.req_output_token_ids: List[Optional[List[int]]] = []
self.req_output_token_ids: list[Optional[list[int]]] = []
# This is updated each time the batch constituents change.
self.sampling_metadata = self._make_sampling_metadata()
@property
def req_ids(self) -> List[str]:
def req_ids(self) -> list[str]:
# None elements should only be present transiently
# while performing state updates to the batch.
return cast(List[str], self._req_ids)
return cast(list[str], self._req_ids)
def add_request(
self,
......@@ -417,7 +417,7 @@ class InputBatch:
self.logit_bias[i2], self.logit_bias[i1]
self.block_table.swap_row(i1, i2)
def condense(self, empty_req_indices: List[int]) -> None:
def condense(self, empty_req_indices: list[int]) -> None:
num_reqs = self.num_reqs
if num_reqs == 0:
# The batched states are empty.
......@@ -550,7 +550,7 @@ class InputBatch:
frequency_penalties=self.frequency_penalties[:num_reqs],
presence_penalties=self.presence_penalties[:num_reqs],
repetition_penalties=self.repetition_penalties[:num_reqs],
output_token_ids=cast(List[List[int]], self.req_output_token_ids),
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
min_tokens=self.min_tokens,
no_penalties=self.no_penalties,
logit_bias=self.logit_bias[:num_reqs],
......@@ -577,7 +577,7 @@ class InputBatch:
def make_lora_inputs(
self, num_scheduled_tokens: np.ndarray
) -> Tuple[Tuple[int, ...], Tuple[int, ...], Set[LoRARequest]]:
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
"""
Given the num_scheduled_tokens for each request in the batch, return
datastructures used to activate the current LoRAs.
......@@ -593,7 +593,7 @@ class InputBatch:
prompt_lora_mapping = tuple(req_lora_mapping)
token_lora_mapping = tuple(
req_lora_mapping.repeat(num_scheduled_tokens))
active_lora_requests: Set[LoRARequest] = set(
active_lora_requests: set[LoRARequest] = set(
self.lora_id_to_lora_request.values())
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
......
......@@ -3,7 +3,7 @@
import gc
import time
import weakref
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Optional, Union
import numpy as np
import torch
......@@ -135,9 +135,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Lazy initialization
# self.model: nn.Module # Set after load_model
self.kv_caches: List[torch.Tensor] = []
self.kv_caches: list[torch.Tensor] = []
# req_id -> (input_id -> encoder_output)
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
# Set up speculative decoding.
self.use_spec_decode = False
......@@ -158,7 +158,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
# Request states.
self.requests: Dict[str, CachedRequestState] = {}
self.requests: dict[str, CachedRequestState] = {}
# Persistent batch.
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
......@@ -274,7 +274,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# then resubmitted with the same ID. In this case, we treat them as two
# distinct requests - clearing the cached states for the first request
# and handling the second as a new request.
removed_req_indices: List[int] = []
removed_req_indices: list[int] = []
for req_id in scheduler_output.finished_req_ids:
req_index = self.input_batch.remove_request(req_id)
if req_index is not None:
......@@ -305,7 +305,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert req_index is not None
removed_req_indices.append(req_index)
req_ids_to_add: List[str] = []
req_ids_to_add: list[str] = []
# Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
......@@ -446,7 +446,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _prepare_inputs(
self,
scheduler_output: "SchedulerOutput",
) -> Tuple[FlashAttentionMetadata, torch.Tensor]:
) -> tuple[FlashAttentionMetadata, torch.Tensor]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
......@@ -774,8 +774,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return
# Batch the multi-modal inputs.
mm_inputs: List[MultiModalKwargs] = []
req_input_ids: List[Tuple[str, int]] = []
mm_inputs: list[MultiModalKwargs] = []
req_input_ids: list[tuple[str, int]] = []
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id]
for input_id in encoder_input_ids:
......@@ -819,8 +819,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _gather_encoder_outputs(
self,
scheduler_output: "SchedulerOutput",
) -> List[torch.Tensor]:
encoder_outputs: List[torch.Tensor] = []
) -> list[torch.Tensor]:
encoder_outputs: list[torch.Tensor] = []
for req_id in self.input_batch.req_ids:
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
......@@ -1022,10 +1022,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def generate_draft_token_ids(
self,
sampled_token_ids: List[List[int]],
) -> List[List[int]]:
sampled_token_ids: list[list[int]],
) -> list[list[int]]:
# TODO(woosuk): Optimize.
draft_token_ids: List[List[int]] = []
draft_token_ids: list[list[int]] = []
for i, sampled_ids in enumerate(sampled_token_ids):
num_sampled_ids = len(sampled_ids)
if not num_sampled_ids:
......@@ -1069,12 +1069,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self,
hidden_states: torch.Tensor,
scheduler_output: "SchedulerOutput",
) -> Dict[str, Optional[LogprobsTensors]]:
) -> dict[str, Optional[LogprobsTensors]]:
num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
if not num_prompt_logprobs_dict:
return {}
prompt_logprobs_dict: Dict[str, Optional[LogprobsTensors]] = {}
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
# Since prompt logprobs are a rare feature, prioritize simple,
# maintainable loop over optimal performance.
......@@ -1365,7 +1365,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
"Hybrid models with more than one KV cache type are not "
"supported yet.")
kv_caches: Dict[str, torch.Tensor] = {}
kv_caches: dict[str, torch.Tensor] = {}
for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items():
tensor_config = kv_cache_config.tensors[layer_name]
......
......@@ -2,7 +2,7 @@
"""A GPU worker class."""
import gc
import os
from typing import TYPE_CHECKING, Optional, Set
from typing import TYPE_CHECKING, Optional
import torch
import torch.distributed
......@@ -243,7 +243,7 @@ class Worker(WorkerBase):
def remove_lora(self, lora_id: int) -> bool:
return self.model_runner.remove_lora(lora_id)
def list_loras(self) -> Set[int]:
def list_loras(self) -> set[int]:
return self.model_runner.list_loras()
def pin_lora(self, lora_id: int) -> bool:
......
......@@ -4,7 +4,6 @@ Define LoRA functionality mixin for model runners.
"""
from contextlib import contextmanager
from typing import Set, Tuple
import numpy as np
import torch.nn as nn
......@@ -57,9 +56,9 @@ class LoRAModelRunnerMixin:
)
return self.lora_manager.create_lora_manager(model)
def _set_active_loras(self, prompt_lora_mapping: Tuple[int, ...],
token_lora_mapping: Tuple[int, ...],
lora_requests: Set[LoRARequest]) -> None:
def _set_active_loras(self, prompt_lora_mapping: tuple[int, ...],
token_lora_mapping: tuple[int, ...],
lora_requests: set[LoRARequest]) -> None:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
......@@ -74,10 +73,10 @@ class LoRAModelRunnerMixin:
def set_active_loras(self, input_batch: InputBatch,
num_scheduled_tokens: np.ndarray) -> None:
prompt_lora_mapping: Tuple[int, ...] # of size input_batch.num_reqs
token_lora_mapping: Tuple[int,
prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs
token_lora_mapping: tuple[int,
...] # of size np.sum(num_scheduled_tokens)
lora_requests: Set[LoRARequest]
lora_requests: set[LoRARequest]
prompt_lora_mapping, token_lora_mapping, lora_requests = \
input_batch.make_lora_inputs(num_scheduled_tokens)
return self._set_active_loras(prompt_lora_mapping, token_lora_mapping,
......@@ -105,7 +104,7 @@ class LoRAModelRunnerMixin:
num_scheduled_tokens)
# Make dummy lora requests
lora_requests: Set[LoRARequest] = {
lora_requests: set[LoRARequest] = {
LoRARequest(lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_path="/not/a/real/path")
......@@ -143,7 +142,7 @@ class LoRAModelRunnerMixin:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.pin_adapter(lora_id)
def list_loras(self) -> Set[int]:
def list_loras(self) -> set[int]:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.list_adapters()
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
import time
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
from typing import TYPE_CHECKING, Optional, cast
from unittest.mock import patch
import numpy as np
......@@ -95,13 +95,13 @@ class TPUModelRunner:
)
# Request states.
self.requests: Dict[str, CachedRequestState] = {}
self.requests: dict[str, CachedRequestState] = {}
# req_id -> (input_id -> encoder_output)
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
# KV caches for forward pass
self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = []
self.kv_caches: list[tuple[torch.Tensor, torch.Tensor]] = []
# Cached torch/numpy tensor
# The pytorch tensor and numpy array share the same buffer.
......@@ -171,7 +171,7 @@ class TPUModelRunner:
# then resubmitted with the same ID. In this case, we treat them as two
# distinct requests - clearing the cached states for the first request
# and handling the second as a new request.
removed_req_indices: List[int] = []
removed_req_indices: list[int] = []
for req_id in scheduler_output.finished_req_ids:
req_index = self.input_batch.remove_request(req_id)
if req_index is not None:
......@@ -194,7 +194,7 @@ class TPUModelRunner:
assert req_index is not None
removed_req_indices.append(req_index)
req_ids_to_add: List[str] = []
req_ids_to_add: list[str] = []
# Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
......@@ -453,7 +453,7 @@ class TPUModelRunner:
selected_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
# Then, let's update the cache state.
request_seq_lens: List[Tuple[int, CachedRequestState, int]] = []
request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
assert req_id is not None
req_state = self.requests[req_id]
......@@ -473,9 +473,9 @@ class TPUModelRunner:
assert all(
req_id is not None for req_id in
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs])
req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs])
prompt_logprobs_dict: Dict[str, Optional[LogprobsTensors]] = {}
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
for req_id in self.input_batch.req_ids[:num_reqs]:
prompt_logprobs_dict[req_id] = None
......@@ -612,7 +612,7 @@ class TPUModelRunner:
"Hybrid models with more than one KV cache type are not "
"supported yet.")
kv_caches: Dict[str, torch.Tensor] = {}
kv_caches: dict[str, torch.Tensor] = {}
for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items():
tensor_config = kv_cache_config.tensors[layer_name]
......@@ -649,7 +649,7 @@ class ModelWrapperV1(nn.Module):
self,
token_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
) -> torch.Tensor:
"""Executes the forward pass of the model and samples the next token.
......@@ -667,7 +667,7 @@ class ModelWrapperV1(nn.Module):
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
# work, we need to flatten the first three dimensions and modify
# the slot_mapping accordingly.
# kv_caches: List[Tuple[torch.Tensor, torch.Tensor]]
# kv_caches: list[tuple[torch.Tensor, torch.Tensor]]
num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
slot_mapping = attn_metadata.slot_mapping
slot_mapping = slot_mapping.flatten()
......
# SPDX-License-Identifier: Apache-2.0
"""A TPU worker class."""
import os
from typing import Dict, List, Optional
from typing import Optional
import torch
import torch.distributed
......@@ -103,7 +103,7 @@ class TPUWorker:
self.model_runner = TPUModelRunner(self.vllm_config, self.device)
def determine_available_memory(self) -> int:
kv_caches: Dict[str, torch.Tensor] = {}
kv_caches: dict[str, torch.Tensor] = {}
kv_cache_spec = self.model_runner.get_kv_cache_spec()
for layer_name, layer_spec in kv_cache_spec.items():
if isinstance(layer_spec, FullAttentionSpec):
......@@ -118,7 +118,7 @@ class TPUWorker:
else:
raise NotImplementedError
runner_kv_caches: List[torch.Tensor] = []
runner_kv_caches: list[torch.Tensor] = []
bind_kv_cache(
kv_caches,
self.vllm_config.compilation_config.static_forward_context,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment