Unverified Commit 1c0c6820 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Fix per file ruff ignores related to typing (#26254)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 5f317530
...@@ -289,7 +289,7 @@ class MultiModalFieldElem: ...@@ -289,7 +289,7 @@ class MultiModalFieldElem:
return ( return (
(self.modality, self.key) == (other.modality, other.key) (self.modality, self.key) == (other.modality, other.key)
and data_equal and data_equal
and type(self.field) == type(other.field) and type(self.field) is type(other.field)
) # noqa: E721 ) # noqa: E721
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Optional
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.plugins import load_plugins_by_group from vllm.plugins import load_plugins_by_group
...@@ -15,7 +14,7 @@ logger = logging.getLogger(__name__) ...@@ -15,7 +14,7 @@ logger = logging.getLogger(__name__)
def get_io_processor( def get_io_processor(
vllm_config: VllmConfig, plugin_from_init: Optional[str] = None vllm_config: VllmConfig, plugin_from_init: str | None = None
) -> IOProcessor | None: ) -> IOProcessor | None:
# Input.Output processors are loaded as plugins under the # Input.Output processors are loaded as plugins under the
# 'vllm.io_processor_plugins' group. Similar to platform # 'vllm.io_processor_plugins' group. Similar to platform
......
...@@ -68,7 +68,6 @@ from typing import ( ...@@ -68,7 +68,6 @@ from typing import (
Generic, Generic,
Literal, Literal,
NamedTuple, NamedTuple,
Optional,
TextIO, TextIO,
TypeVar, TypeVar,
Union, Union,
...@@ -247,9 +246,7 @@ class CacheInfo(NamedTuple): ...@@ -247,9 +246,7 @@ class CacheInfo(NamedTuple):
class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
def __init__( def __init__(self, capacity: float, getsizeof: Callable[[_V], float] | None = None):
self, capacity: float, getsizeof: Optional[Callable[[_V], float]] = None
):
super().__init__(capacity, getsizeof) super().__init__(capacity, getsizeof)
self.pinned_items = set[_K]() self.pinned_items = set[_K]()
...@@ -324,15 +321,15 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): ...@@ -324,15 +321,15 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
self._LRUCache__order[key] = None # type: ignore self._LRUCache__order[key] = None # type: ignore
@overload @overload
def get(self, key: _K, /) -> Optional[_V]: ... def get(self, key: _K, /) -> _V | None: ...
@overload @overload
def get(self, key: _K, /, default: Union[_V, _T]) -> Union[_V, _T]: ... def get(self, key: _K, /, default: Union[_V, _T]) -> Union[_V, _T]: ...
def get( def get(
self, key: _K, /, default: Optional[Union[_V, _T]] = None self, key: _K, /, default: Union[_V, _T] | None = None
) -> Optional[Union[_V, _T]]: ) -> Union[_V, _T] | None:
value: Optional[Union[_V, _T]] value: Union[_V, _T] | None
if key in self: if key in self:
value = self.__getitem__(key, update_info=False) # type: ignore[call-arg] value = self.__getitem__(key, update_info=False) # type: ignore[call-arg]
...@@ -350,9 +347,9 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): ...@@ -350,9 +347,9 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
def pop(self, key: _K, default: Union[_V, _T]) -> Union[_V, _T]: ... def pop(self, key: _K, default: Union[_V, _T]) -> Union[_V, _T]: ...
def pop( def pop(
self, key: _K, default: Optional[Union[_V, _T]] = None self, key: _K, default: Union[_V, _T] | None = None
) -> Optional[Union[_V, _T]]: ) -> Union[_V, _T] | None:
value: Optional[Union[_V, _T]] value: Union[_V, _T] | None
if key not in self: if key not in self:
return default return default
...@@ -379,7 +376,7 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): ...@@ -379,7 +376,7 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
""" """
self.pinned_items.remove(key) self.pinned_items.remove(key)
def _on_remove(self, key: _K, value: Optional[_V]) -> None: def _on_remove(self, key: _K, value: _V | None) -> None:
pass pass
def remove_oldest(self, *, remove_pinned: bool = False) -> None: def remove_oldest(self, *, remove_pinned: bool = False) -> None:
...@@ -705,7 +702,7 @@ def in_loop(event_loop: AbstractEventLoop) -> bool: ...@@ -705,7 +702,7 @@ def in_loop(event_loop: AbstractEventLoop) -> bool:
def make_async( def make_async(
func: Callable[P, T], executor: Optional[concurrent.futures.Executor] = None func: Callable[P, T], executor: concurrent.futures.Executor | None = None
) -> Callable[P, Awaitable[T]]: ) -> Callable[P, Awaitable[T]]:
"""Take a blocking function, and run it on in an executor thread. """Take a blocking function, and run it on in an executor thread.
...@@ -940,7 +937,7 @@ def _get_open_port() -> int: ...@@ -940,7 +937,7 @@ def _get_open_port() -> int:
return s.getsockname()[1] return s.getsockname()[1]
def find_process_using_port(port: int) -> Optional[psutil.Process]: def find_process_using_port(port: int) -> psutil.Process | None:
# TODO: We can not check for running processes with network # TODO: We can not check for running processes with network
# port on macOS. Therefore, we can not have a full graceful shutdown # port on macOS. Therefore, we can not have a full graceful shutdown
# of vLLM. For now, let's not look for processes in this case. # of vLLM. For now, let's not look for processes in this case.
...@@ -1025,8 +1022,8 @@ def _generate_random_fp8( ...@@ -1025,8 +1022,8 @@ def _generate_random_fp8(
def get_kv_cache_torch_dtype( def get_kv_cache_torch_dtype(
cache_dtype: Optional[Union[str, torch.dtype]], cache_dtype: Union[str, torch.dtype] | None,
model_dtype: Optional[Union[str, torch.dtype]] = None, model_dtype: Union[str, torch.dtype] | None = None,
) -> torch.dtype: ) -> torch.dtype:
if isinstance(cache_dtype, str): if isinstance(cache_dtype, str):
if cache_dtype == "auto": if cache_dtype == "auto":
...@@ -1053,11 +1050,11 @@ def create_kv_caches_with_random_flash( ...@@ -1053,11 +1050,11 @@ def create_kv_caches_with_random_flash(
num_layers: int, num_layers: int,
num_heads: int, num_heads: int,
head_size: int, head_size: int,
cache_dtype: Optional[Union[str, torch.dtype]], cache_dtype: Union[str, torch.dtype] | None,
model_dtype: Optional[Union[str, torch.dtype]] = None, model_dtype: Union[str, torch.dtype] | None = None,
seed: Optional[int] = None, seed: int | None = None,
device: Optional[str] = "cuda", device: str | None = "cuda",
cache_layout: Optional[str] = "NHD", cache_layout: str | None = "NHD",
) -> tuple[list[torch.Tensor], list[torch.Tensor]]: ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -1095,10 +1092,10 @@ def create_kv_caches_with_random( ...@@ -1095,10 +1092,10 @@ def create_kv_caches_with_random(
num_layers: int, num_layers: int,
num_heads: int, num_heads: int,
head_size: int, head_size: int,
cache_dtype: Optional[Union[str, torch.dtype]], cache_dtype: Union[str, torch.dtype] | None,
model_dtype: Optional[Union[str, torch.dtype]] = None, model_dtype: Union[str, torch.dtype] | None = None,
seed: Optional[int] = None, seed: int | None = None,
device: Optional[str] = "cuda", device: str | None = "cuda",
) -> tuple[list[torch.Tensor], list[torch.Tensor]]: ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
if cache_dtype == "fp8" and head_size % 16: if cache_dtype == "fp8" and head_size % 16:
raise ValueError( raise ValueError(
...@@ -1156,7 +1153,7 @@ def is_uva_available() -> bool: ...@@ -1156,7 +1153,7 @@ def is_uva_available() -> bool:
class DeviceMemoryProfiler: class DeviceMemoryProfiler:
def __init__(self, device: Optional[torch.types.Device] = None): def __init__(self, device: torch.types.Device | None = None):
self.device = device self.device = device
def current_memory_usage(self) -> float: def current_memory_usage(self) -> float:
...@@ -1184,7 +1181,7 @@ def make_ndarray_with_pad( ...@@ -1184,7 +1181,7 @@ def make_ndarray_with_pad(
pad: T, pad: T,
dtype: npt.DTypeLike, dtype: npt.DTypeLike,
*, *,
max_len: Optional[int] = None, max_len: int | None = None,
) -> npt.NDArray: ) -> npt.NDArray:
""" """
Make a padded array from 2D inputs. Make a padded array from 2D inputs.
...@@ -1209,8 +1206,8 @@ def make_tensor_with_pad( ...@@ -1209,8 +1206,8 @@ def make_tensor_with_pad(
pad: T, pad: T,
dtype: torch.dtype, dtype: torch.dtype,
*, *,
max_len: Optional[int] = None, max_len: int | None = None,
device: Optional[Union[str, torch.device]] = None, device: Union[str, torch.device] | None = None,
pin_memory: bool = False, pin_memory: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
...@@ -1405,7 +1402,7 @@ def find_nccl_library() -> str: ...@@ -1405,7 +1402,7 @@ def find_nccl_library() -> str:
return so_file return so_file
def find_nccl_include_paths() -> Optional[list[str]]: def find_nccl_include_paths() -> list[str] | None:
""" """
We either use the nccl.h specified by the `VLLM_NCCL_INCLUDE_PATH` We either use the nccl.h specified by the `VLLM_NCCL_INCLUDE_PATH`
environment variable, or we find the library file brought by environment variable, or we find the library file brought by
...@@ -1525,7 +1522,7 @@ F = TypeVar("F", bound=Callable[..., Any]) ...@@ -1525,7 +1522,7 @@ F = TypeVar("F", bound=Callable[..., Any])
def deprecate_args( def deprecate_args(
start_index: int, start_index: int,
is_deprecated: Union[bool, Callable[[], bool]] = True, is_deprecated: Union[bool, Callable[[], bool]] = True,
additional_message: Optional[str] = None, additional_message: str | None = None,
) -> Callable[[F], F]: ) -> Callable[[F], F]:
if not callable(is_deprecated): if not callable(is_deprecated):
is_deprecated = partial(identity, is_deprecated) is_deprecated = partial(identity, is_deprecated)
...@@ -1565,7 +1562,7 @@ def deprecate_args( ...@@ -1565,7 +1562,7 @@ def deprecate_args(
def deprecate_kwargs( def deprecate_kwargs(
*kws: str, *kws: str,
is_deprecated: Union[bool, Callable[[], bool]] = True, is_deprecated: Union[bool, Callable[[], bool]] = True,
additional_message: Optional[str] = None, additional_message: str | None = None,
) -> Callable[[F], F]: ) -> Callable[[F], F]:
deprecated_kws = set(kws) deprecated_kws = set(kws)
...@@ -1598,7 +1595,7 @@ def deprecate_kwargs( ...@@ -1598,7 +1595,7 @@ def deprecate_kwargs(
@lru_cache(maxsize=8) @lru_cache(maxsize=8)
def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) -> int: def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int:
# Note: cuda_visible_devices is not used, but we keep it as an argument for # Note: cuda_visible_devices is not used, but we keep it as an argument for
# LRU Cache purposes. # LRU Cache purposes.
...@@ -1746,7 +1743,7 @@ class FlexibleArgumentParser(ArgumentParser): ...@@ -1746,7 +1743,7 @@ class FlexibleArgumentParser(ArgumentParser):
' --json-arg \'{"key4": ["value3", "value4", "value5"]}\'\n' ' --json-arg \'{"key4": ["value3", "value4", "value5"]}\'\n'
" --json-arg.key4+ value3 --json-arg.key4+='value4,value5'\n\n" " --json-arg.key4+ value3 --json-arg.key4+='value4,value5'\n\n"
) )
_search_keyword: Optional[str] = None _search_keyword: str | None = None
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
# Set the default "formatter_class" to SortedHelpFormatter # Set the default "formatter_class" to SortedHelpFormatter
...@@ -2245,7 +2242,7 @@ def supports_kw( ...@@ -2245,7 +2242,7 @@ def supports_kw(
def get_allowed_kwarg_only_overrides( def get_allowed_kwarg_only_overrides(
callable: Callable[..., object], callable: Callable[..., object],
overrides: Optional[Mapping[str, object]], overrides: Mapping[str, object] | None,
*, *,
requires_kw_only: bool = True, requires_kw_only: bool = True,
allow_var_kwargs: bool = False, allow_var_kwargs: bool = False,
...@@ -2695,10 +2692,10 @@ vllm_lib = Library("vllm", "FRAGMENT") # noqa ...@@ -2695,10 +2692,10 @@ vllm_lib = Library("vllm", "FRAGMENT") # noqa
def direct_register_custom_op( def direct_register_custom_op(
op_name: str, op_name: str,
op_func: Callable, op_func: Callable,
mutates_args: Optional[list[str]] = None, mutates_args: list[str] | None = None,
fake_impl: Optional[Callable] = None, fake_impl: Callable | None = None,
target_lib: Optional[Library] = None, target_lib: Library | None = None,
dispatch_key: Optional[str] = None, dispatch_key: str | None = None,
tags: tuple[torch.Tag, ...] = (), tags: tuple[torch.Tag, ...] = (),
): ):
""" """
...@@ -3016,7 +3013,7 @@ def split_zmq_path(path: str) -> tuple[str, str, str]: ...@@ -3016,7 +3013,7 @@ def split_zmq_path(path: str) -> tuple[str, str, str]:
return scheme, host, port return scheme, host, port
def make_zmq_path(scheme: str, host: str, port: Optional[int] = None) -> str: def make_zmq_path(scheme: str, host: str, port: int | None = None) -> str:
"""Make a ZMQ path from its parts. """Make a ZMQ path from its parts.
Args: Args:
...@@ -3039,9 +3036,9 @@ def make_zmq_socket( ...@@ -3039,9 +3036,9 @@ def make_zmq_socket(
ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined] ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined]
path: str, path: str,
socket_type: Any, socket_type: Any,
bind: Optional[bool] = None, bind: bool | None = None,
identity: Optional[bytes] = None, identity: bytes | None = None,
linger: Optional[int] = None, linger: int | None = None,
) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined] ) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined]
"""Make a ZMQ socket with the proper bind/connect semantics.""" """Make a ZMQ socket with the proper bind/connect semantics."""
...@@ -3098,9 +3095,9 @@ def make_zmq_socket( ...@@ -3098,9 +3095,9 @@ def make_zmq_socket(
def zmq_socket_ctx( def zmq_socket_ctx(
path: str, path: str,
socket_type: Any, socket_type: Any,
bind: Optional[bool] = None, bind: bool | None = None,
linger: int = 0, linger: int = 0,
identity: Optional[bytes] = None, identity: bytes | None = None,
) -> Iterator[zmq.Socket]: ) -> Iterator[zmq.Socket]:
"""Context manager for a ZMQ socket""" """Context manager for a ZMQ socket"""
...@@ -3163,7 +3160,7 @@ def get_mp_context(): ...@@ -3163,7 +3160,7 @@ def get_mp_context():
def bind_kv_cache( def bind_kv_cache(
ctx: dict[str, Any], ctx: dict[str, Any],
kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index] kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index]
shared_kv_cache_layers: Optional[dict[str, str]] = None, shared_kv_cache_layers: dict[str, str] | None = None,
) -> None: ) -> None:
# Bind the kv_cache tensor to Attention modules, similar to # Bind the kv_cache tensor to Attention modules, similar to
# ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)] # ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
...@@ -3379,7 +3376,7 @@ def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None: ...@@ -3379,7 +3376,7 @@ def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None:
@contextlib.contextmanager @contextlib.contextmanager
def cprofile_context(save_file: Optional[str] = None): def cprofile_context(save_file: str | None = None):
"""Run a cprofile """Run a cprofile
Args: Args:
...@@ -3401,7 +3398,7 @@ def cprofile_context(save_file: Optional[str] = None): ...@@ -3401,7 +3398,7 @@ def cprofile_context(save_file: Optional[str] = None):
prof.print_stats(sort="cumtime") prof.print_stats(sort="cumtime")
def cprofile(save_file: Optional[str] = None, enabled: bool = True): def cprofile(save_file: str | None = None, enabled: bool = True):
"""Decorator to profile a Python method using cProfile. """Decorator to profile a Python method using cProfile.
Args: Args:
...@@ -3608,7 +3605,7 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None: ...@@ -3608,7 +3605,7 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
file.write = write_with_prefix # type: ignore[method-assign] file.write = write_with_prefix # type: ignore[method-assign]
def decorate_logs(process_name: Optional[str] = None) -> None: def decorate_logs(process_name: str | None = None) -> None:
""" """
Adds a process-specific prefix to each line of output written to stdout and Adds a process-specific prefix to each line of output written to stdout and
stderr. stderr.
...@@ -3631,8 +3628,8 @@ def decorate_logs(process_name: Optional[str] = None) -> None: ...@@ -3631,8 +3628,8 @@ def decorate_logs(process_name: Optional[str] = None) -> None:
def length_from_prompt_token_ids_or_embeds( def length_from_prompt_token_ids_or_embeds(
prompt_token_ids: Optional[list[int]], prompt_token_ids: list[int] | None,
prompt_embeds: Optional[torch.Tensor], prompt_embeds: torch.Tensor | None,
) -> int: ) -> int:
"""Calculate the request length (in number of tokens) give either """Calculate the request length (in number of tokens) give either
prompt_token_ids or prompt_embeds. prompt_token_ids or prompt_embeds.
......
...@@ -10,7 +10,7 @@ from __future__ import annotations ...@@ -10,7 +10,7 @@ from __future__ import annotations
import functools import functools
import importlib import importlib
import os import os
from typing import Any, Callable, NoReturn, Optional from typing import Any, Callable, NoReturn
import torch import torch
...@@ -325,7 +325,7 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): ...@@ -325,7 +325,7 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor):
def should_use_deepgemm_for_fp8_linear( def should_use_deepgemm_for_fp8_linear(
output_dtype: torch.dtype, output_dtype: torch.dtype,
weight: torch.Tensor, weight: torch.Tensor,
supports_deep_gemm: Optional[bool] = None, supports_deep_gemm: bool | None = None,
): ):
if supports_deep_gemm is None: if supports_deep_gemm is None:
supports_deep_gemm = is_deep_gemm_supported() supports_deep_gemm = is_deep_gemm_supported()
......
...@@ -12,7 +12,7 @@ import functools ...@@ -12,7 +12,7 @@ import functools
import importlib import importlib
import importlib.util import importlib.util
import os import os
from typing import Any, Callable, NoReturn, Optional from typing import Any, Callable, NoReturn
import requests import requests
import torch import torch
...@@ -202,14 +202,14 @@ def supports_trtllm_attention() -> bool: ...@@ -202,14 +202,14 @@ def supports_trtllm_attention() -> bool:
@functools.cache @functools.cache
def _force_use_trtllm_attention(env_value: Optional[bool]) -> Optional[bool]: def _force_use_trtllm_attention(env_value: bool | None) -> bool | None:
"""Cache the env value for VLLM_USE_TRTLLM_ATTENTION""" """Cache the env value for VLLM_USE_TRTLLM_ATTENTION"""
if env_value is not None: if env_value is not None:
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value) logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
return env_value return env_value
def force_use_trtllm_attention() -> Optional[bool]: def force_use_trtllm_attention() -> bool | None:
""" """
Return ``None`` if VLLM_USE_TRTLLM_ATTENTION is not set, Return ``None`` if VLLM_USE_TRTLLM_ATTENTION is not set,
return ``True`` if TRTLLM attention is forced to be used, return ``True`` if TRTLLM attention is forced to be used,
...@@ -401,7 +401,7 @@ def flashinfer_scaled_fp8_mm( ...@@ -401,7 +401,7 @@ def flashinfer_scaled_fp8_mm(
scale_a: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor, scale_b: torch.Tensor,
out_dtype: torch.dtype, out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert a.ndim == 2 and b.ndim == 2 assert a.ndim == 2 and b.ndim == 2
assert a.shape[1] == b.shape[0] assert a.shape[1] == b.shape[0]
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import ClassVar, Optional, Union from typing import ClassVar, Union
import numpy as np import numpy as np
import torch import torch
...@@ -254,12 +254,12 @@ class FlashInferMetadata: ...@@ -254,12 +254,12 @@ class FlashInferMetadata:
# For cascade attention (CPU for planning). # For cascade attention (CPU for planning).
use_cascade: bool use_cascade: bool
prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper | None = None
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None decode_wrapper: BatchDecodeWithPagedKVCacheWrapper | None = None
cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None cascade_wrapper: MultiLevelCascadeAttentionWrapper | None = None
qo_indptr_gpu: Optional[torch.Tensor] = None qo_indptr_gpu: torch.Tensor | None = None
paged_kv_indptr_gpu: Optional[torch.Tensor] = None paged_kv_indptr_gpu: torch.Tensor | None = None
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...@@ -727,13 +727,13 @@ class FlashInferImpl(AttentionImpl): ...@@ -727,13 +727,13 @@ class FlashInferImpl(AttentionImpl):
head_size: int, head_size: int,
scale: float, scale: float,
num_kv_heads: int, num_kv_heads: int,
alibi_slopes: Optional[list[float]], alibi_slopes: list[float] | None,
sliding_window: Optional[int], sliding_window: int | None,
kv_cache_dtype: str, kv_cache_dtype: str,
logits_soft_cap: Optional[float] = None, logits_soft_cap: float | None = None,
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None, kv_sharing_target_layer_name: int | None = None,
sinks: Optional[torch.Tensor] = None, sinks: torch.Tensor | None = None,
) -> None: ) -> None:
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
...@@ -763,7 +763,7 @@ class FlashInferImpl(AttentionImpl): ...@@ -763,7 +763,7 @@ class FlashInferImpl(AttentionImpl):
"FlashInferImpl" "FlashInferImpl"
) )
self.sinks: Optional[torch.Tensor] = None self.sinks: torch.Tensor | None = None
if sinks is not None: if sinks is not None:
if sinks.shape[0] != num_heads: if sinks.shape[0] != num_heads:
raise ValueError( raise ValueError(
...@@ -776,9 +776,9 @@ class FlashInferImpl(AttentionImpl): ...@@ -776,9 +776,9 @@ class FlashInferImpl(AttentionImpl):
self.support_trtllm_attn = ( self.support_trtllm_attn = (
supports_trtllm_attention() and num_heads % num_kv_heads == 0 supports_trtllm_attention() and num_heads % num_kv_heads == 0
) )
self.bmm1_scale: Optional[float] = None self.bmm1_scale: float | None = None
self.bmm2_scale: Optional[float] = None self.bmm2_scale: float | None = None
self.o_sf_scale: Optional[float] = None self.o_sf_scale: float | None = None
def fused_output_quant_supported(self, quant_key: QuantKey): def fused_output_quant_supported(self, quant_key: QuantKey):
return ( return (
...@@ -795,9 +795,9 @@ class FlashInferImpl(AttentionImpl): ...@@ -795,9 +795,9 @@ class FlashInferImpl(AttentionImpl):
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: FlashInferMetadata, attn_metadata: FlashInferMetadata,
output: Optional[torch.Tensor] = None, output: torch.Tensor | None = None,
output_scale: Optional[torch.Tensor] = None, output_scale: torch.Tensor | None = None,
output_block_scale: Optional[torch.Tensor] = None, output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashInfer. """Forward pass with FlashInfer.
...@@ -1093,13 +1093,13 @@ def fast_plan_decode( ...@@ -1093,13 +1093,13 @@ def fast_plan_decode(
page_size: int, page_size: int,
pos_encoding_mode: str = "NONE", pos_encoding_mode: str = "NONE",
window_left: int = -1, window_left: int = -1,
logits_soft_cap: Optional[float] = None, logits_soft_cap: float | None = None,
q_data_type: Optional[Union[str, torch.dtype]] = "float16", q_data_type: Union[str, torch.dtype] | None = "float16",
kv_data_type: Optional[Union[str, torch.dtype]] = None, kv_data_type: Union[str, torch.dtype] | None = None,
data_type: Optional[Union[str, torch.dtype]] = None, data_type: Union[str, torch.dtype] | None = None,
sm_scale: Optional[float] = None, sm_scale: float | None = None,
rope_scale: Optional[float] = None, rope_scale: float | None = None,
rope_theta: Optional[float] = None, rope_theta: float | None = None,
non_blocking: bool = True, non_blocking: bool = True,
) -> None: ) -> None:
""" """
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
from vllm._bc_linter import bc_linter_include from vllm._bc_linter import bc_linter_include
...@@ -25,14 +25,14 @@ if TYPE_CHECKING: ...@@ -25,14 +25,14 @@ if TYPE_CHECKING:
@dataclass @dataclass
class NewRequestData: class NewRequestData:
req_id: str req_id: str
prompt_token_ids: Optional[list[int]] prompt_token_ids: list[int] | None
mm_features: list[MultiModalFeatureSpec] mm_features: list[MultiModalFeatureSpec]
sampling_params: Optional[SamplingParams] sampling_params: SamplingParams | None
pooling_params: Optional[PoolingParams] pooling_params: PoolingParams | None
block_ids: tuple[list[int], ...] block_ids: tuple[list[int], ...]
num_computed_tokens: int num_computed_tokens: int
lora_request: Optional[LoRARequest] lora_request: LoRARequest | None
prompt_embeds: Optional[torch.Tensor] = None prompt_embeds: torch.Tensor | None = None
@classmethod @classmethod
def from_request( def from_request(
...@@ -98,7 +98,7 @@ class CachedRequestData: ...@@ -98,7 +98,7 @@ class CachedRequestData:
# NOTE(woosuk): new_token_ids is only used for pipeline parallelism. # NOTE(woosuk): new_token_ids is only used for pipeline parallelism.
# When PP is not used, new_token_ids will be empty. # When PP is not used, new_token_ids will be empty.
new_token_ids: list[list[int]] new_token_ids: list[list[int]]
new_block_ids: list[Optional[tuple[list[int], ...]]] new_block_ids: list[tuple[list[int], ...] | None]
num_computed_tokens: list[int] num_computed_tokens: list[int]
num_output_tokens: list[int] num_output_tokens: list[int]
...@@ -160,7 +160,7 @@ class SchedulerOutput: ...@@ -160,7 +160,7 @@ class SchedulerOutput:
# for filling the next token bitmask # for filling the next token bitmask
structured_output_request_ids: dict[str, int] structured_output_request_ids: dict[str, int]
# the bitmask for the whole batch # the bitmask for the whole batch
grammar_bitmask: Optional[npt.NDArray[np.int32]] grammar_bitmask: npt.NDArray[np.int32] | None
# KV Cache Connector metadata. # KV Cache Connector metadata.
kv_connector_metadata: Optional[KVConnectorMetadata] = None kv_connector_metadata: KVConnectorMetadata | None = None
...@@ -7,7 +7,7 @@ import itertools ...@@ -7,7 +7,7 @@ import itertools
import time import time
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable from collections.abc import Iterable
from typing import Any, Optional, Union from typing import Any, Union
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
...@@ -64,7 +64,7 @@ class Scheduler(SchedulerInterface): ...@@ -64,7 +64,7 @@ class Scheduler(SchedulerInterface):
# request ids should be included in the EngineCoreOutputs returned # request ids should be included in the EngineCoreOutputs returned
# by update_from_outputs(). This is currently used in the multi-engine # by update_from_outputs(). This is currently used in the multi-engine
# case to track request lifetimes efficiently. # case to track request lifetimes efficiently.
self.finished_req_ids_dict: Optional[dict[int, set[str]]] = ( self.finished_req_ids_dict: dict[int, set[str]] | None = (
defaultdict(set) if include_finished_set else None defaultdict(set) if include_finished_set else None
) )
...@@ -708,7 +708,7 @@ class Scheduler(SchedulerInterface): ...@@ -708,7 +708,7 @@ class Scheduler(SchedulerInterface):
) -> CachedRequestData: ) -> CachedRequestData:
req_ids: list[str] = [] req_ids: list[str] = []
new_token_ids: list[list[int]] = [] new_token_ids: list[list[int]] = []
new_block_ids: list[Optional[tuple[list[int], ...]]] = [] new_block_ids: list[tuple[list[int], ...] | None] = []
num_computed_tokens: list[int] = [] num_computed_tokens: list[int] = []
num_output_tokens: list[int] = [] num_output_tokens: list[int] = []
...@@ -921,7 +921,7 @@ class Scheduler(SchedulerInterface): ...@@ -921,7 +921,7 @@ class Scheduler(SchedulerInterface):
kv_connector_output = model_runner_output.kv_connector_output kv_connector_output = model_runner_output.kv_connector_output
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
spec_decoding_stats: Optional[SpecDecodingStats] = None spec_decoding_stats: SpecDecodingStats | None = None
kv_connector_stats = ( kv_connector_stats = (
kv_connector_output.kv_connector_stats if kv_connector_output else None kv_connector_output.kv_connector_stats if kv_connector_output else None
) )
...@@ -1212,7 +1212,7 @@ class Scheduler(SchedulerInterface): ...@@ -1212,7 +1212,7 @@ class Scheduler(SchedulerInterface):
request.status = finished_status request.status = finished_status
self._free_request(request) self._free_request(request)
def _free_request(self, request: Request) -> Optional[dict[str, Any]]: def _free_request(self, request: Request) -> dict[str, Any] | None:
assert request.is_finished() assert request.is_finished()
delay_free_blocks, kv_xfer_params = self._connector_finished(request) delay_free_blocks, kv_xfer_params = self._connector_finished(request)
...@@ -1243,9 +1243,9 @@ class Scheduler(SchedulerInterface): ...@@ -1243,9 +1243,9 @@ class Scheduler(SchedulerInterface):
def make_stats( def make_stats(
self, self,
spec_decoding_stats: Optional[SpecDecodingStats] = None, spec_decoding_stats: SpecDecodingStats | None = None,
kv_connector_stats: Optional[KVConnectorStats] = None, kv_connector_stats: KVConnectorStats | None = None,
) -> Optional[SchedulerStats]: ) -> SchedulerStats | None:
if not self.log_stats: if not self.log_stats:
return None return None
prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats() prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats()
...@@ -1262,10 +1262,10 @@ class Scheduler(SchedulerInterface): ...@@ -1262,10 +1262,10 @@ class Scheduler(SchedulerInterface):
def make_spec_decoding_stats( def make_spec_decoding_stats(
self, self,
spec_decoding_stats: Optional[SpecDecodingStats], spec_decoding_stats: SpecDecodingStats | None,
num_draft_tokens: int, num_draft_tokens: int,
num_accepted_tokens: int, num_accepted_tokens: int,
) -> Optional[SpecDecodingStats]: ) -> SpecDecodingStats | None:
if not self.log_stats: if not self.log_stats:
return None return None
if spec_decoding_stats is None: if spec_decoding_stats is None:
...@@ -1285,12 +1285,12 @@ class Scheduler(SchedulerInterface): ...@@ -1285,12 +1285,12 @@ class Scheduler(SchedulerInterface):
# KV Connector Related Methods # KV Connector Related Methods
######################################################################## ########################################################################
def get_kv_connector(self) -> Optional[KVConnectorBase_V1]: def get_kv_connector(self) -> KVConnectorBase_V1 | None:
return self.connector return self.connector
def _connector_finished( def _connector_finished(
self, request: Request self, request: Request
) -> tuple[bool, Optional[dict[str, Any]]]: ) -> tuple[bool, dict[str, Any] | None]:
""" """
Invoke the KV connector request_finished() method if applicable. Invoke the KV connector request_finished() method if applicable.
......
...@@ -4,7 +4,7 @@ from __future__ import annotations ...@@ -4,7 +4,7 @@ from __future__ import annotations
import multiprocessing import multiprocessing
from concurrent.futures import Future, ThreadPoolExecutor from concurrent.futures import Future, ThreadPoolExecutor
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -35,11 +35,11 @@ class StructuredOutputManager: ...@@ -35,11 +35,11 @@ class StructuredOutputManager:
"""Engine-level manager for structured output requests.""" """Engine-level manager for structured output requests."""
def __init__(self, vllm_config: VllmConfig): def __init__(self, vllm_config: VllmConfig):
self.backend: Optional[StructuredOutputBackend] = None self.backend: StructuredOutputBackend | None = None
self.reasoner: Optional[ReasoningParser] = None self.reasoner: ReasoningParser | None = None
self.vllm_config = vllm_config self.vllm_config = vllm_config
self._grammar_bitmask: Optional[torch.Tensor] = None self._grammar_bitmask: torch.Tensor | None = None
self._full_mask = torch.tensor(-1, dtype=torch.int32) self._full_mask = torch.tensor(-1, dtype=torch.int32)
max_batch_size = self.vllm_config.scheduler_config.max_num_seqs max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
...@@ -168,7 +168,7 @@ class StructuredOutputManager: ...@@ -168,7 +168,7 @@ class StructuredOutputManager:
requests: dict[str, Request], requests: dict[str, Request],
structured_output_request_ids: dict[str, int], structured_output_request_ids: dict[str, int],
scheduled_spec_decode_tokens: dict[str, list[int]], scheduled_spec_decode_tokens: dict[str, list[int]],
) -> Optional[npt.NDArray[np.int32]]: ) -> npt.NDArray[np.int32] | None:
# Prepare the structured output bitmask for this batch. # Prepare the structured output bitmask for this batch.
if not structured_output_request_ids: if not structured_output_request_ids:
return None return None
......
...@@ -7,7 +7,7 @@ import copy ...@@ -7,7 +7,7 @@ import copy
import json import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Union
import torch import torch
...@@ -252,7 +252,7 @@ def serialize_guidance_grammar( ...@@ -252,7 +252,7 @@ def serialize_guidance_grammar(
def validate_guidance_grammar( def validate_guidance_grammar(
sampling_params: SamplingParams, tokenizer: Optional[llguidance.LLTokenizer] = None sampling_params: SamplingParams, tokenizer: llguidance.LLTokenizer | None = None
) -> None: ) -> None:
tp, grm = get_structured_output_key(sampling_params) tp, grm = get_structured_output_key(sampling_params)
guidance_grm = serialize_guidance_grammar(tp, grm) guidance_grm = serialize_guidance_grammar(tp, grm)
......
...@@ -20,10 +20,10 @@ from vllm.v1.structured_output.backend_types import ( ...@@ -20,10 +20,10 @@ from vllm.v1.structured_output.backend_types import (
@dataclasses.dataclass @dataclasses.dataclass
class StructuredOutputRequest: class StructuredOutputRequest:
sampling_params: SamplingParams sampling_params: SamplingParams
_grammar: Optional[ _grammar: Union[Future[StructuredOutputGrammar], StructuredOutputGrammar] | None = (
Union[Future[StructuredOutputGrammar], StructuredOutputGrammar] None
] = None )
reasoning_ended: Optional[bool] = None reasoning_ended: bool | None = None
def _check_grammar_completion(self) -> bool: def _check_grammar_completion(self) -> bool:
# NOTE: We have to lazy import to gate circular imports # NOTE: We have to lazy import to gate circular imports
...@@ -43,7 +43,7 @@ class StructuredOutputRequest: ...@@ -43,7 +43,7 @@ class StructuredOutputRequest:
return self._check_grammar_completion() return self._check_grammar_completion()
@property @property
def grammar(self) -> Optional[StructuredOutputGrammar]: def grammar(self) -> StructuredOutputGrammar | None:
completed = self._check_grammar_completion() completed = self._check_grammar_completion()
return ( return (
cast(Optional[StructuredOutputGrammar], self._grammar) cast(Optional[StructuredOutputGrammar], self._grammar)
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
from __future__ import annotations from __future__ import annotations
import os import os
from typing import Any, Callable, Optional, TypeVar, Union from typing import Any, Callable, TypeVar, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -78,8 +78,8 @@ class WorkerBase: ...@@ -78,8 +78,8 @@ class WorkerBase:
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
# Device and model state # Device and model state
self.device: Optional[torch.device] = None self.device: torch.device | None = None
self.model_runner: Optional[nn.Module] = None self.model_runner: nn.Module | None = None
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"""Get specifications for KV cache implementation.""" """Get specifications for KV cache implementation."""
...@@ -115,8 +115,8 @@ class WorkerBase: ...@@ -115,8 +115,8 @@ class WorkerBase:
raise NotImplementedError raise NotImplementedError
def execute_model( def execute_model(
self, execute_model_req: Optional[ExecuteModelRequest] = None self, execute_model_req: ExecuteModelRequest | None = None
) -> Optional[list[SamplerOutput]]: ) -> list[SamplerOutput] | None:
raise NotImplementedError raise NotImplementedError
def start_worker_execution_loop(self) -> None: def start_worker_execution_loop(self) -> None:
...@@ -198,8 +198,8 @@ class WorkerWrapperBase: ...@@ -198,8 +198,8 @@ class WorkerWrapperBase:
group. group.
""" """
self.rpc_rank = rpc_rank self.rpc_rank = rpc_rank
self.worker: Optional[WorkerBase] = None self.worker: WorkerBase | None = None
self.vllm_config: Optional[VllmConfig] = None self.vllm_config: VllmConfig | None = None
# do not store this `vllm_config`, `init_worker` will set the final # do not store this `vllm_config`, `init_worker` will set the final
# one. TODO: investigate if we can remove this field in # one. TODO: investigate if we can remove this field in
# `WorkerWrapperBase`, `init_cached_hf_modules` should be # `WorkerWrapperBase`, `init_cached_hf_modules` should be
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment