"vscode:/vscode.git/clone" did not exist on "0cb7b065c3c2c3e9fa269f48fb4e945616f7f5b9"
Unverified Commit 1a7894db authored by 杨朱 · Kiki's avatar 杨朱 · Kiki Committed by GitHub
Browse files

[Misc] Replace Optional[X] with X | None syntax (#33332)


Signed-off-by: default avatarcarlory <baofa.fan@daocloud.io>
Co-authored-by: default avatarClaude Opus 4.5 <noreply@anthropic.com>
parent c87eac18
...@@ -7,7 +7,7 @@ import platform ...@@ -7,7 +7,7 @@ import platform
import random import random
import sys import sys
from datetime import timedelta from datetime import timedelta
from typing import TYPE_CHECKING, Any, NamedTuple, Optional from typing import TYPE_CHECKING, Any, NamedTuple
import numpy as np import numpy as np
import torch import torch
...@@ -243,7 +243,7 @@ class Platform: ...@@ -243,7 +243,7 @@ class Platform:
cls, cls,
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
backend: Optional["AttentionBackendEnum"] = None, backend: "AttentionBackendEnum | None" = None,
) -> "AttentionBackendEnum": ) -> "AttentionBackendEnum":
""" """
Get the vision attention backend class of a device. Get the vision attention backend class of a device.
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import os import os
from functools import cache, lru_cache, wraps from functools import cache, lru_cache, wraps
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
import torch import torch
...@@ -356,7 +356,7 @@ class RocmPlatform(Platform): ...@@ -356,7 +356,7 @@ class RocmPlatform(Platform):
cls, cls,
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
backend: Optional["AttentionBackendEnum"] = None, backend: "AttentionBackendEnum | None" = None,
) -> "AttentionBackendEnum": ) -> "AttentionBackendEnum":
if backend is not None: if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), ( assert backend in cls.get_supported_vit_attn_backends(), (
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import contextlib import contextlib
import os import os
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
import torch import torch
...@@ -88,7 +88,7 @@ class XPUPlatform(Platform): ...@@ -88,7 +88,7 @@ class XPUPlatform(Platform):
cls, cls,
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
backend: Optional["AttentionBackendEnum"] = None, backend: "AttentionBackendEnum | None" = None,
) -> "AttentionBackendEnum": ) -> "AttentionBackendEnum":
if backend is not None: if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), ( assert backend in cls.get_supported_vit_attn_backends(), (
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import deepcopy from copy import deepcopy
from typing import Annotated, Any, Optional from typing import Annotated, Any
import msgspec import msgspec
...@@ -80,7 +80,7 @@ class PoolingParams( ...@@ -80,7 +80,7 @@ class PoolingParams(
return deepcopy(self) return deepcopy(self)
def verify( def verify(
self, task: PoolingTask, model_config: Optional["ModelConfig"] = None self, task: PoolingTask, model_config: "ModelConfig | None" = None
) -> None: ) -> None:
if self.task is None: if self.task is None:
self.task = task self.task = task
...@@ -106,7 +106,7 @@ class PoolingParams( ...@@ -106,7 +106,7 @@ class PoolingParams(
self._verify_valid_parameters() self._verify_valid_parameters()
def _merge_default_parameters( def _merge_default_parameters(
self, model_config: Optional["ModelConfig"] = None self, model_config: "ModelConfig | None" = None
) -> None: ) -> None:
if model_config is None: if model_config is None:
return return
...@@ -160,7 +160,7 @@ class PoolingParams( ...@@ -160,7 +160,7 @@ class PoolingParams(
if getattr(self, k, None) is None: if getattr(self, k, None) is None:
setattr(self, k, getattr(pooler_config, k)) setattr(self, k, getattr(pooler_config, k))
def _set_default_parameters(self, model_config: Optional["ModelConfig"]): def _set_default_parameters(self, model_config: "ModelConfig | None"):
if self.task in ["embed", "token_embed"]: if self.task in ["embed", "token_embed"]:
if self.use_activation is None: if self.use_activation is None:
self.use_activation = True self.use_activation = True
......
...@@ -5,7 +5,7 @@ import copy ...@@ -5,7 +5,7 @@ import copy
from collections import defaultdict from collections import defaultdict
from collections.abc import Callable from collections.abc import Callable
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Optional, TypeAlias from typing import Any, TypeAlias
from torch._C._autograd import DeviceType, _KinetoEvent, _ProfilerResult from torch._C._autograd import DeviceType, _KinetoEvent, _ProfilerResult
from torch._C._profiler import _EventType, _ExperimentalConfig, _ProfilerEvent from torch._C._profiler import _EventType, _ExperimentalConfig, _ProfilerEvent
...@@ -31,7 +31,7 @@ except ImportError: ...@@ -31,7 +31,7 @@ except ImportError:
@dataclass @dataclass
class _ModuleTreeNode: class _ModuleTreeNode:
event: _ProfilerEvent event: _ProfilerEvent
parent: Optional["_ModuleTreeNode"] = None parent: "_ModuleTreeNode | None" = None
children: list["_ModuleTreeNode"] = field(default_factory=list) children: list["_ModuleTreeNode"] = field(default_factory=list)
trace: str = "" trace: str = ""
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import fnmatch import fnmatch
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
from vllm.utils.import_utils import PlaceholderModule from vllm.utils.import_utils import PlaceholderModule
...@@ -32,7 +32,7 @@ def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]: ...@@ -32,7 +32,7 @@ def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]:
def glob( def glob(
s3: Optional["BaseClient"] = None, s3: "BaseClient | None" = None,
path: str = "", path: str = "",
allow_pattern: list[str] | None = None, allow_pattern: list[str] | None = None,
) -> list[str]: ) -> list[str]:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Optional from typing import TYPE_CHECKING, ClassVar
import numpy as np import numpy as np
import torch import torch
...@@ -707,7 +707,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): ...@@ -707,7 +707,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
kv_sharing_target_layer_name: str | None, kv_sharing_target_layer_name: str | None,
# MLA Specific Arguments # MLA Specific Arguments
topk_indice_buffer: torch.Tensor | None = None, topk_indice_buffer: torch.Tensor | None = None,
indexer: Optional["Indexer"] = None, indexer: "Indexer | None" = None,
**mla_args, **mla_args,
) -> None: ) -> None:
super().__init__( super().__init__(
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Optional from typing import TYPE_CHECKING, ClassVar
import numpy as np import numpy as np
import torch import torch
...@@ -284,7 +284,7 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]): ...@@ -284,7 +284,7 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]):
kv_sharing_target_layer_name: str | None, kv_sharing_target_layer_name: str | None,
# MLA Specific Arguments # MLA Specific Arguments
topk_indice_buffer: torch.Tensor | None = None, topk_indice_buffer: torch.Tensor | None = None,
indexer: Optional["Indexer"] = None, indexer: "Indexer | None" = None,
**mla_args, **mla_args,
) -> None: ) -> None:
super().__init__( super().__init__(
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import ast import ast
from dataclasses import dataclass from dataclasses import dataclass
from typing import ClassVar, Optional from typing import ClassVar
import torch import torch
...@@ -87,11 +87,11 @@ class TreeAttentionMetadata: ...@@ -87,11 +87,11 @@ class TreeAttentionMetadata:
tree_attn_bias: torch.Tensor | None = None tree_attn_bias: torch.Tensor | None = None
# Cached Prefill/decode metadata. # Cached Prefill/decode metadata.
_cached_prefill_metadata: Optional["TreeAttentionMetadata"] = None _cached_prefill_metadata: "TreeAttentionMetadata | None" = None
_cached_decode_metadata: Optional["TreeAttentionMetadata"] = None _cached_decode_metadata: "TreeAttentionMetadata | None" = None
@property @property
def prefill_metadata(self) -> Optional["TreeAttentionMetadata"]: def prefill_metadata(self) -> "TreeAttentionMetadata | None":
if self.num_prefills == 0: if self.num_prefills == 0:
return None return None
...@@ -116,7 +116,7 @@ class TreeAttentionMetadata: ...@@ -116,7 +116,7 @@ class TreeAttentionMetadata:
return self._cached_prefill_metadata return self._cached_prefill_metadata
@property @property
def decode_metadata(self) -> Optional["TreeAttentionMetadata"]: def decode_metadata(self) -> "TreeAttentionMetadata | None":
if self.num_decode_tokens == 0: if self.num_decode_tokens == 0:
return None return None
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Iterable from collections.abc import Iterable
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
...@@ -189,7 +189,7 @@ class SchedulerInterface(ABC): ...@@ -189,7 +189,7 @@ class SchedulerInterface(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def make_stats(self) -> Optional["SchedulerStats"]: def make_stats(self) -> "SchedulerStats | None":
"""Make a SchedulerStats object for logging. """Make a SchedulerStats object for logging.
The SchedulerStats object is created for every scheduling step. The SchedulerStats object is created for every scheduling step.
...@@ -201,5 +201,5 @@ class SchedulerInterface(ABC): ...@@ -201,5 +201,5 @@ class SchedulerInterface(ABC):
"""Shutdown the scheduler.""" """Shutdown the scheduler."""
raise NotImplementedError raise NotImplementedError
def get_kv_connector(self) -> Optional["KVConnectorBase_V1"]: def get_kv_connector(self) -> "KVConnectorBase_V1 | None":
return None return None
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import copy from copy import copy
from typing import Optional, cast from typing import cast
from vllm.outputs import CompletionOutput from vllm.outputs import CompletionOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
...@@ -133,7 +133,7 @@ class ParentRequest: ...@@ -133,7 +133,7 @@ class ParentRequest:
@staticmethod @staticmethod
def observe_finished_request( def observe_finished_request(
parent_req: Optional["ParentRequest"], parent_req: "ParentRequest | None",
iteration_stats: IterationStats, iteration_stats: IterationStats,
num_generation_tokens: int, num_generation_tokens: int,
): ):
......
...@@ -7,7 +7,7 @@ from collections import deque ...@@ -7,7 +7,7 @@ from collections import deque
from collections.abc import Callable, Mapping from collections.abc import Callable, Mapping
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any
import torch import torch
...@@ -68,7 +68,7 @@ class Request: ...@@ -68,7 +68,7 @@ class Request:
arrival_time: float | None = None, arrival_time: float | None = None,
prompt_embeds: torch.Tensor | None = None, prompt_embeds: torch.Tensor | None = None,
mm_features: list[MultiModalFeatureSpec] | None = None, mm_features: list[MultiModalFeatureSpec] | None = None,
lora_request: Optional["LoRARequest"] = None, lora_request: "LoRARequest | None" = None,
cache_salt: str | None = None, cache_salt: str | None = None,
priority: int = 0, priority: int = 0,
trace_headers: Mapping[str, str] | None = None, trace_headers: Mapping[str, str] | None = None,
......
...@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod ...@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
import torch import torch
...@@ -94,7 +94,7 @@ class LogitsProcessor(ABC): ...@@ -94,7 +94,7 @@ class LogitsProcessor(ABC):
@abstractmethod @abstractmethod
def update_state( def update_state(
self, self,
batch_update: Optional["BatchUpdate"], batch_update: "BatchUpdate | None",
) -> None: ) -> None:
"""Called when there are new output tokens, prior """Called when there are new output tokens, prior
to each forward pass. to each forward pass.
......
...@@ -14,7 +14,6 @@ from typing import ( ...@@ -14,7 +14,6 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Generic, Generic,
Optional,
TypeVar, TypeVar,
Union, Union,
overload, overload,
...@@ -229,7 +228,7 @@ def wait_for_completion_or_failure( ...@@ -229,7 +228,7 @@ def wait_for_completion_or_failure(
api_server_manager: APIServerProcessManager, api_server_manager: APIServerProcessManager,
engine_manager: Union["CoreEngineProcManager", "CoreEngineActorManager"] engine_manager: Union["CoreEngineProcManager", "CoreEngineActorManager"]
| None = None, | None = None,
coordinator: Optional["DPCoordinator"] = None, coordinator: "DPCoordinator | None" = None,
) -> None: ) -> None:
"""Wait for all processes to complete or detect if any fail. """Wait for all processes to complete or detect if any fail.
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading import threading
from typing import Optional
import torch import torch
...@@ -15,7 +14,7 @@ logger = init_logger(__name__) ...@@ -15,7 +14,7 @@ logger = init_logger(__name__)
_THREAD_ID_TO_CONTEXT: dict = {} _THREAD_ID_TO_CONTEXT: dict = {}
# Here we hardcode the number of microbatches to 2 for default. # Here we hardcode the number of microbatches to 2 for default.
_NUM_UBATCHES: int = 2 _NUM_UBATCHES: int = 2
_CURRENT_CONTEXTS: list[Optional["UBatchContext"]] = [] _CURRENT_CONTEXTS: list["UBatchContext | None"] = []
class UBatchContext: class UBatchContext:
......
...@@ -5,7 +5,6 @@ import inspect ...@@ -5,7 +5,6 @@ import inspect
import os import os
from itertools import accumulate from itertools import accumulate
from math import prod from math import prod
from typing import Optional
import torch import torch
...@@ -26,7 +25,7 @@ _MB = 1024**2 ...@@ -26,7 +25,7 @@ _MB = 1024**2
_GiB = 1024**3 _GiB = 1024**3
# Global workspace manager instance # Global workspace manager instance
_manager: Optional["WorkspaceManager"] = None _manager: "WorkspaceManager | None" = None
class WorkspaceManager: class WorkspaceManager:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment