Unverified Commit cf069aa8 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated Python 3.8 typing (#13971)

parent bf33700e
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from concurrent.futures import Future from concurrent.futures import Future
from typing import List, Type, Union from typing import Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -22,8 +22,8 @@ class Executor(ExecutorBase): ...@@ -22,8 +22,8 @@ class Executor(ExecutorBase):
For methods shared by v0 and v1, define them in ExecutorBase""" For methods shared by v0 and v1, define them in ExecutorBase"""
@staticmethod @staticmethod
def get_class(vllm_config: VllmConfig) -> Type["Executor"]: def get_class(vllm_config: VllmConfig) -> type["Executor"]:
executor_class: Type[Executor] executor_class: type[Executor]
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
distributed_executor_backend = ( distributed_executor_backend = (
parallel_config.distributed_executor_backend) parallel_config.distributed_executor_backend)
...@@ -53,7 +53,7 @@ class Executor(ExecutorBase): ...@@ -53,7 +53,7 @@ class Executor(ExecutorBase):
return executor_class return executor_class
def initialize_from_config(self, def initialize_from_config(self,
kv_cache_configs: List[KVCacheConfig]) -> None: kv_cache_configs: list[KVCacheConfig]) -> None:
""" """
Initialize the KV caches and begin the model execution loop of the Initialize the KV caches and begin the model execution loop of the
underlying workers. underlying workers.
...@@ -69,7 +69,7 @@ class Executor(ExecutorBase): ...@@ -69,7 +69,7 @@ class Executor(ExecutorBase):
# operators can be applied to all workers. # operators can be applied to all workers.
return min(output) return min(output)
def get_kv_cache_specs(self) -> List[KVCacheSpec]: def get_kv_cache_specs(self) -> list[KVCacheSpec]:
output = self.collective_rpc("get_kv_cache_spec") output = self.collective_rpc("get_kv_cache_spec")
return output return output
......
...@@ -10,7 +10,7 @@ from dataclasses import dataclass ...@@ -10,7 +10,7 @@ from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from functools import partial from functools import partial
from multiprocessing.process import BaseProcess from multiprocessing.process import BaseProcess
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Optional, Union
import cloudpickle import cloudpickle
import psutil import psutil
...@@ -77,7 +77,7 @@ class MultiprocExecutor(Executor): ...@@ -77,7 +77,7 @@ class MultiprocExecutor(Executor):
scheduler_output_handle = self.rpc_broadcast_mq.export_handle() scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
# Create workers # Create workers
self.workers: List[WorkerProcHandle] = [] self.workers: list[WorkerProcHandle] = []
for rank in range(self.world_size): for rank in range(self.world_size):
worker = WorkerProc.make_worker_process(self.vllm_config, rank, worker = WorkerProc.make_worker_process(self.vllm_config, rank,
rank, rank,
...@@ -94,8 +94,8 @@ class MultiprocExecutor(Executor): ...@@ -94,8 +94,8 @@ class MultiprocExecutor(Executor):
def collective_rpc(self, def collective_rpc(self,
method: Union[str, Callable], method: Union[str, Callable],
timeout: Optional[float] = None, timeout: Optional[float] = None,
args: Tuple = (), args: tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]: kwargs: Optional[dict] = None) -> list[Any]:
start_time = time.monotonic() start_time = time.monotonic()
kwargs = kwargs or {} kwargs = kwargs or {}
...@@ -208,7 +208,7 @@ class WorkerProc: ...@@ -208,7 +208,7 @@ class WorkerProc:
self.rank = rank self.rank = rank
wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank) wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank)
# TODO: move `init_worker` to executor level as a collective rpc call # TODO: move `init_worker` to executor level as a collective rpc call
all_kwargs: List[Dict] = [ all_kwargs: list[dict] = [
{} for _ in range(vllm_config.parallel_config.world_size) {} for _ in range(vllm_config.parallel_config.world_size)
] ]
all_kwargs[rank] = { all_kwargs[rank] = {
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List
import torch import torch
...@@ -74,7 +73,7 @@ class FullAttentionSpec(KVCacheSpecBase): ...@@ -74,7 +73,7 @@ class FullAttentionSpec(KVCacheSpecBase):
return cdiv(num_tokens, self.block_size) * self.page_size_bytes return cdiv(num_tokens, self.block_size) * self.page_size_bytes
KVCacheSpec = Dict[str, KVCacheSpecBase] KVCacheSpec = dict[str, KVCacheSpecBase]
@dataclass @dataclass
...@@ -95,7 +94,7 @@ class KVCacheConfig: ...@@ -95,7 +94,7 @@ class KVCacheConfig:
"""The number of KV cache blocks""" """The number of KV cache blocks"""
num_blocks: int num_blocks: int
"""layer_name -> how to initialize KV cache for that layer""" """layer_name -> how to initialize KV cache for that layer"""
tensors: Dict[str, KVCacheTensor] tensors: dict[str, KVCacheTensor]
""" """
A list of kv-cache groups. Each group includes a set of layers with A list of kv-cache groups. Each group includes a set of layers with
the same kv-cache spec, and the total page_size of layers inside a group the same kv-cache spec, and the total page_size of layers inside a group
...@@ -108,6 +107,6 @@ class KVCacheConfig: ...@@ -108,6 +107,6 @@ class KVCacheConfig:
3. (not implemented yet) A model with 2 full attention layers and 4 sliding 3. (not implemented yet) A model with 2 full attention layers and 4 sliding
window attention layers: three groups, (full * 2), (sw * 2), (sw * 2). window attention layers: three groups, (full * 2), (sw * 2), (sw * 2).
""" """
groups: List[List[str]] groups: list[list[str]]
"""the KVCacheSpec of the model""" """the KVCacheSpec of the model"""
kv_cache_spec: KVCacheSpec kv_cache_spec: KVCacheSpec
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Optional from typing import Optional
import numpy as np import numpy as np
import prometheus_client import prometheus_client
...@@ -35,8 +35,8 @@ class LoggingStatLogger(StatLoggerBase): ...@@ -35,8 +35,8 @@ class LoggingStatLogger(StatLoggerBase):
self.last_log_time = now self.last_log_time = now
# Tracked stats over current local logging interval. # Tracked stats over current local logging interval.
self.num_prompt_tokens: List[int] = [] self.num_prompt_tokens: list[int] = []
self.num_generation_tokens: List[int] = [] self.num_generation_tokens: list[int] = []
# Prefix cache metrics. TODO: Make the interval configurable. # Prefix cache metrics. TODO: Make the interval configurable.
self.prefix_caching_metrics = PrefixCachingMetrics() self.prefix_caching_metrics = PrefixCachingMetrics()
...@@ -52,7 +52,7 @@ class LoggingStatLogger(StatLoggerBase): ...@@ -52,7 +52,7 @@ class LoggingStatLogger(StatLoggerBase):
self.num_generation_tokens.append( self.num_generation_tokens.append(
iteration_stats.num_generation_tokens) iteration_stats.num_generation_tokens)
def _get_throughput(self, tracked_stats: List[int], now: float) -> float: def _get_throughput(self, tracked_stats: list[int], now: float) -> float:
# Compute summary metrics for tracked stats # Compute summary metrics for tracked stats
return float(np.sum(tracked_stats) / (now - self.last_log_time)) return float(np.sum(tracked_stats) / (now - self.last_log_time))
...@@ -147,7 +147,7 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -147,7 +147,7 @@ class PrometheusStatLogger(StatLoggerBase):
documentation="Number of generation tokens processed.", documentation="Number of generation tokens processed.",
labelnames=labelnames).labels(*labelvalues) labelnames=labelnames).labels(*labelvalues)
self.counter_request_success: Dict[FinishReason, self.counter_request_success: dict[FinishReason,
prometheus_client.Counter] = {} prometheus_client.Counter] = {}
counter_request_success_base = prometheus_client.Counter( counter_request_success_base = prometheus_client.Counter(
name="vllm:request_success_total", name="vllm:request_success_total",
...@@ -338,14 +338,14 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -338,14 +338,14 @@ class PrometheusStatLogger(StatLoggerBase):
prometheus_client.REGISTRY.unregister(collector) prometheus_client.REGISTRY.unregister(collector)
def build_buckets(mantissa_lst: List[int], max_value: int) -> List[int]: def build_buckets(mantissa_lst: list[int], max_value: int) -> list[int]:
""" """
Builds a list of buckets with increasing powers of 10 multiplied by Builds a list of buckets with increasing powers of 10 multiplied by
mantissa values until the value exceeds the specified maximum. mantissa values until the value exceeds the specified maximum.
""" """
exponent = 0 exponent = 0
buckets: List[int] = [] buckets: list[int] = []
while True: while True:
for m in mantissa_lst: for m in mantissa_lst:
value = m * 10**exponent value = m * 10**exponent
...@@ -356,7 +356,7 @@ def build_buckets(mantissa_lst: List[int], max_value: int) -> List[int]: ...@@ -356,7 +356,7 @@ def build_buckets(mantissa_lst: List[int], max_value: int) -> List[int]:
exponent += 1 exponent += 1
def build_1_2_5_buckets(max_value: int) -> List[int]: def build_1_2_5_buckets(max_value: int) -> list[int]:
""" """
Example: Example:
>>> build_1_2_5_buckets(100) >>> build_1_2_5_buckets(100)
...@@ -365,7 +365,7 @@ def build_1_2_5_buckets(max_value: int) -> List[int]: ...@@ -365,7 +365,7 @@ def build_1_2_5_buckets(max_value: int) -> List[int]:
return build_buckets([1, 2, 5], max_value) return build_buckets([1, 2, 5], max_value)
def build_cudagraph_buckets(vllm_config: VllmConfig) -> List[int]: def build_cudagraph_buckets(vllm_config: VllmConfig) -> list[int]:
if not vllm_config.model_config.enforce_eager: if not vllm_config.model_config.enforce_eager:
buckets = vllm_config.compilation_config.\ buckets = vllm_config.compilation_config.\
cudagraph_capture_sizes.copy() cudagraph_capture_sizes.copy()
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Set from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
...@@ -39,8 +39,8 @@ class SchedulerStats: ...@@ -39,8 +39,8 @@ class SchedulerStats:
@dataclass @dataclass
class LoRAStats: class LoRAStats:
waiting_requests: Set[str] = field(default_factory=set) waiting_requests: set[str] = field(default_factory=set)
running_requests: Set[str] = field(default_factory=set) running_requests: set[str] = field(default_factory=set)
@dataclass @dataclass
...@@ -81,11 +81,11 @@ class IterationStats: ...@@ -81,11 +81,11 @@ class IterationStats:
self.num_generation_tokens = 0 self.num_generation_tokens = 0
self.num_prompt_tokens = 0 self.num_prompt_tokens = 0
self.num_preempted_reqs = 0 self.num_preempted_reqs = 0
self.finished_requests: List[FinishedRequestStats] = [] self.finished_requests: list[FinishedRequestStats] = []
self.time_to_first_tokens_iter: List[float] = [] self.time_to_first_tokens_iter: list[float] = []
self.time_per_output_tokens_iter: List[float] = [] self.time_per_output_tokens_iter: list[float] = []
self.waiting_lora_adapters: Dict[str, int] = {} self.waiting_lora_adapters: dict[str, int] = {}
self.running_lora_adapters: Dict[str, int] = {} self.running_lora_adapters: dict[str, int] = {}
def _time_since(self, start: float) -> float: def _time_since(self, start: float) -> float:
"""Calculate an interval relative to this iteration's timestamp.""" """Calculate an interval relative to this iteration's timestamp."""
...@@ -132,7 +132,7 @@ class IterationStats: ...@@ -132,7 +132,7 @@ class IterationStats:
if num_new_generation_tokens > 0: if num_new_generation_tokens > 0:
req_stats.last_token_ts = engine_core_timestamp req_stats.last_token_ts = engine_core_timestamp
def update_from_events(self, req_id: str, events: List["EngineCoreEvent"], def update_from_events(self, req_id: str, events: list["EngineCoreEvent"],
is_prefilling: bool, req_stats: RequestStateStats, is_prefilling: bool, req_stats: RequestStateStats,
lora_stats: Optional[LoRAStats]): lora_stats: Optional[LoRAStats]):
# Avoid circular dependency # Avoid circular dependency
...@@ -185,7 +185,7 @@ class LoRARequestStates: ...@@ -185,7 +185,7 @@ class LoRARequestStates:
"""Per-LoRA request state stats.""" """Per-LoRA request state stats."""
def __init__(self): def __init__(self):
self.lora_name_to_stats: Dict[str, LoRAStats] = {} self.lora_name_to_stats: dict[str, LoRAStats] = {}
def get_stats(self, req_state: 'RequestState') -> Optional[LoRAStats]: def get_stats(self, req_state: 'RequestState') -> Optional[LoRAStats]:
if req_state.lora_name is None: if req_state.lora_name is None:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, NamedTuple, Optional from typing import NamedTuple, Optional
import torch import torch
...@@ -9,11 +9,11 @@ import torch ...@@ -9,11 +9,11 @@ import torch
class LogprobsLists(NamedTuple): class LogprobsLists(NamedTuple):
# [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1]
logprob_token_ids: List[List[int]] logprob_token_ids: list[list[int]]
# [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1]
logprobs: List[List[float]] logprobs: list[list[float]]
# [num_reqs] # [num_reqs]
sampled_token_ranks: List[int] sampled_token_ranks: list[int]
def slice(self, start: int, end: int): def slice(self, start: int, end: int):
return LogprobsLists( return LogprobsLists(
...@@ -52,23 +52,23 @@ class SamplerOutput: ...@@ -52,23 +52,23 @@ class SamplerOutput:
# ModelRunnerOutput is serialized and sent to the scheduler process. # ModelRunnerOutput is serialized and sent to the scheduler process.
# This is expensive for torch.Tensor so prefer to use List instead. # This is expensive for torch.Tensor so prefer to use list instead.
@dataclass @dataclass
class ModelRunnerOutput: class ModelRunnerOutput:
# [num_reqs] # [num_reqs]
req_ids: List[str] req_ids: list[str]
# req_id -> index # req_id -> index
req_id_to_index: Dict[str, int] req_id_to_index: dict[str, int]
# num_reqs x num_generated_tokens # num_reqs x num_generated_tokens
# num_generated_tokens is the number of tokens # num_generated_tokens is the number of tokens
# generated in the current step. It can be different for # generated in the current step. It can be different for
# each request due to speculative/jump decoding. # each request due to speculative/jump decoding.
sampled_token_ids: List[List[int]] sampled_token_ids: list[list[int]]
# num_reqs x num_spec_tokens # num_reqs x num_spec_tokens
spec_token_ids: Optional[List[List[int]]] spec_token_ids: Optional[list[list[int]]]
# [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1]
...@@ -79,4 +79,4 @@ class ModelRunnerOutput: ...@@ -79,4 +79,4 @@ class ModelRunnerOutput:
# [prompt_len, num_prompt_logprobs] # [prompt_len, num_prompt_logprobs]
# [prompt_len, num_prompt_logprobs] # [prompt_len, num_prompt_logprobs]
# [prompt_len] # [prompt_len]
prompt_logprobs_dict: Dict[str, Optional[LogprobsTensors]] prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]]
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import enum import enum
from typing import TYPE_CHECKING, List, Optional, Union from typing import TYPE_CHECKING, Optional, Union
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
...@@ -20,10 +20,10 @@ class Request: ...@@ -20,10 +20,10 @@ class Request:
self, self,
request_id: str, request_id: str,
prompt: Optional[str], prompt: Optional[str],
prompt_token_ids: List[int], prompt_token_ids: list[int],
multi_modal_inputs: Optional[List["MultiModalKwargs"]], multi_modal_inputs: Optional[list["MultiModalKwargs"]],
multi_modal_hashes: Optional[List[str]], multi_modal_hashes: Optional[list[str]],
multi_modal_placeholders: Optional[List["PlaceholderRange"]], multi_modal_placeholders: Optional[list["PlaceholderRange"]],
sampling_params: SamplingParams, sampling_params: SamplingParams,
eos_token_id: Optional[int], eos_token_id: Optional[int],
arrival_time: float, arrival_time: float,
...@@ -36,7 +36,7 @@ class Request: ...@@ -36,7 +36,7 @@ class Request:
self.lora_request = lora_request self.lora_request = lora_request
self.status = RequestStatus.WAITING self.status = RequestStatus.WAITING
self.events: List[EngineCoreEvent] = [] self.events: list[EngineCoreEvent] = []
self.stop_reason: Union[int, str, None] = None self.stop_reason: Union[int, str, None] = None
assert sampling_params.max_tokens is not None assert sampling_params.max_tokens is not None
self.max_tokens = sampling_params.max_tokens self.max_tokens = sampling_params.max_tokens
...@@ -44,15 +44,15 @@ class Request: ...@@ -44,15 +44,15 @@ class Request:
self.prompt = prompt self.prompt = prompt
self.prompt_token_ids = prompt_token_ids self.prompt_token_ids = prompt_token_ids
self.num_prompt_tokens = len(self.prompt_token_ids) self.num_prompt_tokens = len(self.prompt_token_ids)
self._output_token_ids: List[int] = [] self._output_token_ids: list[int] = []
self._all_token_ids: List[int] = self.prompt_token_ids.copy() self._all_token_ids: list[int] = self.prompt_token_ids.copy()
self.spec_token_ids: List[int] = [] self.spec_token_ids: list[int] = []
self.num_computed_tokens = 0 self.num_computed_tokens = 0
# Multi-modal related # Multi-modal related
self.mm_positions = multi_modal_placeholders or [] self.mm_positions = multi_modal_placeholders or []
self.mm_inputs = multi_modal_inputs or [] self.mm_inputs = multi_modal_inputs or []
self.mm_hashes: List[str] = multi_modal_hashes or [] self.mm_hashes: list[str] = multi_modal_hashes or []
# Sanity check # Sanity check
assert len(self.mm_inputs) == len(self.mm_positions) assert len(self.mm_inputs) == len(self.mm_positions)
...@@ -89,7 +89,7 @@ class Request: ...@@ -89,7 +89,7 @@ class Request:
EngineCoreEvent.new_event(EngineCoreEventType.SCHEDULED, EngineCoreEvent.new_event(EngineCoreEventType.SCHEDULED,
timestamp)) timestamp))
def take_events(self) -> Optional[List[EngineCoreEvent]]: def take_events(self) -> Optional[list[EngineCoreEvent]]:
if not self.events: if not self.events:
return None return None
events, self.events = self.events, [] events, self.events = self.events, []
...@@ -97,7 +97,7 @@ class Request: ...@@ -97,7 +97,7 @@ class Request:
def append_output_token_ids( def append_output_token_ids(
self, self,
token_ids: Union[int, List[int]], token_ids: Union[int, list[int]],
) -> None: ) -> None:
if isinstance(token_ids, int): if isinstance(token_ids, int):
token_ids = [token_ids] token_ids = [token_ids]
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Tuple from typing import Optional
import torch import torch
...@@ -17,7 +17,7 @@ class SamplingMetadata: ...@@ -17,7 +17,7 @@ class SamplingMetadata:
top_k: Optional[torch.Tensor] top_k: Optional[torch.Tensor]
min_p: Optional[torch.Tensor] min_p: Optional[torch.Tensor]
generators: Dict[int, torch.Generator] generators: dict[int, torch.Generator]
# None means no logprobs, 0 means sampled token logprobs only # None means no logprobs, 0 means sampled token logprobs only
max_num_logprobs: Optional[int] max_num_logprobs: Optional[int]
...@@ -28,12 +28,12 @@ class SamplingMetadata: ...@@ -28,12 +28,12 @@ class SamplingMetadata:
presence_penalties: torch.Tensor presence_penalties: torch.Tensor
repetition_penalties: torch.Tensor repetition_penalties: torch.Tensor
output_token_ids: List[List[int]] output_token_ids: list[list[int]]
# req_index -> (min_tokens, stop_token_ids) # req_index -> (min_tokens, stop_token_ids)
min_tokens: Dict[int, Tuple[int, Set[int]]] min_tokens: dict[int, tuple[int, set[int]]]
logit_bias: List[Optional[Dict[int, float]]] logit_bias: list[Optional[dict[int, float]]]
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size, # `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
# vocab size). # vocab size).
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Dict, List, Set, Tuple
import torch import torch
from vllm.model_executor.layers.utils import apply_penalties from vllm.model_executor.layers.utils import apply_penalties
...@@ -9,13 +7,13 @@ from vllm.utils import is_pin_memory_available, make_tensor_with_pad ...@@ -9,13 +7,13 @@ from vllm.utils import is_pin_memory_available, make_tensor_with_pad
def apply_min_token_penalties( def apply_min_token_penalties(
logits: torch.Tensor, output_token_ids: List[List[int]], logits: torch.Tensor, output_token_ids: list[list[int]],
min_tokens: Dict[int, Tuple[int, Set[int]]]) -> None: min_tokens: dict[int, tuple[int, set[int]]]) -> None:
""" """
Applies minimum token penalty by setting the logits of the stop tokens Applies minimum token penalty by setting the logits of the stop tokens
to -inf. to -inf.
""" """
min_tokens_logits_to_penalize: List[Tuple[int, int]] = [] min_tokens_logits_to_penalize: list[tuple[int, int]] = []
for index, (min_token, stop_token_ids) in min_tokens.items(): for index, (min_token, stop_token_ids) in min_tokens.items():
if len(output_token_ids[index]) < min_token: if len(output_token_ids[index]) < min_token:
for stop_token_id in stop_token_ids: for stop_token_id in stop_token_ids:
...@@ -30,7 +28,7 @@ def apply_all_penalties( ...@@ -30,7 +28,7 @@ def apply_all_penalties(
presence_penalties: torch.Tensor, presence_penalties: torch.Tensor,
frequency_penalties: torch.Tensor, frequency_penalties: torch.Tensor,
repetition_penalties: torch.Tensor, repetition_penalties: torch.Tensor,
output_token_ids: List[List[int]], output_token_ids: list[list[int]],
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Applies presence, frequency and repetition penalties to the logits. Applies presence, frequency and repetition penalties to the logits.
...@@ -43,7 +41,7 @@ def apply_all_penalties( ...@@ -43,7 +41,7 @@ def apply_all_penalties(
repetition_penalties) repetition_penalties)
def _convert_to_tensors(output_token_ids: List[List[int]], vocab_size: int, def _convert_to_tensors(output_token_ids: list[list[int]], vocab_size: int,
device: torch.device) -> torch.Tensor: device: torch.device) -> torch.Tensor:
""" """
Convert the different list data structures to tensors. Convert the different list data structures to tensors.
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Dict, Optional from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -54,7 +54,7 @@ class TopKTopPSampler(nn.Module): ...@@ -54,7 +54,7 @@ class TopKTopPSampler(nn.Module):
def forward_native( def forward_native(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
generators: Dict[int, torch.Generator], generators: dict[int, torch.Generator],
k: Optional[torch.Tensor], k: Optional[torch.Tensor],
p: Optional[torch.Tensor], p: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -66,7 +66,7 @@ class TopKTopPSampler(nn.Module): ...@@ -66,7 +66,7 @@ class TopKTopPSampler(nn.Module):
def forward_cuda( def forward_cuda(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
generators: Dict[int, torch.Generator], generators: dict[int, torch.Generator],
k: Optional[torch.Tensor], k: Optional[torch.Tensor],
p: Optional[torch.Tensor], p: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -117,7 +117,7 @@ def apply_top_k_top_p( ...@@ -117,7 +117,7 @@ def apply_top_k_top_p(
def random_sample( def random_sample(
probs: torch.Tensor, probs: torch.Tensor,
generators: Dict[int, torch.Generator], generators: dict[int, torch.Generator],
) -> torch.Tensor: ) -> torch.Tensor:
"""Randomly sample from the probabilities. """Randomly sample from the probabilities.
...@@ -143,7 +143,7 @@ def flashinfer_sample( ...@@ -143,7 +143,7 @@ def flashinfer_sample(
probs: torch.Tensor, probs: torch.Tensor,
k: Optional[torch.Tensor], k: Optional[torch.Tensor],
p: Optional[torch.Tensor], p: Optional[torch.Tensor],
generators: Dict[int, torch.Generator], generators: dict[int, torch.Generator],
) -> torch.Tensor: ) -> torch.Tensor:
"""Sample from the probabilities using FlashInfer. """Sample from the probabilities using FlashInfer.
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -54,7 +53,7 @@ class RejectionSampler(nn.Module): ...@@ -54,7 +53,7 @@ class RejectionSampler(nn.Module):
else: else:
self.forward_method = self.forward_native self.forward_method = self.forward_native
def forward(self, draft_token_ids: List[List[int]], def forward(self, draft_token_ids: list[list[int]],
target_probs: torch.Tensor, target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata) -> SamplerOutput: sampling_metadata: SamplingMetadata) -> SamplerOutput:
if not sampling_metadata.all_greedy: if not sampling_metadata.all_greedy:
...@@ -66,7 +65,7 @@ class RejectionSampler(nn.Module): ...@@ -66,7 +65,7 @@ class RejectionSampler(nn.Module):
def flashinfer_sample( def flashinfer_sample(
self, self,
draft_token_ids: List[List[int]], draft_token_ids: list[list[int]],
target_probs: torch.Tensor, target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> SamplerOutput:
...@@ -119,7 +118,7 @@ class RejectionSampler(nn.Module): ...@@ -119,7 +118,7 @@ class RejectionSampler(nn.Module):
# TODO: The following method can be optimized for better performance. # TODO: The following method can be optimized for better performance.
def forward_native( def forward_native(
self, self,
draft_token_ids: List[List[int]], draft_token_ids: list[list[int]],
target_probs: torch.Tensor, target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> SamplerOutput:
......
...@@ -4,7 +4,7 @@ import time ...@@ -4,7 +4,7 @@ import time
from dataclasses import dataclass from dataclasses import dataclass
from dataclasses import field as dataclass_field from dataclasses import field as dataclass_field
from enum import IntEnum from enum import IntEnum
from typing import ClassVar, Dict, List, Optional, Set from typing import ClassVar, Optional
import msgspec import msgspec
from msgspec import field as msgspec_field from msgspec import field as msgspec_field
...@@ -78,7 +78,7 @@ class RequestStatsUpdate( ...@@ -78,7 +78,7 @@ class RequestStatsUpdate(
FINISHED (All could go to FINISHED) FINISHED (All could go to FINISHED)
""" """
_VALID_TRANSITIONS: ClassVar[Dict[Type, Set[Type]]] = { _VALID_TRANSITIONS: ClassVar[dict[Type, set[Type]]] = {
Type.ARRIVED: { Type.ARRIVED: {
Type.INPUT_PROCESSED, Type.INPUT_PROCESSED,
Type.FINISHED, Type.FINISHED,
...@@ -140,7 +140,7 @@ class RequestStatsUpdate( ...@@ -140,7 +140,7 @@ class RequestStatsUpdate(
finish_reason: Optional[str] = None finish_reason: Optional[str] = None
# Non-optional fields for each update type. # Non-optional fields for each update type.
_REQUIRED_FIELDS: ClassVar[Dict[Type, List[str]]] = { _REQUIRED_FIELDS: ClassVar[dict[Type, list[str]]] = {
Type.INPUT_PROCESSED: ["num_prompt_tokens", "sampling_params"], Type.INPUT_PROCESSED: ["num_prompt_tokens", "sampling_params"],
Type.PREFILLING: ["num_computed_tokens", "num_cached_tokens"], Type.PREFILLING: ["num_computed_tokens", "num_cached_tokens"],
Type.DETOKENIZED: ["num_new_tokens"], Type.DETOKENIZED: ["num_new_tokens"],
...@@ -218,13 +218,13 @@ class RequestStats: ...@@ -218,13 +218,13 @@ class RequestStats:
# 2. the request was preempted and resumed. It is equivalent to running # 2. the request was preempted and resumed. It is equivalent to running
# a prefill of the original prefill tokens + generated output tokens # a prefill of the original prefill tokens + generated output tokens
# before preemption. # before preemption.
prefill_start_ts_s_lst: List[float] = dataclass_field(default_factory=list) prefill_start_ts_s_lst: list[float] = dataclass_field(default_factory=list)
# A list of timestamps when a token is decoded by the engine core. # A list of timestamps when a token is decoded by the engine core.
decoding_ts_s_lst: List[float] = dataclass_field(default_factory=list) decoding_ts_s_lst: list[float] = dataclass_field(default_factory=list)
# A sorted list of timestamps for each output token. # A sorted list of timestamps for each output token.
output_token_ts_s_lst: List[float] = dataclass_field(default_factory=list) output_token_ts_s_lst: list[float] = dataclass_field(default_factory=list)
# First token's timestamp. # First token's timestamp.
first_token_ts_s: Optional[float] = None first_token_ts_s: Optional[float] = None
...@@ -241,7 +241,7 @@ class RequestStats: ...@@ -241,7 +241,7 @@ class RequestStats:
# metric to measure the impact of preemption other than observation of # metric to measure the impact of preemption other than observation of
# large P99 TPOT. Ideally we could quantify the impact of preemption by # large P99 TPOT. Ideally we could quantify the impact of preemption by
# measuring the number of tokens re-computed due to preemption. # measuring the number of tokens re-computed due to preemption.
preempted_ts_s_lst: List[float] = dataclass_field(default_factory=list) preempted_ts_s_lst: list[float] = dataclass_field(default_factory=list)
# Timestamp when the request was finished at the engine core. # Timestamp when the request was finished at the engine core.
finished_ts_s: Optional[float] = None finished_ts_s: Optional[float] = None
...@@ -308,7 +308,7 @@ class RequestStats: ...@@ -308,7 +308,7 @@ class RequestStats:
return self.e2e_latency_s - self.first_token_latency_s return self.e2e_latency_s - self.first_token_latency_s
@property @property
def output_token_latency_s_lst(self) -> List[float]: def output_token_latency_s_lst(self) -> list[float]:
if len(self.output_token_ts_s_lst) == 0: if len(self.output_token_ts_s_lst) == 0:
return [] return []
latency_s_lst = [] latency_s_lst = []
...@@ -442,7 +442,7 @@ class EngineCoreStatsSnapshot( ...@@ -442,7 +442,7 @@ class EngineCoreStatsSnapshot(
default_factory=SchedulerStats) default_factory=SchedulerStats)
# Per request stats updates. # Per request stats updates.
requests_stats_updates: List[RequestStatsUpdate] = msgspec_field( requests_stats_updates: list[RequestStatsUpdate] = msgspec_field(
default_factory=list) default_factory=list)
# Engine core's queue stats. # Engine core's queue stats.
......
...@@ -5,8 +5,8 @@ import os ...@@ -5,8 +5,8 @@ import os
import weakref import weakref
from collections import defaultdict from collections import defaultdict
from collections.abc import Sequence from collections.abc import Sequence
from typing import (TYPE_CHECKING, Any, Callable, Dict, Generic, List, from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
Optional, TypeVar, Union, overload) Union, overload)
import torch import torch
...@@ -24,7 +24,7 @@ T = TypeVar("T") ...@@ -24,7 +24,7 @@ T = TypeVar("T")
class ConstantList(Generic[T], Sequence): class ConstantList(Generic[T], Sequence):
def __init__(self, x: List[T]) -> None: def __init__(self, x: list[T]) -> None:
self._x = x self._x = x
def append(self, item): def append(self, item):
...@@ -57,10 +57,10 @@ class ConstantList(Generic[T], Sequence): ...@@ -57,10 +57,10 @@ class ConstantList(Generic[T], Sequence):
... ...
@overload @overload
def __getitem__(self, s: slice, /) -> List[T]: def __getitem__(self, s: slice, /) -> list[T]:
... ...
def __getitem__(self, item: Union[int, slice]) -> Union[T, List[T]]: def __getitem__(self, item: Union[int, slice]) -> Union[T, list[T]]:
return self._x[item] return self._x[item]
@overload @overload
...@@ -71,7 +71,7 @@ class ConstantList(Generic[T], Sequence): ...@@ -71,7 +71,7 @@ class ConstantList(Generic[T], Sequence):
def __setitem__(self, s: slice, value: T, /): def __setitem__(self, s: slice, value: T, /):
... ...
def __setitem__(self, item: Union[int, slice], value: Union[T, List[T]]): def __setitem__(self, item: Union[int, slice], value: Union[T, list[T]]):
raise Exception("Cannot set item in a constant list") raise Exception("Cannot set item in a constant list")
def __delitem__(self, item): def __delitem__(self, item):
...@@ -99,7 +99,7 @@ class BackgroundProcHandle: ...@@ -99,7 +99,7 @@ class BackgroundProcHandle:
output_path: str, output_path: str,
process_name: str, process_name: str,
target_fn: Callable, target_fn: Callable,
process_kwargs: Dict[Any, Any], process_kwargs: dict[Any, Any],
): ):
context = get_mp_context() context = get_mp_context()
reader, writer = context.Pipe(duplex=False) reader, writer = context.Pipe(duplex=False)
...@@ -146,9 +146,9 @@ def shutdown(proc: multiprocessing.Process, input_path: str, output_path: str): ...@@ -146,9 +146,9 @@ def shutdown(proc: multiprocessing.Process, input_path: str, output_path: str):
def bind_kv_cache( def bind_kv_cache(
kv_caches: Dict[str, torch.Tensor], kv_caches: dict[str, torch.Tensor],
forward_context: Dict[str, "Attention"], forward_context: dict[str, "Attention"],
runner_kv_caches: List[torch.Tensor], runner_kv_caches: list[torch.Tensor],
) -> None: ) -> None:
""" """
Bind the allocated KV cache to both ModelRunner and forward context so Bind the allocated KV cache to both ModelRunner and forward context so
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List
import numpy as np import numpy as np
import torch import torch
...@@ -40,7 +38,7 @@ class BlockTable: ...@@ -40,7 +38,7 @@ class BlockTable:
def append_row( def append_row(
self, self,
block_ids: List[int], block_ids: list[int],
row_idx: int, row_idx: int,
) -> None: ) -> None:
if not block_ids: if not block_ids:
...@@ -50,7 +48,7 @@ class BlockTable: ...@@ -50,7 +48,7 @@ class BlockTable:
self.num_blocks_per_row[row_idx] += num_blocks self.num_blocks_per_row[row_idx] += num_blocks
self.block_table_np[row_idx, start:start + num_blocks] = block_ids self.block_table_np[row_idx, start:start + num_blocks] = block_ids
def add_row(self, block_ids: List[int], row_idx: int) -> None: def add_row(self, block_ids: list[int], row_idx: int) -> None:
self.num_blocks_per_row[row_idx] = 0 self.num_blocks_per_row[row_idx] = 0
self.append_row(block_ids, row_idx) self.append_row(block_ids, row_idx)
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Datastructures defining an input batch # Datastructures defining an input batch
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, cast from typing import TYPE_CHECKING, Optional, cast
import numpy as np import numpy as np
import torch import torch
...@@ -24,16 +24,16 @@ if TYPE_CHECKING: ...@@ -24,16 +24,16 @@ if TYPE_CHECKING:
class CachedRequestState: class CachedRequestState:
req_id: str req_id: str
prompt_token_ids: List[int] prompt_token_ids: list[int]
prompt: Optional[str] prompt: Optional[str]
mm_inputs: List[MultiModalKwargs] mm_inputs: list[MultiModalKwargs]
mm_positions: List["PlaceholderRange"] mm_positions: list["PlaceholderRange"]
sampling_params: SamplingParams sampling_params: SamplingParams
generator: Optional[torch.Generator] generator: Optional[torch.Generator]
block_ids: List[int] block_ids: list[int]
num_computed_tokens: int num_computed_tokens: int
output_token_ids: List[int] output_token_ids: list[int]
mrope_positions: Optional[torch.Tensor] = None mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[int] = None mrope_position_delta: Optional[int] = None
...@@ -63,8 +63,8 @@ class InputBatch: ...@@ -63,8 +63,8 @@ class InputBatch:
self.pin_memory = pin_memory self.pin_memory = pin_memory
self.vocab_size = vocab_size self.vocab_size = vocab_size
self._req_ids: List[Optional[str]] = [] self._req_ids: list[Optional[str]] = []
self.req_id_to_index: Dict[str, int] = {} self.req_id_to_index: dict[str, int] = {}
# TODO(woosuk): This buffer could be too large if max_model_len is big. # TODO(woosuk): This buffer could be too large if max_model_len is big.
# Find a way to reduce the CPU memory usage. # Find a way to reduce the CPU memory usage.
...@@ -106,8 +106,8 @@ class InputBatch: ...@@ -106,8 +106,8 @@ class InputBatch:
device="cpu", device="cpu",
pin_memory=pin_memory) pin_memory=pin_memory)
self.temperature_cpu = self.temperature_cpu_tensor.numpy() self.temperature_cpu = self.temperature_cpu_tensor.numpy()
self.greedy_reqs: Set[str] = set() self.greedy_reqs: set[str] = set()
self.random_reqs: Set[str] = set() self.random_reqs: set[str] = set()
self.top_p = torch.empty((max_num_reqs, ), self.top_p = torch.empty((max_num_reqs, ),
dtype=torch.float32, dtype=torch.float32,
...@@ -117,7 +117,7 @@ class InputBatch: ...@@ -117,7 +117,7 @@ class InputBatch:
device="cpu", device="cpu",
pin_memory=pin_memory) pin_memory=pin_memory)
self.top_p_cpu = self.top_p_cpu_tensor.numpy() self.top_p_cpu = self.top_p_cpu_tensor.numpy()
self.top_p_reqs: Set[str] = set() self.top_p_reqs: set[str] = set()
self.top_k = torch.empty((max_num_reqs, ), self.top_k = torch.empty((max_num_reqs, ),
dtype=torch.int32, dtype=torch.int32,
...@@ -127,7 +127,7 @@ class InputBatch: ...@@ -127,7 +127,7 @@ class InputBatch:
device="cpu", device="cpu",
pin_memory=pin_memory) pin_memory=pin_memory)
self.top_k_cpu = self.top_k_cpu_tensor.numpy() self.top_k_cpu = self.top_k_cpu_tensor.numpy()
self.top_k_reqs: Set[str] = set() self.top_k_reqs: set[str] = set()
self.min_p = torch.empty((max_num_reqs, ), self.min_p = torch.empty((max_num_reqs, ),
dtype=torch.float32, dtype=torch.float32,
...@@ -137,7 +137,7 @@ class InputBatch: ...@@ -137,7 +137,7 @@ class InputBatch:
device="cpu", device="cpu",
pin_memory=pin_memory) pin_memory=pin_memory)
self.min_p_cpu = self.min_p_cpu_tensor.numpy() self.min_p_cpu = self.min_p_cpu_tensor.numpy()
self.min_p_reqs: Set[str] = set() self.min_p_reqs: set[str] = set()
# Frequency penalty related data structures # Frequency penalty related data structures
self.frequency_penalties = torch.empty((max_num_reqs, ), self.frequency_penalties = torch.empty((max_num_reqs, ),
...@@ -150,7 +150,7 @@ class InputBatch: ...@@ -150,7 +150,7 @@ class InputBatch:
pin_memory=pin_memory) pin_memory=pin_memory)
self.frequency_penalties_cpu = \ self.frequency_penalties_cpu = \
self.frequency_penalties_cpu_tensor.numpy() self.frequency_penalties_cpu_tensor.numpy()
self.frequency_penalties_reqs: Set[str] = set() self.frequency_penalties_reqs: set[str] = set()
# Presence penalty related data structures # Presence penalty related data structures
self.presence_penalties = torch.empty((max_num_reqs, ), self.presence_penalties = torch.empty((max_num_reqs, ),
...@@ -162,7 +162,7 @@ class InputBatch: ...@@ -162,7 +162,7 @@ class InputBatch:
pin_memory=pin_memory) pin_memory=pin_memory)
self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy( self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
) )
self.presence_penalties_reqs: Set[str] = set() self.presence_penalties_reqs: set[str] = set()
# Repetition penalty related data structures # Repetition penalty related data structures
self.repetition_penalties = torch.empty((max_num_reqs, ), self.repetition_penalties = torch.empty((max_num_reqs, ),
...@@ -175,43 +175,43 @@ class InputBatch: ...@@ -175,43 +175,43 @@ class InputBatch:
pin_memory=pin_memory) pin_memory=pin_memory)
self.repetition_penalties_cpu = \ self.repetition_penalties_cpu = \
self.repetition_penalties_cpu_tensor.numpy() self.repetition_penalties_cpu_tensor.numpy()
self.repetition_penalties_reqs: Set[str] = set() self.repetition_penalties_reqs: set[str] = set()
# req_index -> (min_tokens, stop_token_ids) # req_index -> (min_tokens, stop_token_ids)
self.min_tokens: Dict[int, Tuple[int, Set[int]]] = {} self.min_tokens: dict[int, tuple[int, set[int]]] = {}
# lora related # lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs, ), self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
dtype=np.int32) dtype=np.int32)
self.lora_id_to_request_ids: Dict[int, Set[str]] = {} self.lora_id_to_request_ids: dict[int, set[str]] = {}
self.lora_id_to_lora_request: Dict[int, LoRARequest] = {} self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
# req_index -> generator # req_index -> generator
# NOTE(woosuk): The indices of the requests that do not have their own # NOTE(woosuk): The indices of the requests that do not have their own
# generator should not be included in the dictionary. # generator should not be included in the dictionary.
self.generators: Dict[int, torch.Generator] = {} self.generators: dict[int, torch.Generator] = {}
self.num_logprobs: Dict[str, int] = {} self.num_logprobs: dict[str, int] = {}
# NOTE(rob): num_prompt_logprobs only includes reqs # NOTE(rob): num_prompt_logprobs only includes reqs
# that are currently in the prefill phase. # that are currently in the prefill phase.
self.num_prompt_logprobs: Dict[str, int] = {} self.num_prompt_logprobs: dict[str, int] = {}
self.logit_bias: List[Optional[Dict[int, self.logit_bias: list[Optional[dict[int,
float]]] = [None] * max_num_reqs float]]] = [None] * max_num_reqs
self.has_allowed_token_ids: Set[str] = set() self.has_allowed_token_ids: set[str] = set()
self.allowed_token_ids_mask: Optional[torch.Tensor] = None self.allowed_token_ids_mask: Optional[torch.Tensor] = None
self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
self.req_output_token_ids: List[Optional[List[int]]] = [] self.req_output_token_ids: list[Optional[list[int]]] = []
# This is updated each time the batch constituents change. # This is updated each time the batch constituents change.
self.sampling_metadata = self._make_sampling_metadata() self.sampling_metadata = self._make_sampling_metadata()
@property @property
def req_ids(self) -> List[str]: def req_ids(self) -> list[str]:
# None elements should only be present transiently # None elements should only be present transiently
# while performing state updates to the batch. # while performing state updates to the batch.
return cast(List[str], self._req_ids) return cast(list[str], self._req_ids)
def add_request( def add_request(
self, self,
...@@ -417,7 +417,7 @@ class InputBatch: ...@@ -417,7 +417,7 @@ class InputBatch:
self.logit_bias[i2], self.logit_bias[i1] self.logit_bias[i2], self.logit_bias[i1]
self.block_table.swap_row(i1, i2) self.block_table.swap_row(i1, i2)
def condense(self, empty_req_indices: List[int]) -> None: def condense(self, empty_req_indices: list[int]) -> None:
num_reqs = self.num_reqs num_reqs = self.num_reqs
if num_reqs == 0: if num_reqs == 0:
# The batched states are empty. # The batched states are empty.
...@@ -550,7 +550,7 @@ class InputBatch: ...@@ -550,7 +550,7 @@ class InputBatch:
frequency_penalties=self.frequency_penalties[:num_reqs], frequency_penalties=self.frequency_penalties[:num_reqs],
presence_penalties=self.presence_penalties[:num_reqs], presence_penalties=self.presence_penalties[:num_reqs],
repetition_penalties=self.repetition_penalties[:num_reqs], repetition_penalties=self.repetition_penalties[:num_reqs],
output_token_ids=cast(List[List[int]], self.req_output_token_ids), output_token_ids=cast(list[list[int]], self.req_output_token_ids),
min_tokens=self.min_tokens, min_tokens=self.min_tokens,
no_penalties=self.no_penalties, no_penalties=self.no_penalties,
logit_bias=self.logit_bias[:num_reqs], logit_bias=self.logit_bias[:num_reqs],
...@@ -577,7 +577,7 @@ class InputBatch: ...@@ -577,7 +577,7 @@ class InputBatch:
def make_lora_inputs( def make_lora_inputs(
self, num_scheduled_tokens: np.ndarray self, num_scheduled_tokens: np.ndarray
) -> Tuple[Tuple[int, ...], Tuple[int, ...], Set[LoRARequest]]: ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
""" """
Given the num_scheduled_tokens for each request in the batch, return Given the num_scheduled_tokens for each request in the batch, return
datastructures used to activate the current LoRAs. datastructures used to activate the current LoRAs.
...@@ -593,7 +593,7 @@ class InputBatch: ...@@ -593,7 +593,7 @@ class InputBatch:
prompt_lora_mapping = tuple(req_lora_mapping) prompt_lora_mapping = tuple(req_lora_mapping)
token_lora_mapping = tuple( token_lora_mapping = tuple(
req_lora_mapping.repeat(num_scheduled_tokens)) req_lora_mapping.repeat(num_scheduled_tokens))
active_lora_requests: Set[LoRARequest] = set( active_lora_requests: set[LoRARequest] = set(
self.lora_id_to_lora_request.values()) self.lora_id_to_lora_request.values())
return prompt_lora_mapping, token_lora_mapping, active_lora_requests return prompt_lora_mapping, token_lora_mapping, active_lora_requests
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import gc import gc
import time import time
import weakref import weakref
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -135,9 +135,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -135,9 +135,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Lazy initialization # Lazy initialization
# self.model: nn.Module # Set after load_model # self.model: nn.Module # Set after load_model
self.kv_caches: List[torch.Tensor] = [] self.kv_caches: list[torch.Tensor] = []
# req_id -> (input_id -> encoder_output) # req_id -> (input_id -> encoder_output)
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {} self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
# Set up speculative decoding. # Set up speculative decoding.
self.use_spec_decode = False self.use_spec_decode = False
...@@ -158,7 +158,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -158,7 +158,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
) )
# Request states. # Request states.
self.requests: Dict[str, CachedRequestState] = {} self.requests: dict[str, CachedRequestState] = {}
# Persistent batch. # Persistent batch.
self.input_batch = InputBatch( self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs, max_num_reqs=self.max_num_reqs,
...@@ -274,7 +274,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -274,7 +274,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# then resubmitted with the same ID. In this case, we treat them as two # then resubmitted with the same ID. In this case, we treat them as two
# distinct requests - clearing the cached states for the first request # distinct requests - clearing the cached states for the first request
# and handling the second as a new request. # and handling the second as a new request.
removed_req_indices: List[int] = [] removed_req_indices: list[int] = []
for req_id in scheduler_output.finished_req_ids: for req_id in scheduler_output.finished_req_ids:
req_index = self.input_batch.remove_request(req_id) req_index = self.input_batch.remove_request(req_id)
if req_index is not None: if req_index is not None:
...@@ -305,7 +305,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -305,7 +305,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert req_index is not None assert req_index is not None
removed_req_indices.append(req_index) removed_req_indices.append(req_index)
req_ids_to_add: List[str] = [] req_ids_to_add: list[str] = []
# Add new requests to the cached states. # Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs: for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id req_id = new_req_data.req_id
...@@ -446,7 +446,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -446,7 +446,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _prepare_inputs( def _prepare_inputs(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> Tuple[FlashAttentionMetadata, torch.Tensor]: ) -> tuple[FlashAttentionMetadata, torch.Tensor]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0 assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs num_reqs = self.input_batch.num_reqs
...@@ -774,8 +774,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -774,8 +774,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return return
# Batch the multi-modal inputs. # Batch the multi-modal inputs.
mm_inputs: List[MultiModalKwargs] = [] mm_inputs: list[MultiModalKwargs] = []
req_input_ids: List[Tuple[str, int]] = [] req_input_ids: list[tuple[str, int]] = []
for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id] req_state = self.requests[req_id]
for input_id in encoder_input_ids: for input_id in encoder_input_ids:
...@@ -819,8 +819,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -819,8 +819,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _gather_encoder_outputs( def _gather_encoder_outputs(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> List[torch.Tensor]: ) -> list[torch.Tensor]:
encoder_outputs: List[torch.Tensor] = [] encoder_outputs: list[torch.Tensor] = []
for req_id in self.input_batch.req_ids: for req_id in self.input_batch.req_ids:
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id] req_id]
...@@ -1022,10 +1022,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1022,10 +1022,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def generate_draft_token_ids( def generate_draft_token_ids(
self, self,
sampled_token_ids: List[List[int]], sampled_token_ids: list[list[int]],
) -> List[List[int]]: ) -> list[list[int]]:
# TODO(woosuk): Optimize. # TODO(woosuk): Optimize.
draft_token_ids: List[List[int]] = [] draft_token_ids: list[list[int]] = []
for i, sampled_ids in enumerate(sampled_token_ids): for i, sampled_ids in enumerate(sampled_token_ids):
num_sampled_ids = len(sampled_ids) num_sampled_ids = len(sampled_ids)
if not num_sampled_ids: if not num_sampled_ids:
...@@ -1069,12 +1069,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1069,12 +1069,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> Dict[str, Optional[LogprobsTensors]]: ) -> dict[str, Optional[LogprobsTensors]]:
num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
if not num_prompt_logprobs_dict: if not num_prompt_logprobs_dict:
return {} return {}
prompt_logprobs_dict: Dict[str, Optional[LogprobsTensors]] = {} prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
# Since prompt logprobs are a rare feature, prioritize simple, # Since prompt logprobs are a rare feature, prioritize simple,
# maintainable loop over optimal performance. # maintainable loop over optimal performance.
...@@ -1365,7 +1365,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1365,7 +1365,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
"Hybrid models with more than one KV cache type are not " "Hybrid models with more than one KV cache type are not "
"supported yet.") "supported yet.")
kv_caches: Dict[str, torch.Tensor] = {} kv_caches: dict[str, torch.Tensor] = {}
for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items(): for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items():
tensor_config = kv_cache_config.tensors[layer_name] tensor_config = kv_cache_config.tensors[layer_name]
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
"""A GPU worker class.""" """A GPU worker class."""
import gc import gc
import os import os
from typing import TYPE_CHECKING, Optional, Set from typing import TYPE_CHECKING, Optional
import torch import torch
import torch.distributed import torch.distributed
...@@ -243,7 +243,7 @@ class Worker(WorkerBase): ...@@ -243,7 +243,7 @@ class Worker(WorkerBase):
def remove_lora(self, lora_id: int) -> bool: def remove_lora(self, lora_id: int) -> bool:
return self.model_runner.remove_lora(lora_id) return self.model_runner.remove_lora(lora_id)
def list_loras(self) -> Set[int]: def list_loras(self) -> set[int]:
return self.model_runner.list_loras() return self.model_runner.list_loras()
def pin_lora(self, lora_id: int) -> bool: def pin_lora(self, lora_id: int) -> bool:
......
...@@ -4,7 +4,6 @@ Define LoRA functionality mixin for model runners. ...@@ -4,7 +4,6 @@ Define LoRA functionality mixin for model runners.
""" """
from contextlib import contextmanager from contextlib import contextmanager
from typing import Set, Tuple
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
...@@ -57,9 +56,9 @@ class LoRAModelRunnerMixin: ...@@ -57,9 +56,9 @@ class LoRAModelRunnerMixin:
) )
return self.lora_manager.create_lora_manager(model) return self.lora_manager.create_lora_manager(model)
def _set_active_loras(self, prompt_lora_mapping: Tuple[int, ...], def _set_active_loras(self, prompt_lora_mapping: tuple[int, ...],
token_lora_mapping: Tuple[int, ...], token_lora_mapping: tuple[int, ...],
lora_requests: Set[LoRARequest]) -> None: lora_requests: set[LoRARequest]) -> None:
if not self.lora_manager: if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.") raise RuntimeError("LoRA is not enabled.")
...@@ -74,10 +73,10 @@ class LoRAModelRunnerMixin: ...@@ -74,10 +73,10 @@ class LoRAModelRunnerMixin:
def set_active_loras(self, input_batch: InputBatch, def set_active_loras(self, input_batch: InputBatch,
num_scheduled_tokens: np.ndarray) -> None: num_scheduled_tokens: np.ndarray) -> None:
prompt_lora_mapping: Tuple[int, ...] # of size input_batch.num_reqs prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs
token_lora_mapping: Tuple[int, token_lora_mapping: tuple[int,
...] # of size np.sum(num_scheduled_tokens) ...] # of size np.sum(num_scheduled_tokens)
lora_requests: Set[LoRARequest] lora_requests: set[LoRARequest]
prompt_lora_mapping, token_lora_mapping, lora_requests = \ prompt_lora_mapping, token_lora_mapping, lora_requests = \
input_batch.make_lora_inputs(num_scheduled_tokens) input_batch.make_lora_inputs(num_scheduled_tokens)
return self._set_active_loras(prompt_lora_mapping, token_lora_mapping, return self._set_active_loras(prompt_lora_mapping, token_lora_mapping,
...@@ -105,7 +104,7 @@ class LoRAModelRunnerMixin: ...@@ -105,7 +104,7 @@ class LoRAModelRunnerMixin:
num_scheduled_tokens) num_scheduled_tokens)
# Make dummy lora requests # Make dummy lora requests
lora_requests: Set[LoRARequest] = { lora_requests: set[LoRARequest] = {
LoRARequest(lora_name=f"warmup_{lora_id}", LoRARequest(lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id, lora_int_id=lora_id,
lora_path="/not/a/real/path") lora_path="/not/a/real/path")
...@@ -143,7 +142,7 @@ class LoRAModelRunnerMixin: ...@@ -143,7 +142,7 @@ class LoRAModelRunnerMixin:
raise RuntimeError("LoRA is not enabled.") raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.pin_adapter(lora_id) return self.lora_manager.pin_adapter(lora_id)
def list_loras(self) -> Set[int]: def list_loras(self) -> set[int]:
if not self.lora_manager: if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.") raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.list_adapters() return self.lora_manager.list_adapters()
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import time import time
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast from typing import TYPE_CHECKING, Optional, cast
from unittest.mock import patch from unittest.mock import patch
import numpy as np import numpy as np
...@@ -95,13 +95,13 @@ class TPUModelRunner: ...@@ -95,13 +95,13 @@ class TPUModelRunner:
) )
# Request states. # Request states.
self.requests: Dict[str, CachedRequestState] = {} self.requests: dict[str, CachedRequestState] = {}
# req_id -> (input_id -> encoder_output) # req_id -> (input_id -> encoder_output)
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {} self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
# KV caches for forward pass # KV caches for forward pass
self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = [] self.kv_caches: list[tuple[torch.Tensor, torch.Tensor]] = []
# Cached torch/numpy tensor # Cached torch/numpy tensor
# The pytorch tensor and numpy array share the same buffer. # The pytorch tensor and numpy array share the same buffer.
...@@ -171,7 +171,7 @@ class TPUModelRunner: ...@@ -171,7 +171,7 @@ class TPUModelRunner:
# then resubmitted with the same ID. In this case, we treat them as two # then resubmitted with the same ID. In this case, we treat them as two
# distinct requests - clearing the cached states for the first request # distinct requests - clearing the cached states for the first request
# and handling the second as a new request. # and handling the second as a new request.
removed_req_indices: List[int] = [] removed_req_indices: list[int] = []
for req_id in scheduler_output.finished_req_ids: for req_id in scheduler_output.finished_req_ids:
req_index = self.input_batch.remove_request(req_id) req_index = self.input_batch.remove_request(req_id)
if req_index is not None: if req_index is not None:
...@@ -194,7 +194,7 @@ class TPUModelRunner: ...@@ -194,7 +194,7 @@ class TPUModelRunner:
assert req_index is not None assert req_index is not None
removed_req_indices.append(req_index) removed_req_indices.append(req_index)
req_ids_to_add: List[str] = [] req_ids_to_add: list[str] = []
# Add new requests to the cached states. # Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs: for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id req_id = new_req_data.req_id
...@@ -453,7 +453,7 @@ class TPUModelRunner: ...@@ -453,7 +453,7 @@ class TPUModelRunner:
selected_token_ids = torch.argmax(logits, dim=-1, keepdim=True) selected_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
# Then, let's update the cache state. # Then, let's update the cache state.
request_seq_lens: List[Tuple[int, CachedRequestState, int]] = [] request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
assert req_id is not None assert req_id is not None
req_state = self.requests[req_id] req_state = self.requests[req_id]
...@@ -473,9 +473,9 @@ class TPUModelRunner: ...@@ -473,9 +473,9 @@ class TPUModelRunner:
assert all( assert all(
req_id is not None for req_id in req_id is not None for req_id in
self.input_batch.req_ids[:num_reqs]), "req_ids contains None" self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs]) req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs])
prompt_logprobs_dict: Dict[str, Optional[LogprobsTensors]] = {} prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
for req_id in self.input_batch.req_ids[:num_reqs]: for req_id in self.input_batch.req_ids[:num_reqs]:
prompt_logprobs_dict[req_id] = None prompt_logprobs_dict[req_id] = None
...@@ -612,7 +612,7 @@ class TPUModelRunner: ...@@ -612,7 +612,7 @@ class TPUModelRunner:
"Hybrid models with more than one KV cache type are not " "Hybrid models with more than one KV cache type are not "
"supported yet.") "supported yet.")
kv_caches: Dict[str, torch.Tensor] = {} kv_caches: dict[str, torch.Tensor] = {}
for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items(): for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items():
tensor_config = kv_cache_config.tensors[layer_name] tensor_config = kv_cache_config.tensors[layer_name]
...@@ -649,7 +649,7 @@ class ModelWrapperV1(nn.Module): ...@@ -649,7 +649,7 @@ class ModelWrapperV1(nn.Module):
self, self,
token_ids: torch.Tensor, token_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
) -> torch.Tensor: ) -> torch.Tensor:
"""Executes the forward pass of the model and samples the next token. """Executes the forward pass of the model and samples the next token.
...@@ -667,7 +667,7 @@ class ModelWrapperV1(nn.Module): ...@@ -667,7 +667,7 @@ class ModelWrapperV1(nn.Module):
# [num_kv_heads, num_blocks, block_size, head_size]. To make it # [num_kv_heads, num_blocks, block_size, head_size]. To make it
# work, we need to flatten the first three dimensions and modify # work, we need to flatten the first three dimensions and modify
# the slot_mapping accordingly. # the slot_mapping accordingly.
# kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] # kv_caches: list[tuple[torch.Tensor, torch.Tensor]]
num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
slot_mapping = attn_metadata.slot_mapping slot_mapping = attn_metadata.slot_mapping
slot_mapping = slot_mapping.flatten() slot_mapping = slot_mapping.flatten()
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""A TPU worker class.""" """A TPU worker class."""
import os import os
from typing import Dict, List, Optional from typing import Optional
import torch import torch
import torch.distributed import torch.distributed
...@@ -103,7 +103,7 @@ class TPUWorker: ...@@ -103,7 +103,7 @@ class TPUWorker:
self.model_runner = TPUModelRunner(self.vllm_config, self.device) self.model_runner = TPUModelRunner(self.vllm_config, self.device)
def determine_available_memory(self) -> int: def determine_available_memory(self) -> int:
kv_caches: Dict[str, torch.Tensor] = {} kv_caches: dict[str, torch.Tensor] = {}
kv_cache_spec = self.model_runner.get_kv_cache_spec() kv_cache_spec = self.model_runner.get_kv_cache_spec()
for layer_name, layer_spec in kv_cache_spec.items(): for layer_name, layer_spec in kv_cache_spec.items():
if isinstance(layer_spec, FullAttentionSpec): if isinstance(layer_spec, FullAttentionSpec):
...@@ -118,7 +118,7 @@ class TPUWorker: ...@@ -118,7 +118,7 @@ class TPUWorker:
else: else:
raise NotImplementedError raise NotImplementedError
runner_kv_caches: List[torch.Tensor] = [] runner_kv_caches: list[torch.Tensor] = []
bind_kv_cache( bind_kv_cache(
kv_caches, kv_caches,
self.vllm_config.compilation_config.static_forward_context, self.vllm_config.compilation_config.static_forward_context,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment