Unverified Commit 1a7894db authored by 杨朱 · Kiki's avatar 杨朱 · Kiki Committed by GitHub
Browse files

[Misc] Replace Optional[X] with X | None syntax (#33332)


Signed-off-by: default avatarcarlory <baofa.fan@daocloud.io>
Co-authored-by: default avatarClaude Opus 4.5 <noreply@anthropic.com>
parent c87eac18
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
from vllm.logprobs import Logprob
from vllm.lora.request import LoRARequest
......@@ -27,7 +27,7 @@ class BeamSearchSequence:
text: str | None = None
finish_reason: str | None = None
stop_reason: int | str | None = None
multi_modal_data: Optional["MultiModalDataDict"] = None
multi_modal_data: "MultiModalDataDict | None" = None
mm_processor_kwargs: dict[str, Any] | None = None
......
......@@ -3,7 +3,7 @@
import importlib
from collections.abc import Callable
from typing import TYPE_CHECKING, Optional, cast
from typing import TYPE_CHECKING, cast
from vllm.distributed.kv_transfer.kv_connector.base import (
KVConnectorBase,
......@@ -44,7 +44,7 @@ class KVConnectorFactory:
cls,
config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
kv_cache_config: "KVCacheConfig | None" = None,
) -> KVConnectorBase:
kv_transfer_config = config.kv_transfer_config
if kv_transfer_config is None:
......
......@@ -41,7 +41,7 @@ The class provides the following primitives:
import enum
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING, Any, Literal, Optional
from typing import TYPE_CHECKING, Any, Literal
import torch
......@@ -161,7 +161,7 @@ class KVConnectorBase_V1(ABC):
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
kv_cache_config: "KVCacheConfig | None" = None,
):
logger.warning(
"Initializing KVConnectorBase_V1. This API is experimental and "
......@@ -383,13 +383,13 @@ class KVConnectorBase_V1(ABC):
"""
return None
def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]:
def get_kv_connector_stats(self) -> "KVConnectorStats | None":
"""
Get the KV connector stats collected during the last interval.
"""
return None
def get_kv_connector_kv_cache_events(self) -> Optional["KVConnectorKVEvents"]:
def get_kv_connector_kv_cache_events(self) -> "KVConnectorKVEvents | None":
"""
Get the KV connector kv cache events collected during the last interval.
This function should be called by the model runner every time after the
......@@ -558,7 +558,7 @@ class KVConnectorBase_V1(ABC):
@classmethod
def build_kv_connector_stats(
cls, data: dict[str, Any] | None = None
) -> Optional["KVConnectorStats"]:
) -> "KVConnectorStats | None":
"""
KVConnectorStats resolution method. This method allows dynamically
registered connectors to return their own KVConnectorStats object,
......@@ -584,7 +584,7 @@ class KVConnectorBase_V1(ABC):
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[object]],
) -> Optional["KVConnectorPromMetrics"]:
) -> "KVConnectorPromMetrics | None":
"""
Create a KVConnectorPromMetrics subclass which should register
per-connector Prometheus metrics and implement observe() to
......
......@@ -32,7 +32,7 @@ Usage:
"""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
import torch
......@@ -84,7 +84,7 @@ class DecodeBenchConnector(KVConnectorBase_V1):
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
kv_cache_config: "KVCacheConfig | None" = None,
):
super().__init__(vllm_config, role, kv_cache_config)
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
import safetensors
import torch
......@@ -91,7 +91,7 @@ class ExampleConnector(KVConnectorBase_V1):
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
kv_cache_config: "KVCacheConfig | None" = None,
):
super().__init__(
vllm_config=vllm_config,
......
......@@ -5,7 +5,7 @@ import os
import uuid
from collections.abc import Generator
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
import torch
from lmcache import utils
......@@ -274,7 +274,7 @@ class ReqMeta:
load_spec: LoadSpec | None = None,
discard_partial_chunks: bool = True,
save_decode_cache: bool = False,
) -> Optional["ReqMeta"]:
) -> "ReqMeta | None":
"""Create the request metadata from a request tracker.
Args:
......
......@@ -3,7 +3,7 @@
import enum
from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
from typing import TYPE_CHECKING, Any, Literal, cast
import torch
import zmq
......@@ -385,7 +385,7 @@ class LMCacheMPConnector(KVConnectorBase_V1):
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
kv_cache_config: "KVCacheConfig | None" = None,
):
super().__init__(vllm_config, role, kv_cache_config)
......@@ -595,7 +595,7 @@ class LMCacheMPConnector(KVConnectorBase_V1):
self.worker_adapter.shutdown()
return None
def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]:
def get_kv_connector_stats(self) -> "KVConnectorStats | None":
"""
Get the KV connector stats collected during the last interval.
"""
......@@ -810,7 +810,7 @@ class LMCacheMPConnector(KVConnectorBase_V1):
@classmethod
def build_kv_connector_stats(
cls, data: dict[str, Any] | None = None
) -> Optional["KVConnectorStats"]:
) -> "KVConnectorStats | None":
"""
KVConnectorStats resolution method. This method allows dynamically
registered connectors to return their own KVConnectorStats object,
......@@ -825,7 +825,7 @@ class LMCacheMPConnector(KVConnectorBase_V1):
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[object]],
) -> Optional["KVConnectorPromMetrics"]:
) -> "KVConnectorPromMetrics | None":
"""
Create a KVConnectorPromMetrics subclass which should register
per-connector Prometheus metrics and implement observe() to
......
......@@ -6,7 +6,7 @@ import time
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
import msgspec
import numpy as np
......@@ -115,7 +115,7 @@ class MooncakeConnector(KVConnectorBase_V1):
self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
kv_cache_config: "KVCacheConfig | None" = None,
):
super().__init__(vllm_config, role, kv_cache_config)
......
......@@ -5,7 +5,7 @@ import threading
import time
from collections.abc import Iterator
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
import msgspec
import torch
......@@ -101,7 +101,7 @@ class MoRIIOAgentMetadata(
class RoleManager:
"""Manages role state across the connector."""
_instance: Optional["RoleManager"] = None
_instance: "RoleManager | None" = None
_lock = threading.Lock()
def __init__(self) -> None:
......
......@@ -7,7 +7,7 @@ import threading
import time
from collections import defaultdict
from concurrent.futures import Future, ThreadPoolExecutor
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
import msgpack
import msgspec
......@@ -90,7 +90,7 @@ class MoRIIOConnector(KVConnectorBase_V1):
self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
kv_cache_config: "KVCacheConfig | None" = None,
):
super().__init__(vllm_config, role)
assert vllm_config.kv_transfer_config is not None, (
......@@ -333,7 +333,7 @@ class MoRIIOConnectorScheduler:
request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int,
connector_worker: Optional["MoRIIOConnectorWorker"] = None,
connector_worker: "MoRIIOConnectorWorker | None" = None,
):
params = request.kv_transfer_params
if not params:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
from weakref import ref as weakref_ref
import msgpack
......@@ -340,7 +340,7 @@ class MoRIIOWrapper:
def __init__(
self,
moriio_engine: Optional["IOEngine"] = None,
moriio_engine: "IOEngine | None" = None,
tp_rank: int = 0,
dp_rank: int = 0,
):
......
......@@ -14,7 +14,7 @@ from collections import defaultdict
from collections.abc import Iterator
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
import msgspec
import numpy as np
......@@ -302,7 +302,7 @@ class NixlConnector(KVConnectorBase_V1):
self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
kv_cache_config: "KVCacheConfig | None" = None,
):
super().__init__(vllm_config, role, kv_cache_config)
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
import regex as re
import torch
......@@ -76,7 +76,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
kv_cache_config: "KVCacheConfig | None" = None,
):
super().__init__(
vllm_config=vllm_config,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
......@@ -49,7 +49,7 @@ def is_v1_kv_transfer_group(connector: KVConnectorBaseType | None = None) -> boo
def ensure_kv_transfer_initialized(
vllm_config: "VllmConfig", kv_cache_config: Optional["KVCacheConfig"] = None
vllm_config: "VllmConfig", kv_cache_config: "KVCacheConfig | None" = None
) -> None:
"""
Initialize KV cache transfer parallel group.
......
......@@ -33,7 +33,7 @@ from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from datetime import timedelta
from multiprocessing import shared_memory
from typing import Any, Optional
from typing import Any
from unittest.mock import patch
import torch
......@@ -106,7 +106,7 @@ def _get_unique_name(name: str) -> str:
return newname
_groups: dict[str, Callable[[], Optional["GroupCoordinator"]]] = {}
_groups: dict[str, Callable[[], "GroupCoordinator | None"]] = {}
def _register_group(group: "GroupCoordinator") -> None:
......@@ -784,7 +784,7 @@ class GroupCoordinator:
self,
tensor_dict: dict[str, torch.Tensor | Any],
dst: int | None = None,
all_gather_group: Optional["GroupCoordinator"] = None,
all_gather_group: "GroupCoordinator | None" = None,
all_gather_tensors: dict[str, bool] | None = None,
) -> dict[str, torch.Tensor | Any] | None:
"""Send the input tensor dictionary.
......@@ -871,7 +871,7 @@ class GroupCoordinator:
def recv_tensor_dict(
self,
src: int | None = None,
all_gather_group: Optional["GroupCoordinator"] = None,
all_gather_group: "GroupCoordinator | None" = None,
all_gather_tensors: dict[str, bool] | None = None,
) -> dict[str, torch.Tensor | Any] | None:
"""Recv the input tensor dictionary.
......
......@@ -3,7 +3,7 @@
"""Pydantic models for Anthropic API protocol"""
import time
from typing import Any, Literal, Optional
from typing import Any, Literal
from pydantic import BaseModel, field_validator
......@@ -135,7 +135,7 @@ class AnthropicStreamEvent(BaseModel):
"ping",
"error",
]
message: Optional["AnthropicMessagesResponse"] = None
message: "AnthropicMessagesResponse | None" = None
delta: AnthropicDelta | None = None
content_block: AnthropicContentBlock | None = None
index: int | None = None
......
......@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence as GenericSequence
from typing import Optional
import torch
import torch.types
......@@ -126,7 +125,7 @@ class PackedLoRALayerWeights(LoRALayerWeights):
@classmethod
def pack(
cls, loras: GenericSequence[Optional["LoRALayerWeights"]]
cls, loras: GenericSequence["LoRALayerWeights | None"]
) -> "PackedLoRALayerWeights":
"""Pack a list of LoRAs into a single LoRA.
......@@ -155,7 +154,7 @@ class PackedLoRALayerWeights(LoRALayerWeights):
@classmethod
def pack_moe(
cls,
loras: GenericSequence[Optional["LoRALayerWeights"]],
loras: GenericSequence["LoRALayerWeights | None"],
module_name: str,
is_non_gated_moe: bool = False,
) -> "PackedLoRALayerWeights":
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING
import huggingface_hub
from huggingface_hub.utils import HfHubHTTPError, HFValidationError
......@@ -131,7 +131,7 @@ def replace_submodule(
def parse_fine_tuned_lora_name(
name: str, weights_mapper: Optional["WeightsMapper"] = None
name: str, weights_mapper: "WeightsMapper | None" = None
) -> tuple[str, bool]:
"""Parse the name of lora weights.
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from enum import IntEnum
from typing import Optional, Union
from typing import Union
import torch
......@@ -284,7 +284,7 @@ class FusedMoEQuantConfig:
return self._w1.bias
@property
def w1_precision(self) -> Optional["PrecisionConfig"]:
def w1_precision(self) -> "PrecisionConfig | None":
assert self._w1.scale is None or isinstance(self._w1.scale, PrecisionConfig)
return self._w1.scale
......@@ -306,7 +306,7 @@ class FusedMoEQuantConfig:
return self._w2.bias
@property
def w2_precision(self) -> Optional["PrecisionConfig"]:
def w2_precision(self) -> "PrecisionConfig | None":
assert self._w2.scale is None or isinstance(self._w2.scale, PrecisionConfig)
return self._w2.scale
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
......@@ -148,7 +148,7 @@ class AWQMarlinConfig(QuantizationConfig):
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
) -> Optional["QuantizationMethods"]:
) -> "QuantizationMethods | None":
can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
is_valid_user_quant = (
user_quant is None or user_quant == "marlin" or user_quant == "awq_marlin"
......@@ -173,7 +173,7 @@ class AWQMarlinConfig(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
) -> "QuantizeMethodBase | None":
if isinstance(layer, LinearBase) or (
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment