"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "c721b814e31d1646ce95bca2acf68fe285fdc34e"
Unverified Commit 1a7894db authored by 杨朱 · Kiki's avatar 杨朱 · Kiki Committed by GitHub
Browse files

[Misc] Replace Optional[X] with X | None syntax (#33332)


Signed-off-by: default avatarcarlory <baofa.fan@daocloud.io>
Co-authored-by: default avatarClaude Opus 4.5 <noreply@anthropic.com>
parent c87eac18
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -27,7 +27,7 @@ class BeamSearchSequence: ...@@ -27,7 +27,7 @@ class BeamSearchSequence:
text: str | None = None text: str | None = None
finish_reason: str | None = None finish_reason: str | None = None
stop_reason: int | str | None = None stop_reason: int | str | None = None
multi_modal_data: Optional["MultiModalDataDict"] = None multi_modal_data: "MultiModalDataDict | None" = None
mm_processor_kwargs: dict[str, Any] | None = None mm_processor_kwargs: dict[str, Any] | None = None
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import importlib import importlib
from collections.abc import Callable from collections.abc import Callable
from typing import TYPE_CHECKING, Optional, cast from typing import TYPE_CHECKING, cast
from vllm.distributed.kv_transfer.kv_connector.base import ( from vllm.distributed.kv_transfer.kv_connector.base import (
KVConnectorBase, KVConnectorBase,
...@@ -44,7 +44,7 @@ class KVConnectorFactory: ...@@ -44,7 +44,7 @@ class KVConnectorFactory:
cls, cls,
config: "VllmConfig", config: "VllmConfig",
role: KVConnectorRole, role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None, kv_cache_config: "KVCacheConfig | None" = None,
) -> KVConnectorBase: ) -> KVConnectorBase:
kv_transfer_config = config.kv_transfer_config kv_transfer_config = config.kv_transfer_config
if kv_transfer_config is None: if kv_transfer_config is None:
......
...@@ -41,7 +41,7 @@ The class provides the following primitives: ...@@ -41,7 +41,7 @@ The class provides the following primitives:
import enum import enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING, Any, Literal, Optional from typing import TYPE_CHECKING, Any, Literal
import torch import torch
...@@ -161,7 +161,7 @@ class KVConnectorBase_V1(ABC): ...@@ -161,7 +161,7 @@ class KVConnectorBase_V1(ABC):
self, self,
vllm_config: "VllmConfig", vllm_config: "VllmConfig",
role: KVConnectorRole, role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None, kv_cache_config: "KVCacheConfig | None" = None,
): ):
logger.warning( logger.warning(
"Initializing KVConnectorBase_V1. This API is experimental and " "Initializing KVConnectorBase_V1. This API is experimental and "
...@@ -383,13 +383,13 @@ class KVConnectorBase_V1(ABC): ...@@ -383,13 +383,13 @@ class KVConnectorBase_V1(ABC):
""" """
return None return None
def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]: def get_kv_connector_stats(self) -> "KVConnectorStats | None":
""" """
Get the KV connector stats collected during the last interval. Get the KV connector stats collected during the last interval.
""" """
return None return None
def get_kv_connector_kv_cache_events(self) -> Optional["KVConnectorKVEvents"]: def get_kv_connector_kv_cache_events(self) -> "KVConnectorKVEvents | None":
""" """
Get the KV connector kv cache events collected during the last interval. Get the KV connector kv cache events collected during the last interval.
This function should be called by the model runner every time after the This function should be called by the model runner every time after the
...@@ -558,7 +558,7 @@ class KVConnectorBase_V1(ABC): ...@@ -558,7 +558,7 @@ class KVConnectorBase_V1(ABC):
@classmethod @classmethod
def build_kv_connector_stats( def build_kv_connector_stats(
cls, data: dict[str, Any] | None = None cls, data: dict[str, Any] | None = None
) -> Optional["KVConnectorStats"]: ) -> "KVConnectorStats | None":
""" """
KVConnectorStats resolution method. This method allows dynamically KVConnectorStats resolution method. This method allows dynamically
registered connectors to return their own KVConnectorStats object, registered connectors to return their own KVConnectorStats object,
...@@ -584,7 +584,7 @@ class KVConnectorBase_V1(ABC): ...@@ -584,7 +584,7 @@ class KVConnectorBase_V1(ABC):
metric_types: dict[type["PromMetric"], type["PromMetricT"]], metric_types: dict[type["PromMetric"], type["PromMetricT"]],
labelnames: list[str], labelnames: list[str],
per_engine_labelvalues: dict[int, list[object]], per_engine_labelvalues: dict[int, list[object]],
) -> Optional["KVConnectorPromMetrics"]: ) -> "KVConnectorPromMetrics | None":
""" """
Create a KVConnectorPromMetrics subclass which should register Create a KVConnectorPromMetrics subclass which should register
per-connector Prometheus metrics and implement observe() to per-connector Prometheus metrics and implement observe() to
......
...@@ -32,7 +32,7 @@ Usage: ...@@ -32,7 +32,7 @@ Usage:
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any
import torch import torch
...@@ -84,7 +84,7 @@ class DecodeBenchConnector(KVConnectorBase_V1): ...@@ -84,7 +84,7 @@ class DecodeBenchConnector(KVConnectorBase_V1):
self, self,
vllm_config: "VllmConfig", vllm_config: "VllmConfig",
role: KVConnectorRole, role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None, kv_cache_config: "KVCacheConfig | None" = None,
): ):
super().__init__(vllm_config, role, kv_cache_config) super().__init__(vllm_config, role, kv_cache_config)
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any
import safetensors import safetensors
import torch import torch
...@@ -91,7 +91,7 @@ class ExampleConnector(KVConnectorBase_V1): ...@@ -91,7 +91,7 @@ class ExampleConnector(KVConnectorBase_V1):
self, self,
vllm_config: "VllmConfig", vllm_config: "VllmConfig",
role: KVConnectorRole, role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None, kv_cache_config: "KVCacheConfig | None" = None,
): ):
super().__init__( super().__init__(
vllm_config=vllm_config, vllm_config=vllm_config,
......
...@@ -5,7 +5,7 @@ import os ...@@ -5,7 +5,7 @@ import os
import uuid import uuid
from collections.abc import Generator from collections.abc import Generator
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any
import torch import torch
from lmcache import utils from lmcache import utils
...@@ -274,7 +274,7 @@ class ReqMeta: ...@@ -274,7 +274,7 @@ class ReqMeta:
load_spec: LoadSpec | None = None, load_spec: LoadSpec | None = None,
discard_partial_chunks: bool = True, discard_partial_chunks: bool = True,
save_decode_cache: bool = False, save_decode_cache: bool = False,
) -> Optional["ReqMeta"]: ) -> "ReqMeta | None":
"""Create the request metadata from a request tracker. """Create the request metadata from a request tracker.
Args: Args:
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import enum import enum
from collections.abc import Iterable from collections.abc import Iterable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal, Optional, cast from typing import TYPE_CHECKING, Any, Literal, cast
import torch import torch
import zmq import zmq
...@@ -385,7 +385,7 @@ class LMCacheMPConnector(KVConnectorBase_V1): ...@@ -385,7 +385,7 @@ class LMCacheMPConnector(KVConnectorBase_V1):
self, self,
vllm_config: "VllmConfig", vllm_config: "VllmConfig",
role: KVConnectorRole, role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None, kv_cache_config: "KVCacheConfig | None" = None,
): ):
super().__init__(vllm_config, role, kv_cache_config) super().__init__(vllm_config, role, kv_cache_config)
...@@ -595,7 +595,7 @@ class LMCacheMPConnector(KVConnectorBase_V1): ...@@ -595,7 +595,7 @@ class LMCacheMPConnector(KVConnectorBase_V1):
self.worker_adapter.shutdown() self.worker_adapter.shutdown()
return None return None
def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]: def get_kv_connector_stats(self) -> "KVConnectorStats | None":
""" """
Get the KV connector stats collected during the last interval. Get the KV connector stats collected during the last interval.
""" """
...@@ -810,7 +810,7 @@ class LMCacheMPConnector(KVConnectorBase_V1): ...@@ -810,7 +810,7 @@ class LMCacheMPConnector(KVConnectorBase_V1):
@classmethod @classmethod
def build_kv_connector_stats( def build_kv_connector_stats(
cls, data: dict[str, Any] | None = None cls, data: dict[str, Any] | None = None
) -> Optional["KVConnectorStats"]: ) -> "KVConnectorStats | None":
""" """
KVConnectorStats resolution method. This method allows dynamically KVConnectorStats resolution method. This method allows dynamically
registered connectors to return their own KVConnectorStats object, registered connectors to return their own KVConnectorStats object,
...@@ -825,7 +825,7 @@ class LMCacheMPConnector(KVConnectorBase_V1): ...@@ -825,7 +825,7 @@ class LMCacheMPConnector(KVConnectorBase_V1):
metric_types: dict[type["PromMetric"], type["PromMetricT"]], metric_types: dict[type["PromMetric"], type["PromMetricT"]],
labelnames: list[str], labelnames: list[str],
per_engine_labelvalues: dict[int, list[object]], per_engine_labelvalues: dict[int, list[object]],
) -> Optional["KVConnectorPromMetrics"]: ) -> "KVConnectorPromMetrics | None":
""" """
Create a KVConnectorPromMetrics subclass which should register Create a KVConnectorPromMetrics subclass which should register
per-connector Prometheus metrics and implement observe() to per-connector Prometheus metrics and implement observe() to
......
...@@ -6,7 +6,7 @@ import time ...@@ -6,7 +6,7 @@ import time
from collections import defaultdict from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any
import msgspec import msgspec
import numpy as np import numpy as np
...@@ -115,7 +115,7 @@ class MooncakeConnector(KVConnectorBase_V1): ...@@ -115,7 +115,7 @@ class MooncakeConnector(KVConnectorBase_V1):
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
role: KVConnectorRole, role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None, kv_cache_config: "KVCacheConfig | None" = None,
): ):
super().__init__(vllm_config, role, kv_cache_config) super().__init__(vllm_config, role, kv_cache_config)
......
...@@ -5,7 +5,7 @@ import threading ...@@ -5,7 +5,7 @@ import threading
import time import time
from collections.abc import Iterator from collections.abc import Iterator
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any
import msgspec import msgspec
import torch import torch
...@@ -101,7 +101,7 @@ class MoRIIOAgentMetadata( ...@@ -101,7 +101,7 @@ class MoRIIOAgentMetadata(
class RoleManager: class RoleManager:
"""Manages role state across the connector.""" """Manages role state across the connector."""
_instance: Optional["RoleManager"] = None _instance: "RoleManager | None" = None
_lock = threading.Lock() _lock = threading.Lock()
def __init__(self) -> None: def __init__(self) -> None:
......
...@@ -7,7 +7,7 @@ import threading ...@@ -7,7 +7,7 @@ import threading
import time import time
from collections import defaultdict from collections import defaultdict
from concurrent.futures import Future, ThreadPoolExecutor from concurrent.futures import Future, ThreadPoolExecutor
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any
import msgpack import msgpack
import msgspec import msgspec
...@@ -90,7 +90,7 @@ class MoRIIOConnector(KVConnectorBase_V1): ...@@ -90,7 +90,7 @@ class MoRIIOConnector(KVConnectorBase_V1):
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
role: KVConnectorRole, role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None, kv_cache_config: "KVCacheConfig | None" = None,
): ):
super().__init__(vllm_config, role) super().__init__(vllm_config, role)
assert vllm_config.kv_transfer_config is not None, ( assert vllm_config.kv_transfer_config is not None, (
...@@ -333,7 +333,7 @@ class MoRIIOConnectorScheduler: ...@@ -333,7 +333,7 @@ class MoRIIOConnectorScheduler:
request: "Request", request: "Request",
blocks: "KVCacheBlocks", blocks: "KVCacheBlocks",
num_external_tokens: int, num_external_tokens: int,
connector_worker: Optional["MoRIIOConnectorWorker"] = None, connector_worker: "MoRIIOConnectorWorker | None" = None,
): ):
params = request.kv_transfer_params params = request.kv_transfer_params
if not params: if not params:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading import threading
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any
from weakref import ref as weakref_ref from weakref import ref as weakref_ref
import msgpack import msgpack
...@@ -340,7 +340,7 @@ class MoRIIOWrapper: ...@@ -340,7 +340,7 @@ class MoRIIOWrapper:
def __init__( def __init__(
self, self,
moriio_engine: Optional["IOEngine"] = None, moriio_engine: "IOEngine | None" = None,
tp_rank: int = 0, tp_rank: int = 0,
dp_rank: int = 0, dp_rank: int = 0,
): ):
......
...@@ -14,7 +14,7 @@ from collections import defaultdict ...@@ -14,7 +14,7 @@ from collections import defaultdict
from collections.abc import Iterator from collections.abc import Iterator
from concurrent.futures import Future, ThreadPoolExecutor from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any
import msgspec import msgspec
import numpy as np import numpy as np
...@@ -302,7 +302,7 @@ class NixlConnector(KVConnectorBase_V1): ...@@ -302,7 +302,7 @@ class NixlConnector(KVConnectorBase_V1):
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
role: KVConnectorRole, role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None, kv_cache_config: "KVCacheConfig | None" = None,
): ):
super().__init__(vllm_config, role, kv_cache_config) super().__init__(vllm_config, role, kv_cache_config)
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any
import regex as re import regex as re
import torch import torch
...@@ -76,7 +76,7 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -76,7 +76,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
self, self,
vllm_config: "VllmConfig", vllm_config: "VllmConfig",
role: KVConnectorRole, role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None, kv_cache_config: "KVCacheConfig | None" = None,
): ):
super().__init__( super().__init__(
vllm_config=vllm_config, vllm_config=vllm_config,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
...@@ -49,7 +49,7 @@ def is_v1_kv_transfer_group(connector: KVConnectorBaseType | None = None) -> boo ...@@ -49,7 +49,7 @@ def is_v1_kv_transfer_group(connector: KVConnectorBaseType | None = None) -> boo
def ensure_kv_transfer_initialized( def ensure_kv_transfer_initialized(
vllm_config: "VllmConfig", kv_cache_config: Optional["KVCacheConfig"] = None vllm_config: "VllmConfig", kv_cache_config: "KVCacheConfig | None" = None
) -> None: ) -> None:
""" """
Initialize KV cache transfer parallel group. Initialize KV cache transfer parallel group.
......
...@@ -33,7 +33,7 @@ from contextlib import contextmanager, nullcontext ...@@ -33,7 +33,7 @@ from contextlib import contextmanager, nullcontext
from dataclasses import dataclass from dataclasses import dataclass
from datetime import timedelta from datetime import timedelta
from multiprocessing import shared_memory from multiprocessing import shared_memory
from typing import Any, Optional from typing import Any
from unittest.mock import patch from unittest.mock import patch
import torch import torch
...@@ -106,7 +106,7 @@ def _get_unique_name(name: str) -> str: ...@@ -106,7 +106,7 @@ def _get_unique_name(name: str) -> str:
return newname return newname
_groups: dict[str, Callable[[], Optional["GroupCoordinator"]]] = {} _groups: dict[str, Callable[[], "GroupCoordinator | None"]] = {}
def _register_group(group: "GroupCoordinator") -> None: def _register_group(group: "GroupCoordinator") -> None:
...@@ -784,7 +784,7 @@ class GroupCoordinator: ...@@ -784,7 +784,7 @@ class GroupCoordinator:
self, self,
tensor_dict: dict[str, torch.Tensor | Any], tensor_dict: dict[str, torch.Tensor | Any],
dst: int | None = None, dst: int | None = None,
all_gather_group: Optional["GroupCoordinator"] = None, all_gather_group: "GroupCoordinator | None" = None,
all_gather_tensors: dict[str, bool] | None = None, all_gather_tensors: dict[str, bool] | None = None,
) -> dict[str, torch.Tensor | Any] | None: ) -> dict[str, torch.Tensor | Any] | None:
"""Send the input tensor dictionary. """Send the input tensor dictionary.
...@@ -871,7 +871,7 @@ class GroupCoordinator: ...@@ -871,7 +871,7 @@ class GroupCoordinator:
def recv_tensor_dict( def recv_tensor_dict(
self, self,
src: int | None = None, src: int | None = None,
all_gather_group: Optional["GroupCoordinator"] = None, all_gather_group: "GroupCoordinator | None" = None,
all_gather_tensors: dict[str, bool] | None = None, all_gather_tensors: dict[str, bool] | None = None,
) -> dict[str, torch.Tensor | Any] | None: ) -> dict[str, torch.Tensor | Any] | None:
"""Recv the input tensor dictionary. """Recv the input tensor dictionary.
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
"""Pydantic models for Anthropic API protocol""" """Pydantic models for Anthropic API protocol"""
import time import time
from typing import Any, Literal, Optional from typing import Any, Literal
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator
...@@ -135,7 +135,7 @@ class AnthropicStreamEvent(BaseModel): ...@@ -135,7 +135,7 @@ class AnthropicStreamEvent(BaseModel):
"ping", "ping",
"error", "error",
] ]
message: Optional["AnthropicMessagesResponse"] = None message: "AnthropicMessagesResponse | None" = None
delta: AnthropicDelta | None = None delta: AnthropicDelta | None = None
content_block: AnthropicContentBlock | None = None content_block: AnthropicContentBlock | None = None
index: int | None = None index: int | None = None
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence as GenericSequence from collections.abc import Sequence as GenericSequence
from typing import Optional
import torch import torch
import torch.types import torch.types
...@@ -126,7 +125,7 @@ class PackedLoRALayerWeights(LoRALayerWeights): ...@@ -126,7 +125,7 @@ class PackedLoRALayerWeights(LoRALayerWeights):
@classmethod @classmethod
def pack( def pack(
cls, loras: GenericSequence[Optional["LoRALayerWeights"]] cls, loras: GenericSequence["LoRALayerWeights | None"]
) -> "PackedLoRALayerWeights": ) -> "PackedLoRALayerWeights":
"""Pack a list of LoRAs into a single LoRA. """Pack a list of LoRAs into a single LoRA.
...@@ -155,7 +154,7 @@ class PackedLoRALayerWeights(LoRALayerWeights): ...@@ -155,7 +154,7 @@ class PackedLoRALayerWeights(LoRALayerWeights):
@classmethod @classmethod
def pack_moe( def pack_moe(
cls, cls,
loras: GenericSequence[Optional["LoRALayerWeights"]], loras: GenericSequence["LoRALayerWeights | None"],
module_name: str, module_name: str,
is_non_gated_moe: bool = False, is_non_gated_moe: bool = False,
) -> "PackedLoRALayerWeights": ) -> "PackedLoRALayerWeights":
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os import os
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
import huggingface_hub import huggingface_hub
from huggingface_hub.utils import HfHubHTTPError, HFValidationError from huggingface_hub.utils import HfHubHTTPError, HFValidationError
...@@ -131,7 +131,7 @@ def replace_submodule( ...@@ -131,7 +131,7 @@ def replace_submodule(
def parse_fine_tuned_lora_name( def parse_fine_tuned_lora_name(
name: str, weights_mapper: Optional["WeightsMapper"] = None name: str, weights_mapper: "WeightsMapper | None" = None
) -> tuple[str, bool]: ) -> tuple[str, bool]:
"""Parse the name of lora weights. """Parse the name of lora weights.
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum from enum import IntEnum
from typing import Optional, Union from typing import Union
import torch import torch
...@@ -284,7 +284,7 @@ class FusedMoEQuantConfig: ...@@ -284,7 +284,7 @@ class FusedMoEQuantConfig:
return self._w1.bias return self._w1.bias
@property @property
def w1_precision(self) -> Optional["PrecisionConfig"]: def w1_precision(self) -> "PrecisionConfig | None":
assert self._w1.scale is None or isinstance(self._w1.scale, PrecisionConfig) assert self._w1.scale is None or isinstance(self._w1.scale, PrecisionConfig)
return self._w1.scale return self._w1.scale
...@@ -306,7 +306,7 @@ class FusedMoEQuantConfig: ...@@ -306,7 +306,7 @@ class FusedMoEQuantConfig:
return self._w2.bias return self._w2.bias
@property @property
def w2_precision(self) -> Optional["PrecisionConfig"]: def w2_precision(self) -> "PrecisionConfig | None":
assert self._w2.scale is None or isinstance(self._w2.scale, PrecisionConfig) assert self._w2.scale is None or isinstance(self._w2.scale, PrecisionConfig)
return self._w2.scale return self._w2.scale
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any
import torch import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
...@@ -148,7 +148,7 @@ class AWQMarlinConfig(QuantizationConfig): ...@@ -148,7 +148,7 @@ class AWQMarlinConfig(QuantizationConfig):
@classmethod @classmethod
def override_quantization_method( def override_quantization_method(
cls, hf_quant_cfg, user_quant cls, hf_quant_cfg, user_quant
) -> Optional["QuantizationMethods"]: ) -> "QuantizationMethods | None":
can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg) can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
is_valid_user_quant = ( is_valid_user_quant = (
user_quant is None or user_quant == "marlin" or user_quant == "awq_marlin" user_quant is None or user_quant == "marlin" or user_quant == "awq_marlin"
...@@ -173,7 +173,7 @@ class AWQMarlinConfig(QuantizationConfig): ...@@ -173,7 +173,7 @@ class AWQMarlinConfig(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> "QuantizeMethodBase | None":
if isinstance(layer, LinearBase) or ( if isinstance(layer, LinearBase) or (
isinstance(layer, ParallelLMHead) and self.lm_head_quantized isinstance(layer, ParallelLMHead) and self.lm_head_quantized
): ):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment