Unverified Commit cf069aa8 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated Python 3.8 typing (#13971)

parent bf33700e
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional, Type
from typing import Any, Optional
import torch
......@@ -21,7 +21,7 @@ class TritonMLABackend(MLACommonBackend):
return "TRITON_MLA_VLLM_V1"
@staticmethod
def get_impl_cls() -> Type["TritonMLAImpl"]:
def get_impl_cls() -> type["TritonMLAImpl"]:
return TritonMLAImpl
......@@ -33,10 +33,10 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]],
blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
# MLA Specific Arguments
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Optional
import torch
# Required to register custom ops.
......@@ -22,15 +22,15 @@ class PallasAttentionBackend(AttentionBackend):
return "PALLAS_VLLM_V1"
@staticmethod
def get_impl_cls() -> Type["PallasAttentionBackendImpl"]:
def get_impl_cls() -> type["PallasAttentionBackendImpl"]:
return PallasAttentionBackendImpl
@staticmethod
def get_metadata_cls() -> Type["PallasMetadata"]:
def get_metadata_cls() -> type["PallasMetadata"]:
return PallasMetadata
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
def get_state_cls() -> type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
......@@ -39,7 +39,7 @@ class PallasAttentionBackend(AttentionBackend):
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
) -> tuple[int, ...]:
return (num_kv_heads, num_blocks, block_size, head_size)
@staticmethod
......@@ -77,10 +77,10 @@ class PallasAttentionBackendImpl(AttentionImpl):
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
......@@ -120,7 +120,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
kv_cache: tuple[torch.Tensor, torch.Tensor],
attn_metadata: PallasMetadata,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
......
# SPDX-License-Identifier: Apache-2.0
"""Attention layer with PagedAttention on rocm"""
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Optional
import torch
......@@ -20,7 +20,7 @@ class ROCmAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> List[int]:
def get_supported_head_sizes() -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
......@@ -28,11 +28,11 @@ class ROCmAttentionBackend(AttentionBackend):
return "ROCM_ATTN_VLLM_V1"
@staticmethod
def get_impl_cls() -> Type["ROCmAttentionImpl"]:
def get_impl_cls() -> type["ROCmAttentionImpl"]:
return ROCmAttentionImpl
@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
def get_metadata_cls() -> type["AttentionMetadata"]:
return FlashAttentionMetadata
@staticmethod
......@@ -41,7 +41,7 @@ class ROCmAttentionBackend(AttentionBackend):
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
......@@ -51,7 +51,7 @@ class ROCmAttentionBackend(AttentionBackend):
return False
@staticmethod
def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder
......@@ -63,10 +63,10 @@ class ROCmAttentionImpl(AttentionImpl):
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
) -> None:
......
# SPDX-License-Identifier: Apache-2.0
from collections import defaultdict
from typing import Dict, Iterable, List, Optional
from collections.abc import Iterable
from typing import Optional
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
......@@ -29,7 +30,7 @@ class BlockPool:
self.num_gpu_blocks = num_gpu_blocks
self.enable_caching = enable_caching
# All kv-cache blocks.
self.blocks: List[KVCacheBlock] = [
self.blocks: list[KVCacheBlock] = [
KVCacheBlock(idx) for idx in range(num_gpu_blocks)
]
# Free block queue that constructs and manipulates a doubly linked
......@@ -46,7 +47,7 @@ class BlockPool:
# if there is already an identical block in the cache. This is because
# we want to make sure the allocated block IDs won't change so that
# block tables are append-only.
self.cached_block_hash_to_block: Dict[BlockHashType, Dict[
self.cached_block_hash_to_block: dict[BlockHashType, dict[
int, KVCacheBlock]] = defaultdict(dict)
def get_cached_block(self,
......@@ -69,8 +70,8 @@ class BlockPool:
def cache_full_blocks(
self,
request: Request,
blocks: List[KVCacheBlock],
block_hashes: List[BlockHashType],
blocks: list[KVCacheBlock],
block_hashes: list[BlockHashType],
num_cached_blocks: int,
num_full_blocks: int,
block_size: int,
......@@ -146,7 +147,7 @@ class BlockPool:
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
prev_block_hash_value = block_hash.hash_value
def get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]:
def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]:
"""Get new blocks from the free block pool.
Note that we do not check block cache in this function.
......@@ -161,7 +162,7 @@ class BlockPool:
raise ValueError(
f"Cannot get {num_blocks} free blocks from the pool")
ret: List[KVCacheBlock] = []
ret: list[KVCacheBlock] = []
idx = 0
while idx < num_blocks:
# First allocate blocks.
......@@ -200,7 +201,7 @@ class BlockPool:
return True
return False
def touch(self, blocks: List[KVCacheBlock]) -> None:
def touch(self, blocks: list[KVCacheBlock]) -> None:
"""Touch a block increases its reference count by 1, and may remove
the block from the free queue. This is used when a block is hit by
another request with the same prefix.
......
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Dict, List, Set, Tuple
from typing import TYPE_CHECKING
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
......@@ -18,9 +18,9 @@ class EncoderCacheManager:
self.cache_size = cache_size
self.num_free_slots = cache_size
# req_id -> cached input ids
self.cached: Dict[str, Set[int]] = {}
# List of [req_id, input_id]
self.freed: List[Tuple[str, int]] = []
self.cached: dict[str, set[int]] = {}
# list of [req_id, input_id]
self.freed: list[tuple[str, int]] = []
def has_cache(self, request: Request, input_id: int) -> bool:
req_id = request.request_id
......@@ -37,7 +37,7 @@ class EncoderCacheManager:
self.cached[req_id].add(input_id)
self.num_free_slots -= request.get_num_encoder_tokens(input_id)
def get_cached_input_ids(self, request: Request) -> Set[int]:
def get_cached_input_ids(self, request: Request) -> set[int]:
return self.cached.get(request.request_id, set())
def free_encoder_input(self, request: Request, input_id: int) -> None:
......@@ -58,7 +58,7 @@ class EncoderCacheManager:
for input_id in input_ids:
self.free_encoder_input(request, input_id)
def get_freed_ids(self) -> List[Tuple[str, int]]:
def get_freed_ids(self) -> list[tuple[str, int]]:
freed = self.freed
self.freed = []
return freed
......@@ -67,7 +67,7 @@ class EncoderCacheManager:
def compute_encoder_budget(
model_config: "ModelConfig",
scheduler_config: "SchedulerConfig",
) -> Tuple[int, int]:
) -> tuple[int, int]:
"""Compute the encoder cache budget based on the model and scheduler
configurations.
......@@ -97,7 +97,7 @@ def compute_encoder_budget(
def _compute_encoder_budget_multimodal(
model_config: "ModelConfig",
scheduler_config: "SchedulerConfig",
) -> Tuple[int, int]:
) -> tuple[int, int]:
"""Compute the encoder cache budget based on the model and scheduler
configurations for a multimodal model.
......
# SPDX-License-Identifier: Apache-2.0
from collections import defaultdict
from typing import DefaultDict, Dict, Iterable, List, Optional, Tuple
from collections.abc import Iterable
from typing import Optional
from vllm.logger import init_logger
from vllm.utils import cdiv
......@@ -52,20 +53,20 @@ class KVCacheManager:
# Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request
# is finished.
self.req_to_blocks: DefaultDict[str,
List[KVCacheBlock]] = defaultdict(list)
self.req_to_blocks: defaultdict[str,
list[KVCacheBlock]] = defaultdict(list)
# Mapping from request ID to kv block hashes.
# This is to avoid recomputing the block hashes for each call of
# `get_computed_blocks` or `allocate_slots`.
self.req_to_block_hashes: DefaultDict[
str, List[BlockHashType]] = defaultdict(list)
self.req_to_block_hashes: defaultdict[
str, list[BlockHashType]] = defaultdict(list)
# {req_id: The number of cached blocks for this given request}
# This is used to track the number of cached blocks for each request.
# This is only used to track the RUNNING requests, we do not track the
# data for reempted ones.
self.num_cached_block: Dict[str, int] = {}
self.num_cached_block: dict[str, int] = {}
self.prefix_cache_stats = PrefixCacheStats()
@property
......@@ -88,7 +89,7 @@ class KVCacheManager:
return stats
def get_computed_blocks(
self, request: Request) -> Tuple[List[KVCacheBlock], int]:
self, request: Request) -> tuple[list[KVCacheBlock], int]:
"""Get the computed (cached) blocks for the request.
Note that the computed blocks must be full.
......@@ -136,8 +137,8 @@ class KVCacheManager:
self,
request: Request,
num_tokens: int,
new_computed_blocks: Optional[List[KVCacheBlock]] = None
) -> Optional[List[KVCacheBlock]]:
new_computed_blocks: Optional[list[KVCacheBlock]] = None
) -> Optional[list[KVCacheBlock]]:
"""Add slots for a request with new tokens to append.
Args:
......
......@@ -3,7 +3,7 @@
from collections import deque
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, List, NamedTuple, Optional, Tuple
from typing import Any, NamedTuple, Optional
from vllm.config import VllmConfig
from vllm.logger import init_logger
......@@ -25,7 +25,7 @@ class BlockHashType(NamedTuple):
# Hash value of the block in an integer.
hash_value: int
# Token IDs in the block.
token_ids: Tuple[int, ...]
token_ids: tuple[int, ...]
# Extra keys for the block.
extra_keys: Optional[Any] = None
......@@ -45,7 +45,7 @@ class PrefixCachingMetrics:
self.aggregated_query_total = 0
self.aggregated_query_hit = 0
# A deque of (requests, queries, hits) for the most recent requests.
self.query_queue: deque[Tuple[int, int, int]] = deque()
self.query_queue: deque[tuple[int, int, int]] = deque()
def observe(self, stats: PrefixCacheStats):
"""Observe the prefix caching for a set of requests.
......@@ -164,7 +164,7 @@ class FreeKVCacheBlockQueue:
blocks: A list of KVCacheBlock objects.
"""
def __init__(self, blocks: List[KVCacheBlock]) -> None:
def __init__(self, blocks: list[KVCacheBlock]) -> None:
self.num_free_blocks = len(blocks)
# Initialize the doubly linked list of free blocks.
......@@ -233,7 +233,7 @@ class FreeKVCacheBlockQueue:
block.next_free_block = None
self.num_free_blocks += 1
def get_all_free_blocks(self) -> List[KVCacheBlock]:
def get_all_free_blocks(self) -> list[KVCacheBlock]:
"""Get all free blocks in the free list. Mainly used for testing.
Returns:
......@@ -264,7 +264,7 @@ def need_extra_keys(request: Request) -> bool:
def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
end_token_idx: int,
start_mm_idx: int) -> Tuple[List[Any], int]:
start_mm_idx: int) -> tuple[list[Any], int]:
"""Generate extra keys related to MultiModal request for block hash
computation. For multi-modal inputs, the extra keys are
(mm_hash, start_offset) that indicate a mm input contained in the
......@@ -279,7 +279,7 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
Returns:
A tuple of extra keys and the next multi-modal index.
"""
extra_keys: List[Any] = []
extra_keys: list[Any] = []
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
if not mm_positions:
......@@ -331,7 +331,7 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
return extra_keys, curr_mm_idx
def _gen_lora_extra_hash_keys(request: Request) -> List[int]:
def _gen_lora_extra_hash_keys(request: Request) -> list[int]:
"""Generate extra keys related to LoRA for block hash computation.
Args:
......@@ -348,7 +348,7 @@ def _gen_lora_extra_hash_keys(request: Request) -> List[int]:
def generate_block_hash_extra_keys(
request: Request, start_token_idx: int, end_token_idx: int,
start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]:
start_mm_idx: int) -> tuple[Optional[tuple[Any, ...]], int]:
"""Generate extra keys for the block hash. The extra keys can come from
the multi-modal inputs and request specific metadata (e.g., LoRA ID).
......@@ -361,12 +361,12 @@ def generate_block_hash_extra_keys(
Returns:
A tuple of extra keys and the next multi-modal index.
"""
mm_extra_keys: List[Any]
mm_extra_keys: list[Any]
mm_extra_keys, new_start_mm_idx = _gen_mm_extra_hash_keys(
request, start_token_idx, end_token_idx, start_mm_idx)
lora_extra_keys: List[int] = _gen_lora_extra_hash_keys(request)
lora_extra_keys: list[int] = _gen_lora_extra_hash_keys(request)
extra_keys: List[Any] = lora_extra_keys + mm_extra_keys
extra_keys: list[Any] = lora_extra_keys + mm_extra_keys
if not extra_keys:
return None, new_start_mm_idx
......@@ -377,7 +377,7 @@ def generate_block_hash_extra_keys(
def hash_block_tokens(
parent_block_hash: Optional[int],
curr_block_token_ids: Sequence[int],
extra_keys: Optional[Tuple[Any, ...]] = None) -> BlockHashType:
extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHashType:
"""Computes a hash value corresponding to the contents of a block and
the contents of the preceding block(s). The hash value is used for
prefix caching. We use LRU cache for this function to avoid recomputing
......@@ -410,7 +410,7 @@ def hash_block_tokens(
def hash_request_tokens(block_size: int,
request: Request) -> List[BlockHashType]:
request: Request) -> list[BlockHashType]:
"""Computes hash values of a chain of blocks given a sequence of
token IDs. The hash value is used for prefix caching.
......@@ -554,8 +554,8 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
def get_kv_cache_configs(vllm_config: VllmConfig,
kv_cache_specs: List[KVCacheSpec],
available_memory: int) -> List[KVCacheConfig]:
kv_cache_specs: list[KVCacheSpec],
available_memory: int) -> list[KVCacheConfig]:
"""
Generates the KV cache configuration for a model
TODO: support hybrid models with more than one type of KV cache.
......
......@@ -2,7 +2,8 @@
import time
from collections import deque
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig,
SpeculativeConfig)
......@@ -57,24 +58,24 @@ class Scheduler:
self.block_size = self.cache_config.block_size
# req_id -> Request
self.requests: Dict[str, Request] = {}
self.requests: dict[str, Request] = {}
# Priority queues for requests.
self.waiting: Deque[Request] = deque()
self.running: List[Request] = []
self.waiting: deque[Request] = deque()
self.running: list[Request] = []
# The requests that have been scheduled and are being executed
# by the executor.
self.scheduled_req_ids: Set[str] = set()
self.scheduled_req_ids: set[str] = set()
# The request IDs that are finished in between the previous and the
# current steps. This is used to notify the workers about the finished
# requests so that they can free the cached states for those requests.
# This is flushed at the end of each scheduling step.
self.finished_req_ids: Set[str] = set()
self.finished_req_ids: set[str] = set()
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
# them at each scheduling step.
# Request id -> CachedRequestData
self._cached_reqs_data: Dict[str, CachedRequestData] = {}
self._cached_reqs_data: dict[str, CachedRequestData] = {}
# Encoder-related.
# Calculate encoder cache size if applicable
......@@ -108,19 +109,19 @@ class Scheduler:
# chunked prefills, prefix caching, speculative decoding,
# and the "jump decoding" optimization in the future.
scheduled_new_reqs: List[Request] = []
scheduled_resumed_reqs: List[Request] = []
scheduled_running_reqs: List[Request] = []
preempted_reqs: List[Request] = []
scheduled_new_reqs: list[Request] = []
scheduled_resumed_reqs: list[Request] = []
scheduled_running_reqs: list[Request] = []
preempted_reqs: list[Request] = []
req_to_new_block_ids: Dict[str, List[int]] = {}
num_scheduled_tokens: Dict[str, int] = {}
req_to_new_block_ids: dict[str, list[int]] = {}
num_scheduled_tokens: dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens
# Encoder-related.
scheduled_encoder_inputs: Dict[str, List[int]] = {}
scheduled_encoder_inputs: dict[str, list[int]] = {}
encoder_budget = self.max_num_encoder_input_tokens
# Spec decode-related.
scheduled_spec_decode_tokens: Dict[str, List[int]] = {}
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
# For logging.
scheduled_timestamp = time.monotonic()
......@@ -211,7 +212,7 @@ class Scheduler:
encoder_budget = new_encoder_budget
# Record the LoRAs in scheduled_running_reqs
requested_loras: Set[int] = set()
requested_loras: set[int] = set()
if self.lora_config:
requested_loras = set(
req.lora_request.lora_int_id for req in scheduled_running_reqs
......@@ -378,7 +379,7 @@ class Scheduler:
request: Request,
num_scheduled_tokens: int,
num_scheduled_spec_tokens: int,
new_block_ids: List[int],
new_block_ids: list[int],
resumed_from_preemption: bool,
) -> "CachedRequestData":
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
......@@ -407,7 +408,7 @@ class Scheduler:
num_computed_tokens: int,
num_new_tokens: int,
encoder_budget: int,
) -> Tuple[List[int], int, int]:
) -> tuple[list[int], int, int]:
"""
Determine which encoder inputs need to be scheduled in the current step,
and update `num_new_tokens` and encoder token budget accordingly.
......@@ -427,7 +428,7 @@ class Scheduler:
if not request.has_encoder_inputs():
return [], num_new_tokens, encoder_budget
encoder_inputs_to_schedule: List[int] = []
encoder_inputs_to_schedule: list[int] = []
mm_positions = request.mm_positions
assert mm_positions is not None
assert len(mm_positions) > 0
......@@ -482,8 +483,8 @@ class Scheduler:
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
new_running: List[Request] = []
outputs: List[EngineCoreOutput] = []
new_running: list[Request] = []
outputs: list[EngineCoreOutput] = []
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
# loop can be a performance bottleneck. We should do our best to avoid
......@@ -543,7 +544,7 @@ class Scheduler:
stopped = False
new_logprobs = None
new_token_ids: List[int] = []
new_token_ids: list[int] = []
if request.num_computed_tokens >= request.num_tokens:
for output_token_id in generated_token_ids:
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from vllm.lora.request import LoRARequest
......@@ -15,13 +15,13 @@ if TYPE_CHECKING:
class NewRequestData:
req_id: str
prompt_token_ids: List[int]
prompt_token_ids: list[int]
prompt: Optional[str]
mm_inputs: List["MultiModalKwargs"]
mm_hashes: List[str]
mm_positions: List["PlaceholderRange"]
mm_inputs: list["MultiModalKwargs"]
mm_hashes: list[str]
mm_positions: list["PlaceholderRange"]
sampling_params: "SamplingParams"
block_ids: List[int]
block_ids: list[int]
num_computed_tokens: int
lora_request: Optional["LoRARequest"]
......@@ -29,7 +29,7 @@ class NewRequestData:
def from_request(
cls,
request: "Request",
block_ids: List[int],
block_ids: list[int],
) -> "NewRequestData":
return cls(
req_id=request.request_id,
......@@ -53,8 +53,8 @@ class CachedRequestData:
# the request's block IDs. If True, new_block_ids will be used as the
# request's block IDs instead of appending to the existing block IDs.
resumed_from_preemption: bool
new_token_ids: List[int]
new_block_ids: List[int]
new_token_ids: list[int]
new_block_ids: list[int]
num_computed_tokens: int
@classmethod
......@@ -62,8 +62,8 @@ class CachedRequestData:
cls,
request: "Request",
resumed_from_preemption: bool,
new_token_ids: List[int],
new_block_ids: List[int],
new_token_ids: list[int],
new_block_ids: list[int],
) -> "CachedRequestData":
return cls(
req_id=request.request_id,
......@@ -77,29 +77,29 @@ class CachedRequestData:
@dataclass
class SchedulerOutput:
# List of the requests that are scheduled for the first time.
# list of the requests that are scheduled for the first time.
# We cache the request's data in each worker process, so that we don't
# need to re-send it every scheduling step.
scheduled_new_reqs: List[NewRequestData]
# List of the requests that have been scheduled before.
scheduled_new_reqs: list[NewRequestData]
# list of the requests that have been scheduled before.
# Since the request's data is already cached in the worker processes,
# we only send the diff to minimize the communication cost.
scheduled_cached_reqs: List[CachedRequestData]
scheduled_cached_reqs: list[CachedRequestData]
# req_id -> num_scheduled_tokens
# Number of tokens scheduled for each request.
num_scheduled_tokens: Dict[str, int]
num_scheduled_tokens: dict[str, int]
# Total number of tokens scheduled for all requests.
# Equal to sum(num_scheduled_tokens.values())
total_num_scheduled_tokens: int
# req_id -> spec_token_ids
# If a request does not have any spec decode tokens, it will not be
# included in the dictionary.
scheduled_spec_decode_tokens: Dict[str, List[int]]
scheduled_spec_decode_tokens: dict[str, list[int]]
# req_id -> encoder input indices that need processing.
# E.g., if a request has [0, 1], it could mean the vision encoder needs
# to process that the request's 0-th and 1-th images in the current step.
scheduled_encoder_inputs: Dict[str, List[int]]
scheduled_encoder_inputs: dict[str, list[int]]
# Number of common prefix blocks for all requests.
# This can be used for cascade attention.
num_common_prefix_blocks: int
......@@ -107,7 +107,7 @@ class SchedulerOutput:
# Request IDs that are finished in between the previous and the current
# steps. This is used to notify the workers about the finished requests
# so that they can free the cached states for those requests.
finished_req_ids: Set[str]
# List of (req_id, encoder_input_index) tuples.
finished_req_ids: set[str]
# list of (req_id, encoder_input_index) tuples.
# Used to free the encoder cache.
free_encoder_input_ids: List[Tuple[str, int]]
free_encoder_input_ids: list[tuple[str, int]]
......@@ -2,7 +2,7 @@
import enum
import time
from typing import Any, List, Optional, Union
from typing import Any, Optional, Union
import msgspec
......@@ -51,10 +51,10 @@ class EngineCoreRequest(
# NOTE(ywang96): original text prompt is needed when a request is added to
# Detokenizer, but set to None when it is added to EngineCoreClient.
prompt: Optional[str]
prompt_token_ids: List[int]
mm_inputs: Optional[List[Optional[MultiModalKwargs]]]
mm_hashes: Optional[List[str]]
mm_placeholders: Optional[List[PlaceholderRange]]
prompt_token_ids: list[int]
mm_inputs: Optional[list[Optional[MultiModalKwargs]]]
mm_hashes: Optional[list[str]]
mm_placeholders: Optional[list[PlaceholderRange]]
sampling_params: SamplingParams
eos_token_id: Optional[int]
arrival_time: float
......@@ -93,14 +93,14 @@ class EngineCoreOutput(
gc=False): # type: ignore[call-arg]
request_id: str
new_token_ids: List[int]
new_token_ids: list[int]
new_logprobs: Optional[LogprobsLists] = None
new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None
finish_reason: Optional[FinishReason] = None
stop_reason: Union[int, str, None] = None
events: Optional[List[EngineCoreEvent]] = None
events: Optional[list[EngineCoreEvent]] = None
@property
def finished(self) -> bool:
......@@ -129,7 +129,7 @@ class EngineCoreOutputs(
# e.g. columnwise layout
# [num_reqs]
outputs: List[EngineCoreOutput] = []
outputs: list[EngineCoreOutput] = []
scheduler_stats: Optional[SchedulerStats] = None
timestamp: float = 0.0
......
......@@ -2,7 +2,8 @@
import asyncio
import os
from typing import AsyncGenerator, List, Mapping, Optional, Set, Type, Union
from collections.abc import AsyncGenerator, Mapping
from typing import Optional, Union
import numpy as np
......@@ -39,7 +40,7 @@ class AsyncLLM(EngineClient):
def __init__(
self,
vllm_config: VllmConfig,
executor_class: Type[Executor],
executor_class: type[Executor],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
input_registry: InputRegistry = INPUT_REGISTRY,
......@@ -54,7 +55,7 @@ class AsyncLLM(EngineClient):
self.log_requests = log_requests
self.log_stats = log_stats
self.stat_loggers: List[StatLoggerBase] = []
self.stat_loggers: list[StatLoggerBase] = []
if self.log_stats:
self.stat_loggers.extend([
LoggingStatLogger(),
......@@ -400,7 +401,7 @@ class AsyncLLM(EngineClient):
"""Remove an already loaded LoRA adapter."""
return await self.engine_core.remove_lora_async(lora_id)
async def list_loras(self) -> Set[int]:
async def list_loras(self) -> set[int]:
"""List all registered adapters."""
return await self.engine_core.list_loras_async()
......
......@@ -7,7 +7,7 @@ import time
from concurrent.futures import Future
from inspect import isclass, signature
from multiprocessing.connection import Connection
from typing import Any, List, Optional, Set, Tuple, Type
from typing import Any, Optional
import msgspec
import psutil
......@@ -42,7 +42,7 @@ class EngineCore:
def __init__(
self,
vllm_config: VllmConfig,
executor_class: Type[Executor],
executor_class: type[Executor],
log_stats: bool,
):
assert vllm_config.model_config.runner_type != "pooling"
......@@ -80,7 +80,7 @@ class EngineCore:
# schedule and execute batches, and is required by pipeline parallelism
# to eliminate pipeline bubbles.
self.batch_queue_size = self.model_executor.max_concurrent_batches
self.batch_queue: Optional[queue.Queue[Tuple[Future[ModelRunnerOutput],
self.batch_queue: Optional[queue.Queue[tuple[Future[ModelRunnerOutput],
SchedulerOutput]]] = None
if self.batch_queue_size > 1:
logger.info("Batch queue is enabled with size %d",
......@@ -88,7 +88,7 @@ class EngineCore:
self.batch_queue = queue.Queue(self.batch_queue_size)
def _initialize_kv_caches(self,
vllm_config: VllmConfig) -> Tuple[int, int]:
vllm_config: VllmConfig) -> tuple[int, int]:
start = time.time()
# Get all kv cache needed by the model
......@@ -134,7 +134,7 @@ class EngineCore:
self.scheduler.add_request(req)
def abort_requests(self, request_ids: List[str]):
def abort_requests(self, request_ids: list[str]):
"""Abort requests from the scheduler."""
# TODO: The scheduler doesn't really need to know the
......@@ -228,7 +228,7 @@ class EngineCore:
def remove_lora(self, lora_id: int) -> bool:
return self.model_executor.remove_lora(lora_id)
def list_loras(self) -> Set[int]:
def list_loras(self) -> set[int]:
return self.model_executor.list_loras()
def pin_lora(self, lora_id: int) -> bool:
......@@ -244,7 +244,7 @@ class EngineCoreProc(EngineCore):
output_path: str,
ready_pipe: Connection,
vllm_config: VllmConfig,
executor_class: Type[Executor],
executor_class: type[Executor],
log_stats: bool,
):
super().__init__(vllm_config, executor_class, log_stats)
......@@ -254,7 +254,7 @@ class EngineCoreProc(EngineCore):
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self.input_queue: queue.Queue[Tuple[EngineCoreRequestType,
self.input_queue: queue.Queue[tuple[EngineCoreRequestType,
Any]] = queue.Queue()
self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
threading.Thread(target=self.process_input_socket,
......
......@@ -10,7 +10,7 @@ from abc import ABC, abstractmethod
from concurrent.futures import Future
from dataclasses import dataclass
from threading import Thread
from typing import Any, Dict, List, Optional, Set, Type, Union
from typing import Any, Optional, Union
import zmq
import zmq.asyncio
......@@ -48,7 +48,7 @@ class EngineCoreClient(ABC):
multiprocess_mode: bool,
asyncio_mode: bool,
vllm_config: VllmConfig,
executor_class: Type[Executor],
executor_class: type[Executor],
log_stats: bool,
) -> "EngineCoreClient":
......@@ -94,7 +94,7 @@ class EngineCoreClient(ABC):
async def execute_dummy_batch_async(self) -> None:
raise NotImplementedError
def abort_requests(self, request_ids: List[str]) -> None:
def abort_requests(self, request_ids: list[str]) -> None:
raise NotImplementedError
def add_lora(self, lora_request: LoRARequest) -> bool:
......@@ -103,7 +103,7 @@ class EngineCoreClient(ABC):
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError
def list_loras(self) -> Set[int]:
def list_loras(self) -> set[int]:
raise NotImplementedError
def pin_lora(self, lora_id: int) -> bool:
......@@ -127,7 +127,7 @@ class EngineCoreClient(ABC):
async def wake_up_async(self) -> None:
raise NotImplementedError
async def abort_requests_async(self, request_ids: List[str]) -> None:
async def abort_requests_async(self, request_ids: list[str]) -> None:
raise NotImplementedError
async def add_lora_async(self, lora_request: LoRARequest) -> bool:
......@@ -136,7 +136,7 @@ class EngineCoreClient(ABC):
async def remove_lora_async(self, lora_id: int) -> bool:
raise NotImplementedError
async def list_loras_async(self) -> Set[int]:
async def list_loras_async(self) -> set[int]:
raise NotImplementedError
async def pin_lora_async(self, lora_id: int) -> bool:
......@@ -162,7 +162,7 @@ class InprocClient(EngineCoreClient):
def add_request(self, request: EngineCoreRequest) -> None:
self.engine_core.add_request(request)
def abort_requests(self, request_ids: List[str]) -> None:
def abort_requests(self, request_ids: list[str]) -> None:
if len(request_ids) > 0:
self.engine_core.abort_requests(request_ids)
......@@ -190,7 +190,7 @@ class InprocClient(EngineCoreClient):
def remove_lora(self, lora_id: int) -> bool:
return self.engine_core.remove_lora(lora_id)
def list_loras(self) -> Set[int]:
def list_loras(self) -> set[int]:
return self.engine_core.list_loras()
def pin_lora(self, lora_id: int) -> bool:
......@@ -239,7 +239,7 @@ class MPClient(EngineCoreClient):
self,
asyncio_mode: bool,
vllm_config: VllmConfig,
executor_class: Type[Executor],
executor_class: type[Executor],
log_stats: bool,
):
# The child processes will send SIGUSR1 when unrecoverable
......@@ -293,14 +293,14 @@ class MPClient(EngineCoreClient):
self.output_socket = resources.output_socket
self.input_socket = resources.input_socket
self.utility_results: Dict[int, AnyFuture] = {}
self.utility_results: dict[int, AnyFuture] = {}
def shutdown(self):
self._finalizer()
def _process_utility_output(output: UtilityOutput,
utility_results: Dict[int, AnyFuture]):
utility_results: dict[int, AnyFuture]):
"""Set the result from a utility method in the waiting future"""
future = utility_results.pop(output.call_id)
if output.failure_message is not None:
......@@ -312,7 +312,7 @@ def _process_utility_output(output: UtilityOutput,
class SyncMPClient(MPClient):
"""Synchronous client for multi-proc EngineCore."""
def __init__(self, vllm_config: VllmConfig, executor_class: Type[Executor],
def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],
log_stats: bool):
super().__init__(
asyncio_mode=False,
......@@ -373,7 +373,7 @@ class SyncMPClient(MPClient):
request.prompt = None
self._send_input(EngineCoreRequestType.ADD, request)
def abort_requests(self, request_ids: List[str]) -> None:
def abort_requests(self, request_ids: list[str]) -> None:
if len(request_ids) > 0:
self._send_input(EngineCoreRequestType.ABORT, request_ids)
......@@ -389,7 +389,7 @@ class SyncMPClient(MPClient):
def remove_lora(self, lora_id: int) -> bool:
return self._call_utility("remove_lora", lora_id)
def list_loras(self) -> Set[int]:
def list_loras(self) -> set[int]:
return self._call_utility("list_loras")
def pin_lora(self, lora_id: int) -> bool:
......@@ -408,7 +408,7 @@ class SyncMPClient(MPClient):
class AsyncMPClient(MPClient):
"""Asyncio-compatible client for multi-proc EngineCore."""
def __init__(self, vllm_config: VllmConfig, executor_class: Type[Executor],
def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],
log_stats: bool):
super().__init__(
asyncio_mode=True,
......@@ -471,7 +471,7 @@ class AsyncMPClient(MPClient):
request.prompt = None
await self._send_input(EngineCoreRequestType.ADD, request)
async def abort_requests_async(self, request_ids: List[str]) -> None:
async def abort_requests_async(self, request_ids: list[str]) -> None:
if len(request_ids) > 0:
await self._send_input(EngineCoreRequestType.ABORT, request_ids)
......@@ -496,7 +496,7 @@ class AsyncMPClient(MPClient):
async def remove_lora_async(self, lora_id: int) -> bool:
return await self._call_utility_async("remove_lora", lora_id)
async def list_loras_async(self) -> Set[int]:
async def list_loras_async(self) -> set[int]:
return await self._call_utility_async("list_loras")
async def pin_lora_async(self, lora_id: int) -> bool:
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import List, Optional
from typing import Optional
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
......@@ -17,12 +17,12 @@ class IncrementalDetokenizer:
# Generation data
output_text: str
tokens: List[str]
token_ids: List[int]
tokens: list[str]
token_ids: list[int]
prompt_len: int
# Stop strings
stop: List[str]
stop: list[str]
include_stop_str_in_output: bool
# Metadata for incremental detokenization
......@@ -41,7 +41,7 @@ class IncrementalDetokenizer:
_last_output_text_offset: int = 0
@property
def output_token_ids(self) -> List[int]:
def output_token_ids(self) -> list[int]:
return self.token_ids[self.prompt_len:]
@classmethod
......@@ -84,7 +84,7 @@ class IncrementalDetokenizer:
stop_buffer_length=stop_buffer_length,
)
def update(self, new_token_ids: List[int]) -> Optional[str]:
def update(self, new_token_ids: list[int]) -> Optional[str]:
"""
Update RequestState for the request_id by:
1) Detokenize the new token ids incrementally.
......
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, List, Mapping, Optional, Set, Type, Union
from collections.abc import Mapping
from typing import Optional, Union
from typing_extensions import TypeVar
......@@ -36,10 +37,10 @@ class LLMEngine:
def __init__(
self,
vllm_config: VllmConfig,
executor_class: Type[Executor],
executor_class: type[Executor],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False,
......@@ -97,7 +98,7 @@ class LLMEngine:
cls,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
enable_multiprocessing: bool = False,
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
......@@ -139,7 +140,7 @@ class LLMEngine:
def validate_outputs(cls, outputs, output_type):
return outputs
def abort_request(self, request_ids: List[str]) -> None:
def abort_request(self, request_ids: list[str]) -> None:
"""Remove request_ids from EngineCore and Detokenizer."""
self.engine_core.abort_requests(request_ids)
......@@ -199,7 +200,7 @@ class LLMEngine:
# 3) Add the request to EngineCore.
self.engine_core.add_request(request)
def step(self) -> List[RequestOutput]:
def step(self) -> list[RequestOutput]:
if self.should_execute_dummy_batch:
self.should_execute_dummy_batch = False
......@@ -241,7 +242,7 @@ class LLMEngine:
def get_tokenizer_group(
self,
group_type: Type[_G] = BaseTokenizerGroup,
group_type: type[_G] = BaseTokenizerGroup,
) -> _G:
tokenizer_group = self.tokenizer
......@@ -263,7 +264,7 @@ class LLMEngine:
"""Remove an already loaded LoRA adapter."""
return self.engine_core.remove_lora(lora_id)
def list_loras(self) -> Set[int]:
def list_loras(self) -> set[int]:
"""List all registered adapters."""
return self.engine_core.list_loras()
......
......@@ -2,7 +2,7 @@
import itertools
from dataclasses import dataclass
from typing import Dict, List, Optional
from typing import Optional
from vllm.logger import init_logger
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
......@@ -151,12 +151,12 @@ class LogprobsProcessor:
@staticmethod
def _make_logprob_dict(
logprobs: List[float],
logprob_token_ids: List[int],
decoded_tokens: List[str],
logprobs: list[float],
logprob_token_ids: list[int],
decoded_tokens: list[str],
rank: int,
num_logprobs: int,
) -> Dict[int, Logprob]:
) -> dict[int, Logprob]:
"""Make a Logprob dictionary for a position.
Args:
......@@ -168,7 +168,7 @@ class LogprobsProcessor:
by the user (in addition to sampled logprob)
Returns:
Dict[token id, Logprob]
dict[token id, Logprob]
"""
# We do not need a special case for the sampled token
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from vllm.config import ModelConfig
from vllm.envs import VLLM_MM_INPUT_CACHE_SIZE
......@@ -68,10 +68,10 @@ class MMInputCacheClient:
def process_inputs(
self,
mm_data: MultiModalDataDict,
mm_hashes: Optional[List[str]],
mm_processor_kwargs: Optional[Dict[str, Any]],
precomputed_mm_inputs: Optional[List[MultiModalKwargs]],
) -> List[MultiModalKwargs]:
mm_hashes: Optional[list[str]],
mm_processor_kwargs: Optional[dict[str, Any]],
precomputed_mm_inputs: Optional[list[MultiModalKwargs]],
) -> list[MultiModalKwargs]:
if precomputed_mm_inputs is None:
image_inputs = mm_data["image"]
if not isinstance(image_inputs, list):
......@@ -88,7 +88,7 @@ class MMInputCacheClient:
# Process each image input separately, so that later we can schedule
# them in a fine-grained manner.
# Apply caching (if enabled) and reuse precomputed inputs (if provided)
ret_inputs: List[MultiModalKwargs] = []
ret_inputs: list[MultiModalKwargs] = []
for input_id in range(num_inputs):
if self.mm_debug_cache_hit_ratio_steps is not None:
self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps)
......@@ -133,9 +133,9 @@ class MMInputCacheServer:
def get_and_update(
self,
mm_inputs: List[Optional[MultiModalKwargs]],
mm_hashes: List[str],
) -> List[MultiModalKwargs]:
mm_inputs: list[Optional[MultiModalKwargs]],
mm_hashes: list[str],
) -> list[MultiModalKwargs]:
assert len(mm_inputs) == len(mm_hashes)
if not self.use_cache:
......
......@@ -2,7 +2,7 @@
import asyncio
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
from typing import Optional, Union
from vllm.outputs import RequestOutput
from vllm.sampling_params import RequestOutputKind
......@@ -18,8 +18,8 @@ from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates,
@dataclass
class OutputProcessorOutput:
request_outputs: List[RequestOutput]
reqs_to_abort: List[str]
request_outputs: list[RequestOutput]
reqs_to_abort: list[str]
class RequestState:
......@@ -30,7 +30,7 @@ class RequestState:
lora_name: Optional[str],
output_kind: RequestOutputKind,
prompt: Optional[str],
prompt_token_ids: List[int],
prompt_token_ids: list[int],
logprobs_processor: LogprobsProcessor,
detokenizer: IncrementalDetokenizer,
arrival_time: float,
......@@ -90,7 +90,7 @@ class OutputProcessor:
):
self.log_stats = log_stats
self.tokenizer = tokenizer
self.request_states: Dict[str, RequestState] = {}
self.request_states: dict[str, RequestState] = {}
self.lora_states = LoRARequestStates()
def is_request_active(self, request_id: str) -> bool:
......@@ -104,7 +104,7 @@ class OutputProcessor:
def abort_requests(
self,
request_ids: List[str],
request_ids: list[str],
) -> None:
for request_id in request_ids:
req_state = self.request_states.pop(request_id, None)
......@@ -130,7 +130,7 @@ class OutputProcessor:
def process_outputs(
self,
engine_core_outputs: List[EngineCoreOutput],
engine_core_outputs: list[EngineCoreOutput],
engine_core_timestamp: Optional[float] = None,
iteration_stats: Optional[IterationStats] = None,
) -> OutputProcessorOutput:
......@@ -158,8 +158,8 @@ class OutputProcessor:
**********************************************************
"""
request_outputs: List[RequestOutput] = []
reqs_to_abort: List[str] = []
request_outputs: list[RequestOutput] = []
reqs_to_abort: list[str] = []
for engine_core_output in engine_core_outputs:
req_id = engine_core_output.request_id
req_state = self.request_states.get(req_id)
......@@ -265,7 +265,7 @@ class OutputProcessor:
@staticmethod
def _make_request_output(
request_state: RequestState,
new_token_ids: List[int],
new_token_ids: list[int],
finish_reason: Optional[FinishReason],
stop_reason: Union[int, str, None],
) -> Optional[RequestOutput]:
......
# SPDX-License-Identifier: Apache-2.0
from collections.abc import AsyncGenerator, Mapping
from copy import copy
from typing import (AsyncGenerator, Dict, List, Mapping, Optional, Protocol,
Tuple, Union)
from typing import Optional, Protocol, Union
from vllm.inputs import PromptType
from vllm.lora.request import LoRARequest
......@@ -137,7 +137,7 @@ class ParallelSamplingRequest:
key=lambda x: x.index)
return self.request_output
def get_child_info(self, index: int) -> Tuple[str, SamplingParams]:
def get_child_info(self, index: int) -> tuple[str, SamplingParams]:
"""Get child request ID and sampling params.
Args:
......@@ -237,9 +237,9 @@ class SyncParallelSamplingManager:
def __init__(self):
# Parent req ID -> parent request manager
self.parent_reqs: Dict[str, ParallelSamplingRequest] = {}
self.parent_reqs: dict[str, ParallelSamplingRequest] = {}
# Child req ID -> (child req index, parent req ID)
self.child_reqs: Dict[str, Tuple[int, str]] = {}
self.child_reqs: dict[str, tuple[int, str]] = {}
def _register_parent_request(self, req: ParallelSamplingRequest) -> None:
"""Register parallel sampling parent request."""
......@@ -299,8 +299,8 @@ class SyncParallelSamplingManager:
def step(
self,
outputs: List[RequestOutput],
) -> List[RequestOutput]:
outputs: list[RequestOutput],
) -> list[RequestOutput]:
"""Build parallel sampling request outputs.
Extract child request outputs, aggregate them
......@@ -355,7 +355,7 @@ async def generate_parallel_sampling_async(
parent_req = ParallelSamplingRequest(request_id, sampling_params)
# Aggregate generators for n child requests
gens: List[AsyncGenerator[RequestOutput, None]] = []
gens: list[AsyncGenerator[RequestOutput, None]] = []
for idx in range(parent_req.n):
child_req_id, child_params = parent_req.get_child_info(idx)
child_gen = generate(
......
# SPDX-License-Identifier: Apache-2.0
import time
from typing import Mapping, Optional, Union
from collections.abc import Mapping
from typing import Optional, Union
from vllm.config import CacheConfig, LoRAConfig, ModelConfig
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment