Unverified Commit 1a7894db authored by 杨朱 · Kiki's avatar 杨朱 · Kiki Committed by GitHub
Browse files

[Misc] Replace Optional[X] with X | None syntax (#33332)


Signed-off-by: default avatarcarlory <baofa.fan@daocloud.io>
Co-authored-by: default avatarClaude Opus 4.5 <noreply@anthropic.com>
parent c87eac18
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from contextlib import suppress from contextlib import suppress
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Any, Literal, Optional, cast from typing import TYPE_CHECKING, Any, Literal, cast
import torch import torch
from compressed_tensors.config import ( from compressed_tensors.config import (
...@@ -160,7 +160,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -160,7 +160,7 @@ class CompressedTensorsConfig(QuantizationConfig):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
prefix: str, prefix: str,
) -> Optional["QuantizeMethodBase"]: ) -> "QuantizeMethodBase | None":
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
# collect schemes # collect schemes
quant_scheme = self.get_scheme(layer=layer, layer_name=prefix) quant_scheme = self.get_scheme(layer=layer, layer_name=prefix)
...@@ -691,7 +691,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -691,7 +691,7 @@ class CompressedTensorsConfig(QuantizationConfig):
def get_scheme( def get_scheme(
self, layer: torch.nn.Module, layer_name: str | None = None self, layer: torch.nn.Module, layer_name: str | None = None
) -> Optional["CompressedTensorsScheme"]: ) -> "CompressedTensorsScheme | None":
""" """
compressed-tensors supports non uniform in the following way: compressed-tensors supports non uniform in the following way:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional from typing import Any
import torch import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
...@@ -105,7 +105,7 @@ class CPUAWQConfig(QuantizationConfig): ...@@ -105,7 +105,7 @@ class CPUAWQConfig(QuantizationConfig):
@classmethod @classmethod
def override_quantization_method( def override_quantization_method(
cls, hf_quant_cfg, user_quant cls, hf_quant_cfg, user_quant
) -> Optional["QuantizationMethods"]: ) -> "QuantizationMethods | None":
quant_method = hf_quant_cfg.get("quant_method", "").lower() quant_method = hf_quant_cfg.get("quant_method", "").lower()
if current_platform.is_cpu() and (quant_method == "awq"): if current_platform.is_cpu() and (quant_method == "awq"):
return cls.get_name() return cls.get_name()
...@@ -113,7 +113,7 @@ class CPUAWQConfig(QuantizationConfig): ...@@ -113,7 +113,7 @@ class CPUAWQConfig(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> "QuantizeMethodBase | None":
if isinstance(layer, LinearBase) or ( if isinstance(layer, LinearBase) or (
isinstance(layer, ParallelLMHead) and self.lm_head_quantized isinstance(layer, ParallelLMHead) and self.lm_head_quantized
): ):
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional from typing import Any
import torch import torch
...@@ -52,7 +52,7 @@ class ExpertsInt8Config(QuantizationConfig): ...@@ -52,7 +52,7 @@ class ExpertsInt8Config(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> "QuantizeMethodBase | None":
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional from typing import Any
import torch import torch
from torch.nn import Module from torch.nn import Module
...@@ -78,7 +78,7 @@ class FBGEMMFp8Config(QuantizationConfig): ...@@ -78,7 +78,7 @@ class FBGEMMFp8Config(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> "QuantizeMethodBase | None":
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
if is_layer_skipped( if is_layer_skipped(
prefix=prefix, prefix=prefix,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any
import torch import torch
from torch.nn import Module from torch.nn import Module
...@@ -182,7 +182,7 @@ class Fp8Config(QuantizationConfig): ...@@ -182,7 +182,7 @@ class Fp8Config(QuantizationConfig):
def get_xpu_quant_method( def get_xpu_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> "QuantizeMethodBase | None":
from vllm.model_executor.layers.quantization.ipex_quant import ( from vllm.model_executor.layers.quantization.ipex_quant import (
XPUFp8LinearMethod, XPUFp8LinearMethod,
XPUFp8MoEMethod, XPUFp8MoEMethod,
...@@ -218,7 +218,7 @@ class Fp8Config(QuantizationConfig): ...@@ -218,7 +218,7 @@ class Fp8Config(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> "QuantizeMethodBase | None":
if current_platform.is_xpu(): if current_platform.is_xpu():
return self.get_xpu_quant_method(layer, prefix) return self.get_xpu_quant_method(layer, prefix)
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from collections.abc import Mapping from collections.abc import Mapping
from types import MappingProxyType from types import MappingProxyType
from typing import Any, Optional from typing import Any
import gguf import gguf
import torch import torch
...@@ -77,7 +77,7 @@ class GGUFConfig(QuantizationConfig): ...@@ -77,7 +77,7 @@ class GGUFConfig(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> "QuantizeMethodBase | None":
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
if is_layer_skipped_gguf( if is_layer_skipped_gguf(
prefix, self.unquantized_modules, self.packed_modules_mapping prefix, self.unquantized_modules, self.packed_modules_mapping
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import deepcopy from copy import deepcopy
from typing import Any, Optional from typing import Any
import torch import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
...@@ -240,7 +240,7 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -240,7 +240,7 @@ class GPTQMarlinConfig(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> "QuantizeMethodBase | None":
if isinstance(layer, FusedMoE): if isinstance(layer, FusedMoE):
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from fractions import Fraction from fractions import Fraction
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any
import regex as re import regex as re
import torch import torch
...@@ -456,7 +456,7 @@ class INCConfig(QuantizationConfig): ...@@ -456,7 +456,7 @@ class INCConfig(QuantizationConfig):
@classmethod @classmethod
def override_quantization_method( def override_quantization_method(
cls, hf_quant_cfg, user_quant cls, hf_quant_cfg, user_quant
) -> Optional["QuantizationMethods"]: ) -> "QuantizationMethods | None":
"""Override the `auto-round` method to `inc`.""" """Override the `auto-round` method to `inc`."""
is_auto_round_format = hf_quant_cfg.get("quant_method", None) == "auto-round" is_auto_round_format = hf_quant_cfg.get("quant_method", None) == "auto-round"
if is_auto_round_format: if is_auto_round_format:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional from typing import Any
import torch import torch
from packaging import version from packaging import version
...@@ -144,7 +144,7 @@ class IPEXConfig(QuantizationConfig): ...@@ -144,7 +144,7 @@ class IPEXConfig(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["LinearMethodBase"]: ) -> "LinearMethodBase | None":
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
if self.method == "awq": if self.method == "awq":
if is_layer_skipped( if is_layer_skipped(
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from fnmatch import fnmatch from fnmatch import fnmatch
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any
import torch import torch
from torch.nn import Module from torch.nn import Module
...@@ -181,7 +181,7 @@ class ModelOptQuantConfigBase(QuantizationConfig): ...@@ -181,7 +181,7 @@ class ModelOptQuantConfigBase(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> "QuantizeMethodBase | None":
# handle kv-cache first so we can focus only on weight quantization thereafter # handle kv-cache first so we can focus only on weight quantization thereafter
if isinstance(layer, Attention): if isinstance(layer, Attention):
return self.KVCacheMethodCls(self) return self.KVCacheMethodCls(self)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional from typing import Any
import torch import torch
...@@ -163,7 +163,7 @@ class MoeWNA16Config(QuantizationConfig): ...@@ -163,7 +163,7 @@ class MoeWNA16Config(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> "QuantizeMethodBase | None":
if is_layer_skipped_quant(prefix, self.modules_to_not_convert): if is_layer_skipped_quant(prefix, self.modules_to_not_convert):
if isinstance(layer, FusedMoE): if isinstance(layer, FusedMoE):
return UnquantizedFusedMoEMethod(layer.moe_config) return UnquantizedFusedMoEMethod(layer.moe_config)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum from enum import Enum
from typing import Optional
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -197,7 +196,7 @@ class Mxfp4Config(QuantizationConfig): ...@@ -197,7 +196,7 @@ class Mxfp4Config(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> "QuantizeMethodBase | None":
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
if self.ignored_layers and is_layer_skipped( if self.ignored_layers and is_layer_skipped(
prefix=prefix, prefix=prefix,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
from typing import Any, Optional from typing import Any
import regex as re import regex as re
import torch import torch
...@@ -159,7 +159,7 @@ class PetitNvFp4Config(QuantizationConfig): ...@@ -159,7 +159,7 @@ class PetitNvFp4Config(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> "QuantizeMethodBase | None":
exclude = self.require_exclude_modules() exclude = self.require_exclude_modules()
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional from typing import Any
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -67,7 +67,7 @@ class PTPCFp8Config(Fp8Config): ...@@ -67,7 +67,7 @@ class PTPCFp8Config(Fp8Config):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> "QuantizeMethodBase | None":
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers): if is_layer_skipped(prefix, self.ignored_layers):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import fnmatch import fnmatch
from typing import TYPE_CHECKING, Any, Optional, cast from typing import TYPE_CHECKING, Any, cast
import torch import torch
...@@ -102,7 +102,7 @@ class QuarkConfig(QuantizationConfig): ...@@ -102,7 +102,7 @@ class QuarkConfig(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> "QuantizeMethodBase | None":
# Check if the layer is skipped for quantization. # Check if the layer is skipped for quantization.
exclude_layers = cast(list[str], self.quant_config.get("exclude")) exclude_layers = cast(list[str], self.quant_config.get("exclude"))
if should_ignore_layer( if should_ignore_layer(
......
...@@ -4,7 +4,7 @@ import importlib ...@@ -4,7 +4,7 @@ import importlib
import json import json
import types import types
from importlib.util import find_spec from importlib.util import find_spec
from typing import Any, Optional from typing import Any
import regex as re import regex as re
import torch import torch
...@@ -209,7 +209,7 @@ class TorchAOConfig(QuantizationConfig): ...@@ -209,7 +209,7 @@ class TorchAOConfig(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> "QuantizeMethodBase | None":
if not isinstance(layer, LinearBase): if not isinstance(layer, LinearBase):
return None return None
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
import torch import torch
...@@ -9,7 +9,7 @@ if TYPE_CHECKING: ...@@ -9,7 +9,7 @@ if TYPE_CHECKING:
from types import ModuleType from types import ModuleType
# 1. Create a global variable as a placeholder for the module # 1. Create a global variable as a placeholder for the module
_petit_kernel: Optional["ModuleType"] = None _petit_kernel: "ModuleType | None" = None
_PETIT_INSTALL_MSG = ( _PETIT_INSTALL_MSG = (
"Petit is not installed. Please install it with `pip install petit-kernel`." "Petit is not installed. Please install it with `pip install petit-kernel`."
......
...@@ -12,7 +12,7 @@ import threading ...@@ -12,7 +12,7 @@ import threading
import time import time
from collections.abc import Generator, MutableMapping from collections.abc import Generator, MutableMapping
from dataclasses import asdict, dataclass, field, fields from dataclasses import asdict, dataclass, field, fields
from typing import TYPE_CHECKING, Any, ClassVar, Optional from typing import TYPE_CHECKING, Any, ClassVar
import regex as re import regex as re
import torch import torch
...@@ -323,7 +323,7 @@ class TensorizerConfig(MutableMapping): ...@@ -323,7 +323,7 @@ class TensorizerConfig(MutableMapping):
" is unstable and may lead to errors." " is unstable and may lead to errors."
) )
def open_stream(self, tensorizer_args: Optional["TensorizerArgs"] = None): def open_stream(self, tensorizer_args: "TensorizerArgs | None" = None):
if tensorizer_args is None: if tensorizer_args is None:
tensorizer_args = self._construct_tensorizer_args() tensorizer_args = self._construct_tensorizer_args()
......
...@@ -11,7 +11,6 @@ from typing import ( ...@@ -11,7 +11,6 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Literal, Literal,
Optional,
TypeAlias, TypeAlias,
TypedDict, TypedDict,
Union, Union,
...@@ -186,7 +185,7 @@ class PlaceholderRange: ...@@ -186,7 +185,7 @@ class PlaceholderRange:
length: int length: int
"""The length of the placeholder.""" """The length of the placeholder."""
is_embed: Optional["torch.Tensor"] = None is_embed: "torch.Tensor | None" = None
""" """
A boolean mask of shape `(length,)` indicating which positions A boolean mask of shape `(length,)` indicating which positions
between `offset` and `offset + length` to assign embeddings to. between `offset` and `offset + length` to assign embeddings to.
...@@ -341,7 +340,7 @@ class MultiModalFeatureSpec: ...@@ -341,7 +340,7 @@ class MultiModalFeatureSpec:
`MultiModalFeatureSpec` per item. `MultiModalFeatureSpec` per item.
""" """
data: Optional["MultiModalKwargsItem"] data: "MultiModalKwargsItem | None"
""" """
Represents multimodal data for this feature. Represents multimodal data for this feature.
......
...@@ -7,7 +7,7 @@ pynvml. However, it should not initialize cuda context. ...@@ -7,7 +7,7 @@ pynvml. However, it should not initialize cuda context.
import os import os
from collections.abc import Callable from collections.abc import Callable
from functools import cache, wraps from functools import cache, wraps
from typing import TYPE_CHECKING, Optional, TypeVar from typing import TYPE_CHECKING, TypeVar
import torch import torch
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
...@@ -382,7 +382,7 @@ class CudaPlatformBase(Platform): ...@@ -382,7 +382,7 @@ class CudaPlatformBase(Platform):
cls, cls,
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
backend: Optional["AttentionBackendEnum"] = None, backend: "AttentionBackendEnum | None" = None,
) -> "AttentionBackendEnum": ) -> "AttentionBackendEnum":
if backend is not None: if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), ( assert backend in cls.get_supported_vit_attn_backends(), (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment