"vllm/vscode:/vscode.git/clone" did not exist on "e4bf6ed90d47a16f85d44ac225210884aa291f8c"
Unverified Commit 1c0c6820 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Fix per file ruff ignores related to typing (#26254)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 5f317530
......@@ -289,7 +289,7 @@ class MultiModalFieldElem:
return (
(self.modality, self.key) == (other.modality, other.key)
and data_equal
and type(self.field) == type(other.field)
and type(self.field) is type(other.field)
) # noqa: E721
......
......@@ -4,7 +4,6 @@
from __future__ import annotations
import logging
from typing import Optional
from vllm.config import VllmConfig
from vllm.plugins import load_plugins_by_group
......@@ -15,7 +14,7 @@ logger = logging.getLogger(__name__)
def get_io_processor(
vllm_config: VllmConfig, plugin_from_init: Optional[str] = None
vllm_config: VllmConfig, plugin_from_init: str | None = None
) -> IOProcessor | None:
# Input.Output processors are loaded as plugins under the
# 'vllm.io_processor_plugins' group. Similar to platform
......
......@@ -68,7 +68,6 @@ from typing import (
Generic,
Literal,
NamedTuple,
Optional,
TextIO,
TypeVar,
Union,
......@@ -247,9 +246,7 @@ class CacheInfo(NamedTuple):
class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
def __init__(
self, capacity: float, getsizeof: Optional[Callable[[_V], float]] = None
):
def __init__(self, capacity: float, getsizeof: Callable[[_V], float] | None = None):
super().__init__(capacity, getsizeof)
self.pinned_items = set[_K]()
......@@ -324,15 +321,15 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
self._LRUCache__order[key] = None # type: ignore
@overload
def get(self, key: _K, /) -> Optional[_V]: ...
def get(self, key: _K, /) -> _V | None: ...
@overload
def get(self, key: _K, /, default: Union[_V, _T]) -> Union[_V, _T]: ...
def get(
self, key: _K, /, default: Optional[Union[_V, _T]] = None
) -> Optional[Union[_V, _T]]:
value: Optional[Union[_V, _T]]
self, key: _K, /, default: Union[_V, _T] | None = None
) -> Union[_V, _T] | None:
value: Union[_V, _T] | None
if key in self:
value = self.__getitem__(key, update_info=False) # type: ignore[call-arg]
......@@ -350,9 +347,9 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
def pop(self, key: _K, default: Union[_V, _T]) -> Union[_V, _T]: ...
def pop(
self, key: _K, default: Optional[Union[_V, _T]] = None
) -> Optional[Union[_V, _T]]:
value: Optional[Union[_V, _T]]
self, key: _K, default: Union[_V, _T] | None = None
) -> Union[_V, _T] | None:
value: Union[_V, _T] | None
if key not in self:
return default
......@@ -379,7 +376,7 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
"""
self.pinned_items.remove(key)
def _on_remove(self, key: _K, value: Optional[_V]) -> None:
def _on_remove(self, key: _K, value: _V | None) -> None:
pass
def remove_oldest(self, *, remove_pinned: bool = False) -> None:
......@@ -705,7 +702,7 @@ def in_loop(event_loop: AbstractEventLoop) -> bool:
def make_async(
func: Callable[P, T], executor: Optional[concurrent.futures.Executor] = None
func: Callable[P, T], executor: concurrent.futures.Executor | None = None
) -> Callable[P, Awaitable[T]]:
"""Take a blocking function, and run it on in an executor thread.
......@@ -940,7 +937,7 @@ def _get_open_port() -> int:
return s.getsockname()[1]
def find_process_using_port(port: int) -> Optional[psutil.Process]:
def find_process_using_port(port: int) -> psutil.Process | None:
# TODO: We can not check for running processes with network
# port on macOS. Therefore, we can not have a full graceful shutdown
# of vLLM. For now, let's not look for processes in this case.
......@@ -1025,8 +1022,8 @@ def _generate_random_fp8(
def get_kv_cache_torch_dtype(
cache_dtype: Optional[Union[str, torch.dtype]],
model_dtype: Optional[Union[str, torch.dtype]] = None,
cache_dtype: Union[str, torch.dtype] | None,
model_dtype: Union[str, torch.dtype] | None = None,
) -> torch.dtype:
if isinstance(cache_dtype, str):
if cache_dtype == "auto":
......@@ -1053,11 +1050,11 @@ def create_kv_caches_with_random_flash(
num_layers: int,
num_heads: int,
head_size: int,
cache_dtype: Optional[Union[str, torch.dtype]],
model_dtype: Optional[Union[str, torch.dtype]] = None,
seed: Optional[int] = None,
device: Optional[str] = "cuda",
cache_layout: Optional[str] = "NHD",
cache_dtype: Union[str, torch.dtype] | None,
model_dtype: Union[str, torch.dtype] | None = None,
seed: int | None = None,
device: str | None = "cuda",
cache_layout: str | None = "NHD",
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
from vllm.platforms import current_platform
......@@ -1095,10 +1092,10 @@ def create_kv_caches_with_random(
num_layers: int,
num_heads: int,
head_size: int,
cache_dtype: Optional[Union[str, torch.dtype]],
model_dtype: Optional[Union[str, torch.dtype]] = None,
seed: Optional[int] = None,
device: Optional[str] = "cuda",
cache_dtype: Union[str, torch.dtype] | None,
model_dtype: Union[str, torch.dtype] | None = None,
seed: int | None = None,
device: str | None = "cuda",
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
if cache_dtype == "fp8" and head_size % 16:
raise ValueError(
......@@ -1156,7 +1153,7 @@ def is_uva_available() -> bool:
class DeviceMemoryProfiler:
def __init__(self, device: Optional[torch.types.Device] = None):
def __init__(self, device: torch.types.Device | None = None):
self.device = device
def current_memory_usage(self) -> float:
......@@ -1184,7 +1181,7 @@ def make_ndarray_with_pad(
pad: T,
dtype: npt.DTypeLike,
*,
max_len: Optional[int] = None,
max_len: int | None = None,
) -> npt.NDArray:
"""
Make a padded array from 2D inputs.
......@@ -1209,8 +1206,8 @@ def make_tensor_with_pad(
pad: T,
dtype: torch.dtype,
*,
max_len: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
max_len: int | None = None,
device: Union[str, torch.device] | None = None,
pin_memory: bool = False,
) -> torch.Tensor:
"""
......@@ -1405,7 +1402,7 @@ def find_nccl_library() -> str:
return so_file
def find_nccl_include_paths() -> Optional[list[str]]:
def find_nccl_include_paths() -> list[str] | None:
"""
We either use the nccl.h specified by the `VLLM_NCCL_INCLUDE_PATH`
environment variable, or we find the library file brought by
......@@ -1525,7 +1522,7 @@ F = TypeVar("F", bound=Callable[..., Any])
def deprecate_args(
start_index: int,
is_deprecated: Union[bool, Callable[[], bool]] = True,
additional_message: Optional[str] = None,
additional_message: str | None = None,
) -> Callable[[F], F]:
if not callable(is_deprecated):
is_deprecated = partial(identity, is_deprecated)
......@@ -1565,7 +1562,7 @@ def deprecate_args(
def deprecate_kwargs(
*kws: str,
is_deprecated: Union[bool, Callable[[], bool]] = True,
additional_message: Optional[str] = None,
additional_message: str | None = None,
) -> Callable[[F], F]:
deprecated_kws = set(kws)
......@@ -1598,7 +1595,7 @@ def deprecate_kwargs(
@lru_cache(maxsize=8)
def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) -> int:
def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int:
# Note: cuda_visible_devices is not used, but we keep it as an argument for
# LRU Cache purposes.
......@@ -1746,7 +1743,7 @@ class FlexibleArgumentParser(ArgumentParser):
' --json-arg \'{"key4": ["value3", "value4", "value5"]}\'\n'
" --json-arg.key4+ value3 --json-arg.key4+='value4,value5'\n\n"
)
_search_keyword: Optional[str] = None
_search_keyword: str | None = None
def __init__(self, *args, **kwargs):
# Set the default "formatter_class" to SortedHelpFormatter
......@@ -2245,7 +2242,7 @@ def supports_kw(
def get_allowed_kwarg_only_overrides(
callable: Callable[..., object],
overrides: Optional[Mapping[str, object]],
overrides: Mapping[str, object] | None,
*,
requires_kw_only: bool = True,
allow_var_kwargs: bool = False,
......@@ -2695,10 +2692,10 @@ vllm_lib = Library("vllm", "FRAGMENT") # noqa
def direct_register_custom_op(
op_name: str,
op_func: Callable,
mutates_args: Optional[list[str]] = None,
fake_impl: Optional[Callable] = None,
target_lib: Optional[Library] = None,
dispatch_key: Optional[str] = None,
mutates_args: list[str] | None = None,
fake_impl: Callable | None = None,
target_lib: Library | None = None,
dispatch_key: str | None = None,
tags: tuple[torch.Tag, ...] = (),
):
"""
......@@ -3016,7 +3013,7 @@ def split_zmq_path(path: str) -> tuple[str, str, str]:
return scheme, host, port
def make_zmq_path(scheme: str, host: str, port: Optional[int] = None) -> str:
def make_zmq_path(scheme: str, host: str, port: int | None = None) -> str:
"""Make a ZMQ path from its parts.
Args:
......@@ -3039,9 +3036,9 @@ def make_zmq_socket(
ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined]
path: str,
socket_type: Any,
bind: Optional[bool] = None,
identity: Optional[bytes] = None,
linger: Optional[int] = None,
bind: bool | None = None,
identity: bytes | None = None,
linger: int | None = None,
) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined]
"""Make a ZMQ socket with the proper bind/connect semantics."""
......@@ -3098,9 +3095,9 @@ def make_zmq_socket(
def zmq_socket_ctx(
path: str,
socket_type: Any,
bind: Optional[bool] = None,
bind: bool | None = None,
linger: int = 0,
identity: Optional[bytes] = None,
identity: bytes | None = None,
) -> Iterator[zmq.Socket]:
"""Context manager for a ZMQ socket"""
......@@ -3163,7 +3160,7 @@ def get_mp_context():
def bind_kv_cache(
ctx: dict[str, Any],
kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index]
shared_kv_cache_layers: Optional[dict[str, str]] = None,
shared_kv_cache_layers: dict[str, str] | None = None,
) -> None:
# Bind the kv_cache tensor to Attention modules, similar to
# ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
......@@ -3379,7 +3376,7 @@ def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None:
@contextlib.contextmanager
def cprofile_context(save_file: Optional[str] = None):
def cprofile_context(save_file: str | None = None):
"""Run a cprofile
Args:
......@@ -3401,7 +3398,7 @@ def cprofile_context(save_file: Optional[str] = None):
prof.print_stats(sort="cumtime")
def cprofile(save_file: Optional[str] = None, enabled: bool = True):
def cprofile(save_file: str | None = None, enabled: bool = True):
"""Decorator to profile a Python method using cProfile.
Args:
......@@ -3608,7 +3605,7 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
file.write = write_with_prefix # type: ignore[method-assign]
def decorate_logs(process_name: Optional[str] = None) -> None:
def decorate_logs(process_name: str | None = None) -> None:
"""
Adds a process-specific prefix to each line of output written to stdout and
stderr.
......@@ -3631,8 +3628,8 @@ def decorate_logs(process_name: Optional[str] = None) -> None:
def length_from_prompt_token_ids_or_embeds(
prompt_token_ids: Optional[list[int]],
prompt_embeds: Optional[torch.Tensor],
prompt_token_ids: list[int] | None,
prompt_embeds: torch.Tensor | None,
) -> int:
"""Calculate the request length (in number of tokens) give either
prompt_token_ids or prompt_embeds.
......
......@@ -10,7 +10,7 @@ from __future__ import annotations
import functools
import importlib
import os
from typing import Any, Callable, NoReturn, Optional
from typing import Any, Callable, NoReturn
import torch
......@@ -325,7 +325,7 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor):
def should_use_deepgemm_for_fp8_linear(
output_dtype: torch.dtype,
weight: torch.Tensor,
supports_deep_gemm: Optional[bool] = None,
supports_deep_gemm: bool | None = None,
):
if supports_deep_gemm is None:
supports_deep_gemm = is_deep_gemm_supported()
......
......@@ -12,7 +12,7 @@ import functools
import importlib
import importlib.util
import os
from typing import Any, Callable, NoReturn, Optional
from typing import Any, Callable, NoReturn
import requests
import torch
......@@ -202,14 +202,14 @@ def supports_trtllm_attention() -> bool:
@functools.cache
def _force_use_trtllm_attention(env_value: Optional[bool]) -> Optional[bool]:
def _force_use_trtllm_attention(env_value: bool | None) -> bool | None:
"""Cache the env value for VLLM_USE_TRTLLM_ATTENTION"""
if env_value is not None:
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
return env_value
def force_use_trtllm_attention() -> Optional[bool]:
def force_use_trtllm_attention() -> bool | None:
"""
Return ``None`` if VLLM_USE_TRTLLM_ATTENTION is not set,
return ``True`` if TRTLLM attention is forced to be used,
......@@ -401,7 +401,7 @@ def flashinfer_scaled_fp8_mm(
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
assert a.ndim == 2 and b.ndim == 2
assert a.shape[1] == b.shape[0]
......
......@@ -5,7 +5,7 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import ClassVar, Optional, Union
from typing import ClassVar, Union
import numpy as np
import torch
......@@ -254,12 +254,12 @@ class FlashInferMetadata:
# For cascade attention (CPU for planning).
use_cascade: bool
prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None
prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper | None = None
decode_wrapper: BatchDecodeWithPagedKVCacheWrapper | None = None
cascade_wrapper: MultiLevelCascadeAttentionWrapper | None = None
qo_indptr_gpu: Optional[torch.Tensor] = None
paged_kv_indptr_gpu: Optional[torch.Tensor] = None
qo_indptr_gpu: torch.Tensor | None = None
paged_kv_indptr_gpu: torch.Tensor | None = None
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
......@@ -727,13 +727,13 @@ class FlashInferImpl(AttentionImpl):
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: Optional[float] = None,
logits_soft_cap: float | None = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
sinks: Optional[torch.Tensor] = None,
kv_sharing_target_layer_name: int | None = None,
sinks: torch.Tensor | None = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
......@@ -763,7 +763,7 @@ class FlashInferImpl(AttentionImpl):
"FlashInferImpl"
)
self.sinks: Optional[torch.Tensor] = None
self.sinks: torch.Tensor | None = None
if sinks is not None:
if sinks.shape[0] != num_heads:
raise ValueError(
......@@ -776,9 +776,9 @@ class FlashInferImpl(AttentionImpl):
self.support_trtllm_attn = (
supports_trtllm_attention() and num_heads % num_kv_heads == 0
)
self.bmm1_scale: Optional[float] = None
self.bmm2_scale: Optional[float] = None
self.o_sf_scale: Optional[float] = None
self.bmm1_scale: float | None = None
self.bmm2_scale: float | None = None
self.o_sf_scale: float | None = None
def fused_output_quant_supported(self, quant_key: QuantKey):
return (
......@@ -795,9 +795,9 @@ class FlashInferImpl(AttentionImpl):
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashInferMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass with FlashInfer.
......@@ -1093,13 +1093,13 @@ def fast_plan_decode(
page_size: int,
pos_encoding_mode: str = "NONE",
window_left: int = -1,
logits_soft_cap: Optional[float] = None,
q_data_type: Optional[Union[str, torch.dtype]] = "float16",
kv_data_type: Optional[Union[str, torch.dtype]] = None,
data_type: Optional[Union[str, torch.dtype]] = None,
sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
logits_soft_cap: float | None = None,
q_data_type: Union[str, torch.dtype] | None = "float16",
kv_data_type: Union[str, torch.dtype] | None = None,
data_type: Union[str, torch.dtype] | None = None,
sm_scale: float | None = None,
rope_scale: float | None = None,
rope_theta: float | None = None,
non_blocking: bool = True,
) -> None:
"""
......
......@@ -4,7 +4,7 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING
from vllm._bc_linter import bc_linter_include
......@@ -25,14 +25,14 @@ if TYPE_CHECKING:
@dataclass
class NewRequestData:
req_id: str
prompt_token_ids: Optional[list[int]]
prompt_token_ids: list[int] | None
mm_features: list[MultiModalFeatureSpec]
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
sampling_params: SamplingParams | None
pooling_params: PoolingParams | None
block_ids: tuple[list[int], ...]
num_computed_tokens: int
lora_request: Optional[LoRARequest]
prompt_embeds: Optional[torch.Tensor] = None
lora_request: LoRARequest | None
prompt_embeds: torch.Tensor | None = None
@classmethod
def from_request(
......@@ -98,7 +98,7 @@ class CachedRequestData:
# NOTE(woosuk): new_token_ids is only used for pipeline parallelism.
# When PP is not used, new_token_ids will be empty.
new_token_ids: list[list[int]]
new_block_ids: list[Optional[tuple[list[int], ...]]]
new_block_ids: list[tuple[list[int], ...] | None]
num_computed_tokens: list[int]
num_output_tokens: list[int]
......@@ -160,7 +160,7 @@ class SchedulerOutput:
# for filling the next token bitmask
structured_output_request_ids: dict[str, int]
# the bitmask for the whole batch
grammar_bitmask: Optional[npt.NDArray[np.int32]]
grammar_bitmask: npt.NDArray[np.int32] | None
# KV Cache Connector metadata.
kv_connector_metadata: Optional[KVConnectorMetadata] = None
kv_connector_metadata: KVConnectorMetadata | None = None
......@@ -7,7 +7,7 @@ import itertools
import time
from collections import defaultdict
from collections.abc import Iterable
from typing import Any, Optional, Union
from typing import Any, Union
from vllm.config import VllmConfig
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
......@@ -64,7 +64,7 @@ class Scheduler(SchedulerInterface):
# request ids should be included in the EngineCoreOutputs returned
# by update_from_outputs(). This is currently used in the multi-engine
# case to track request lifetimes efficiently.
self.finished_req_ids_dict: Optional[dict[int, set[str]]] = (
self.finished_req_ids_dict: dict[int, set[str]] | None = (
defaultdict(set) if include_finished_set else None
)
......@@ -708,7 +708,7 @@ class Scheduler(SchedulerInterface):
) -> CachedRequestData:
req_ids: list[str] = []
new_token_ids: list[list[int]] = []
new_block_ids: list[Optional[tuple[list[int], ...]]] = []
new_block_ids: list[tuple[list[int], ...] | None] = []
num_computed_tokens: list[int] = []
num_output_tokens: list[int] = []
......@@ -921,7 +921,7 @@ class Scheduler(SchedulerInterface):
kv_connector_output = model_runner_output.kv_connector_output
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
spec_decoding_stats: Optional[SpecDecodingStats] = None
spec_decoding_stats: SpecDecodingStats | None = None
kv_connector_stats = (
kv_connector_output.kv_connector_stats if kv_connector_output else None
)
......@@ -1212,7 +1212,7 @@ class Scheduler(SchedulerInterface):
request.status = finished_status
self._free_request(request)
def _free_request(self, request: Request) -> Optional[dict[str, Any]]:
def _free_request(self, request: Request) -> dict[str, Any] | None:
assert request.is_finished()
delay_free_blocks, kv_xfer_params = self._connector_finished(request)
......@@ -1243,9 +1243,9 @@ class Scheduler(SchedulerInterface):
def make_stats(
self,
spec_decoding_stats: Optional[SpecDecodingStats] = None,
kv_connector_stats: Optional[KVConnectorStats] = None,
) -> Optional[SchedulerStats]:
spec_decoding_stats: SpecDecodingStats | None = None,
kv_connector_stats: KVConnectorStats | None = None,
) -> SchedulerStats | None:
if not self.log_stats:
return None
prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats()
......@@ -1262,10 +1262,10 @@ class Scheduler(SchedulerInterface):
def make_spec_decoding_stats(
self,
spec_decoding_stats: Optional[SpecDecodingStats],
spec_decoding_stats: SpecDecodingStats | None,
num_draft_tokens: int,
num_accepted_tokens: int,
) -> Optional[SpecDecodingStats]:
) -> SpecDecodingStats | None:
if not self.log_stats:
return None
if spec_decoding_stats is None:
......@@ -1285,12 +1285,12 @@ class Scheduler(SchedulerInterface):
# KV Connector Related Methods
########################################################################
def get_kv_connector(self) -> Optional[KVConnectorBase_V1]:
def get_kv_connector(self) -> KVConnectorBase_V1 | None:
return self.connector
def _connector_finished(
self, request: Request
) -> tuple[bool, Optional[dict[str, Any]]]:
) -> tuple[bool, dict[str, Any] | None]:
"""
Invoke the KV connector request_finished() method if applicable.
......
......@@ -4,7 +4,7 @@ from __future__ import annotations
import multiprocessing
from concurrent.futures import Future, ThreadPoolExecutor
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING
from vllm.config import VllmConfig
from vllm.logger import init_logger
......@@ -35,11 +35,11 @@ class StructuredOutputManager:
"""Engine-level manager for structured output requests."""
def __init__(self, vllm_config: VllmConfig):
self.backend: Optional[StructuredOutputBackend] = None
self.reasoner: Optional[ReasoningParser] = None
self.backend: StructuredOutputBackend | None = None
self.reasoner: ReasoningParser | None = None
self.vllm_config = vllm_config
self._grammar_bitmask: Optional[torch.Tensor] = None
self._grammar_bitmask: torch.Tensor | None = None
self._full_mask = torch.tensor(-1, dtype=torch.int32)
max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
......@@ -168,7 +168,7 @@ class StructuredOutputManager:
requests: dict[str, Request],
structured_output_request_ids: dict[str, int],
scheduled_spec_decode_tokens: dict[str, list[int]],
) -> Optional[npt.NDArray[np.int32]]:
) -> npt.NDArray[np.int32] | None:
# Prepare the structured output bitmask for this batch.
if not structured_output_request_ids:
return None
......
......@@ -7,7 +7,7 @@ import copy
import json
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Union
import torch
......@@ -252,7 +252,7 @@ def serialize_guidance_grammar(
def validate_guidance_grammar(
sampling_params: SamplingParams, tokenizer: Optional[llguidance.LLTokenizer] = None
sampling_params: SamplingParams, tokenizer: llguidance.LLTokenizer | None = None
) -> None:
tp, grm = get_structured_output_key(sampling_params)
guidance_grm = serialize_guidance_grammar(tp, grm)
......
......@@ -20,10 +20,10 @@ from vllm.v1.structured_output.backend_types import (
@dataclasses.dataclass
class StructuredOutputRequest:
sampling_params: SamplingParams
_grammar: Optional[
Union[Future[StructuredOutputGrammar], StructuredOutputGrammar]
] = None
reasoning_ended: Optional[bool] = None
_grammar: Union[Future[StructuredOutputGrammar], StructuredOutputGrammar] | None = (
None
)
reasoning_ended: bool | None = None
def _check_grammar_completion(self) -> bool:
# NOTE: We have to lazy import to gate circular imports
......@@ -43,7 +43,7 @@ class StructuredOutputRequest:
return self._check_grammar_completion()
@property
def grammar(self) -> Optional[StructuredOutputGrammar]:
def grammar(self) -> StructuredOutputGrammar | None:
completed = self._check_grammar_completion()
return (
cast(Optional[StructuredOutputGrammar], self._grammar)
......
......@@ -4,7 +4,7 @@
from __future__ import annotations
import os
from typing import Any, Callable, Optional, TypeVar, Union
from typing import Any, Callable, TypeVar, Union
import torch
import torch.nn as nn
......@@ -78,8 +78,8 @@ class WorkerBase:
self.is_driver_worker = is_driver_worker
# Device and model state
self.device: Optional[torch.device] = None
self.model_runner: Optional[nn.Module] = None
self.device: torch.device | None = None
self.model_runner: nn.Module | None = None
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"""Get specifications for KV cache implementation."""
......@@ -115,8 +115,8 @@ class WorkerBase:
raise NotImplementedError
def execute_model(
self, execute_model_req: Optional[ExecuteModelRequest] = None
) -> Optional[list[SamplerOutput]]:
self, execute_model_req: ExecuteModelRequest | None = None
) -> list[SamplerOutput] | None:
raise NotImplementedError
def start_worker_execution_loop(self) -> None:
......@@ -198,8 +198,8 @@ class WorkerWrapperBase:
group.
"""
self.rpc_rank = rpc_rank
self.worker: Optional[WorkerBase] = None
self.vllm_config: Optional[VllmConfig] = None
self.worker: WorkerBase | None = None
self.vllm_config: VllmConfig | None = None
# do not store this `vllm_config`, `init_worker` will set the final
# one. TODO: investigate if we can remove this field in
# `WorkerWrapperBase`, `init_cached_hf_modules` should be
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment