Unverified Commit 1a7894db authored by 杨朱 · Kiki's avatar 杨朱 · Kiki Committed by GitHub
Browse files

[Misc] Replace Optional[X] with X | None syntax (#33332)


Signed-off-by: default avatarcarlory <baofa.fan@daocloud.io>
Co-authored-by: default avatarClaude Opus 4.5 <noreply@anthropic.com>
parent c87eac18
......@@ -3,7 +3,7 @@
from contextlib import suppress
from functools import partial
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
from typing import TYPE_CHECKING, Any, Literal, cast
import torch
from compressed_tensors.config import (
......@@ -160,7 +160,7 @@ class CompressedTensorsConfig(QuantizationConfig):
self,
layer: torch.nn.Module,
prefix: str,
) -> Optional["QuantizeMethodBase"]:
) -> "QuantizeMethodBase | None":
if isinstance(layer, LinearBase):
# collect schemes
quant_scheme = self.get_scheme(layer=layer, layer_name=prefix)
......@@ -691,7 +691,7 @@ class CompressedTensorsConfig(QuantizationConfig):
def get_scheme(
self, layer: torch.nn.Module, layer_name: str | None = None
) -> Optional["CompressedTensorsScheme"]:
) -> "CompressedTensorsScheme | None":
"""
compressed-tensors supports non uniform in the following way:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
from typing import Any
import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
......@@ -105,7 +105,7 @@ class CPUAWQConfig(QuantizationConfig):
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
) -> Optional["QuantizationMethods"]:
) -> "QuantizationMethods | None":
quant_method = hf_quant_cfg.get("quant_method", "").lower()
if current_platform.is_cpu() and (quant_method == "awq"):
return cls.get_name()
......@@ -113,7 +113,7 @@ class CPUAWQConfig(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
) -> "QuantizeMethodBase | None":
if isinstance(layer, LinearBase) or (
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
from typing import Any
import torch
......@@ -52,7 +52,7 @@ class ExpertsInt8Config(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
) -> "QuantizeMethodBase | None":
if isinstance(layer, LinearBase):
return UnquantizedLinearMethod()
elif isinstance(layer, FusedMoE):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
from typing import Any
import torch
from torch.nn import Module
......@@ -78,7 +78,7 @@ class FBGEMMFp8Config(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
) -> "QuantizeMethodBase | None":
if isinstance(layer, LinearBase):
if is_layer_skipped(
prefix=prefix,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
import torch
from torch.nn import Module
......@@ -182,7 +182,7 @@ class Fp8Config(QuantizationConfig):
def get_xpu_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
) -> "QuantizeMethodBase | None":
from vllm.model_executor.layers.quantization.ipex_quant import (
XPUFp8LinearMethod,
XPUFp8MoEMethod,
......@@ -218,7 +218,7 @@ class Fp8Config(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
) -> "QuantizeMethodBase | None":
if current_platform.is_xpu():
return self.get_xpu_quant_method(layer, prefix)
if isinstance(layer, LinearBase):
......
......@@ -3,7 +3,7 @@
from collections.abc import Mapping
from types import MappingProxyType
from typing import Any, Optional
from typing import Any
import gguf
import torch
......@@ -77,7 +77,7 @@ class GGUFConfig(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
) -> "QuantizeMethodBase | None":
if isinstance(layer, LinearBase):
if is_layer_skipped_gguf(
prefix, self.unquantized_modules, self.packed_modules_mapping
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import deepcopy
from typing import Any, Optional
from typing import Any
import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
......@@ -240,7 +240,7 @@ class GPTQMarlinConfig(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
) -> "QuantizeMethodBase | None":
if isinstance(layer, FusedMoE):
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from fractions import Fraction
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
import regex as re
import torch
......@@ -456,7 +456,7 @@ class INCConfig(QuantizationConfig):
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
) -> Optional["QuantizationMethods"]:
) -> "QuantizationMethods | None":
"""Override the `auto-round` method to `inc`."""
is_auto_round_format = hf_quant_cfg.get("quant_method", None) == "auto-round"
if is_auto_round_format:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
from typing import Any
import torch
from packaging import version
......@@ -144,7 +144,7 @@ class IPEXConfig(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["LinearMethodBase"]:
) -> "LinearMethodBase | None":
if isinstance(layer, LinearBase):
if self.method == "awq":
if is_layer_skipped(
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from fnmatch import fnmatch
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
import torch
from torch.nn import Module
......@@ -181,7 +181,7 @@ class ModelOptQuantConfigBase(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
) -> "QuantizeMethodBase | None":
# handle kv-cache first so we can focus only on weight quantization thereafter
if isinstance(layer, Attention):
return self.KVCacheMethodCls(self)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
from typing import Any
import torch
......@@ -163,7 +163,7 @@ class MoeWNA16Config(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
) -> "QuantizeMethodBase | None":
if is_layer_skipped_quant(prefix, self.modules_to_not_convert):
if isinstance(layer, FusedMoE):
return UnquantizedFusedMoEMethod(layer.moe_config)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
from typing import Optional
import torch
from torch.nn.parameter import Parameter
......@@ -197,7 +196,7 @@ class Mxfp4Config(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
) -> "QuantizeMethodBase | None":
if isinstance(layer, LinearBase):
if self.ignored_layers and is_layer_skipped(
prefix=prefix,
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
from typing import Any, Optional
from typing import Any
import regex as re
import torch
......@@ -159,7 +159,7 @@ class PetitNvFp4Config(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
) -> "QuantizeMethodBase | None":
exclude = self.require_exclude_modules()
if isinstance(layer, LinearBase):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
from typing import Any
import torch
from torch.nn.parameter import Parameter
......@@ -67,7 +67,7 @@ class PTPCFp8Config(Fp8Config):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
) -> "QuantizeMethodBase | None":
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
return UnquantizedLinearMethod()
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import fnmatch
from typing import TYPE_CHECKING, Any, Optional, cast
from typing import TYPE_CHECKING, Any, cast
import torch
......@@ -102,7 +102,7 @@ class QuarkConfig(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
) -> "QuantizeMethodBase | None":
# Check if the layer is skipped for quantization.
exclude_layers = cast(list[str], self.quant_config.get("exclude"))
if should_ignore_layer(
......
......@@ -4,7 +4,7 @@ import importlib
import json
import types
from importlib.util import find_spec
from typing import Any, Optional
from typing import Any
import regex as re
import torch
......@@ -209,7 +209,7 @@ class TorchAOConfig(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
) -> "QuantizeMethodBase | None":
if not isinstance(layer, LinearBase):
return None
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING
import torch
......@@ -9,7 +9,7 @@ if TYPE_CHECKING:
from types import ModuleType
# 1. Create a global variable as a placeholder for the module
_petit_kernel: Optional["ModuleType"] = None
_petit_kernel: "ModuleType | None" = None
_PETIT_INSTALL_MSG = (
"Petit is not installed. Please install it with `pip install petit-kernel`."
......
......@@ -12,7 +12,7 @@ import threading
import time
from collections.abc import Generator, MutableMapping
from dataclasses import asdict, dataclass, field, fields
from typing import TYPE_CHECKING, Any, ClassVar, Optional
from typing import TYPE_CHECKING, Any, ClassVar
import regex as re
import torch
......@@ -323,7 +323,7 @@ class TensorizerConfig(MutableMapping):
" is unstable and may lead to errors."
)
def open_stream(self, tensorizer_args: Optional["TensorizerArgs"] = None):
def open_stream(self, tensorizer_args: "TensorizerArgs | None" = None):
if tensorizer_args is None:
tensorizer_args = self._construct_tensorizer_args()
......
......@@ -11,7 +11,6 @@ from typing import (
TYPE_CHECKING,
Any,
Literal,
Optional,
TypeAlias,
TypedDict,
Union,
......@@ -186,7 +185,7 @@ class PlaceholderRange:
length: int
"""The length of the placeholder."""
is_embed: Optional["torch.Tensor"] = None
is_embed: "torch.Tensor | None" = None
"""
A boolean mask of shape `(length,)` indicating which positions
between `offset` and `offset + length` to assign embeddings to.
......@@ -341,7 +340,7 @@ class MultiModalFeatureSpec:
`MultiModalFeatureSpec` per item.
"""
data: Optional["MultiModalKwargsItem"]
data: "MultiModalKwargsItem | None"
"""
Represents multimodal data for this feature.
......
......@@ -7,7 +7,7 @@ pynvml. However, it should not initialize cuda context.
import os
from collections.abc import Callable
from functools import cache, wraps
from typing import TYPE_CHECKING, Optional, TypeVar
from typing import TYPE_CHECKING, TypeVar
import torch
from typing_extensions import ParamSpec
......@@ -382,7 +382,7 @@ class CudaPlatformBase(Platform):
cls,
head_size: int,
dtype: torch.dtype,
backend: Optional["AttentionBackendEnum"] = None,
backend: "AttentionBackendEnum | None" = None,
) -> "AttentionBackendEnum":
if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment