Unverified Commit c8ea982d authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated type hinting in `platform`, `plugins`, `triton_utils`, `vllm_flash_attn` (#18129)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent dc372b9c
...@@ -78,13 +78,8 @@ exclude = [ ...@@ -78,13 +78,8 @@ exclude = [
"vllm/executor/**/*.py" = ["UP006", "UP035"] "vllm/executor/**/*.py" = ["UP006", "UP035"]
"vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"] "vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"]
"vllm/model_executor/models/**/*.py" = ["UP006", "UP035"] "vllm/model_executor/models/**/*.py" = ["UP006", "UP035"]
"vllm/platforms/**/*.py" = ["UP006", "UP035"]
"vllm/plugins/**/*.py" = ["UP006", "UP035"]
"vllm/prompt_adapter/**/*.py" = ["UP006", "UP035"] "vllm/prompt_adapter/**/*.py" = ["UP006", "UP035"]
"vllm/spec_decode/**/*.py" = ["UP006", "UP035"] "vllm/spec_decode/**/*.py" = ["UP006", "UP035"]
"vllm/transformers_utils/**/*.py" = ["UP006", "UP035"]
"vllm/triton_utils/**/*.py" = ["UP006", "UP035"]
"vllm/vllm_flash_attn/**/*.py" = ["UP006", "UP035"]
"vllm/worker/**/*.py" = ["UP006", "UP035"] "vllm/worker/**/*.py" = ["UP006", "UP035"]
"vllm/utils.py" = ["UP006", "UP035"] "vllm/utils.py" = ["UP006", "UP035"]
......
...@@ -5,8 +5,7 @@ pynvml. However, it should not initialize cuda context. ...@@ -5,8 +5,7 @@ pynvml. However, it should not initialize cuda context.
import os import os
from functools import wraps from functools import wraps
from typing import (TYPE_CHECKING, Callable, List, Optional, Tuple, TypeVar, from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
Union)
import torch import torch
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
...@@ -56,7 +55,7 @@ class CudaPlatformBase(Platform): ...@@ -56,7 +55,7 @@ class CudaPlatformBase(Platform):
device_control_env_var: str = "CUDA_VISIBLE_DEVICES" device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
@property @property
def supported_dtypes(self) -> List[torch.dtype]: def supported_dtypes(self) -> list[torch.dtype]:
if self.has_device_capability(80): if self.has_device_capability(80):
# Ampere and Hopper or later NVIDIA GPUs. # Ampere and Hopper or later NVIDIA GPUs.
return [torch.bfloat16, torch.float16, torch.float32] return [torch.bfloat16, torch.float16, torch.float32]
...@@ -93,7 +92,7 @@ class CudaPlatformBase(Platform): ...@@ -93,7 +92,7 @@ class CudaPlatformBase(Platform):
return True return True
@classmethod @classmethod
def is_fully_connected(cls, device_ids: List[int]) -> bool: def is_fully_connected(cls, device_ids: list[int]) -> bool:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
...@@ -335,7 +334,7 @@ class NvmlCudaPlatform(CudaPlatformBase): ...@@ -335,7 +334,7 @@ class NvmlCudaPlatform(CudaPlatformBase):
@with_nvml_context @with_nvml_context
def has_device_capability( def has_device_capability(
cls, cls,
capability: Union[Tuple[int, int], int], capability: Union[tuple[int, int], int],
device_id: int = 0, device_id: int = 0,
) -> bool: ) -> bool:
try: try:
...@@ -365,7 +364,7 @@ class NvmlCudaPlatform(CudaPlatformBase): ...@@ -365,7 +364,7 @@ class NvmlCudaPlatform(CudaPlatformBase):
@classmethod @classmethod
@with_nvml_context @with_nvml_context
def is_fully_connected(cls, physical_device_ids: List[int]) -> bool: def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
""" """
query if the set of gpus are fully connected by nvlink (1 hop) query if the set of gpus are fully connected by nvlink (1 hop)
""" """
...@@ -430,7 +429,7 @@ class NonNvmlCudaPlatform(CudaPlatformBase): ...@@ -430,7 +429,7 @@ class NonNvmlCudaPlatform(CudaPlatformBase):
return device_props.total_memory return device_props.total_memory
@classmethod @classmethod
def is_fully_connected(cls, physical_device_ids: List[int]) -> bool: def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
logger.exception( logger.exception(
"NVLink detection not possible, as context support was" "NVLink detection not possible, as context support was"
" not found. Assuming no NVLink available.") " not found. Assuming no NVLink available.")
......
...@@ -4,7 +4,7 @@ import os ...@@ -4,7 +4,7 @@ import os
import platform import platform
import random import random
from platform import uname from platform import uname
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union from typing import TYPE_CHECKING, NamedTuple, Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -200,7 +200,7 @@ class Platform: ...@@ -200,7 +200,7 @@ class Platform:
@classmethod @classmethod
def has_device_capability( def has_device_capability(
cls, cls,
capability: Union[Tuple[int, int], int], capability: Union[tuple[int, int], int],
device_id: int = 0, device_id: int = 0,
) -> bool: ) -> bool:
""" """
...@@ -362,7 +362,7 @@ class Platform: ...@@ -362,7 +362,7 @@ class Platform:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]: def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
""" """
Return the platform specific values for (-inf, inf) Return the platform specific values for (-inf, inf)
""" """
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import os import os
from functools import cache, lru_cache, wraps from functools import cache, lru_cache, wraps
from typing import TYPE_CHECKING, Dict, List, Optional from typing import TYPE_CHECKING, Optional
import torch import torch
...@@ -35,7 +35,7 @@ except ImportError as e: ...@@ -35,7 +35,7 @@ except ImportError as e:
logger.warning("Failed to import from vllm._rocm_C with %r", e) logger.warning("Failed to import from vllm._rocm_C with %r", e)
# Models not supported by ROCm. # Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS: List[str] = [] _ROCM_UNSUPPORTED_MODELS: list[str] = []
# Models partially supported by ROCm. # Models partially supported by ROCm.
# Architecture -> Reason. # Architecture -> Reason.
...@@ -43,7 +43,7 @@ _ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in " ...@@ -43,7 +43,7 @@ _ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
"Triton flash attention. For half-precision SWA support, " "Triton flash attention. For half-precision SWA support, "
"please use CK flash attention by setting " "please use CK flash attention by setting "
"`VLLM_USE_TRITON_FLASH_ATTN=0`") "`VLLM_USE_TRITON_FLASH_ATTN=0`")
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = { _ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {
"Qwen2ForCausalLM": "Qwen2ForCausalLM":
_ROCM_SWA_REASON, _ROCM_SWA_REASON,
"MistralForCausalLM": "MistralForCausalLM":
...@@ -58,7 +58,7 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = { ...@@ -58,7 +58,7 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
"excessive use of shared memory. If this happens, disable Triton FA " "excessive use of shared memory. If this happens, disable Triton FA "
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`") "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
} }
_ROCM_DEVICE_ID_NAME_MAP: Dict[str, str] = { _ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
"0x74a0": "AMD_Instinct_MI300A", "0x74a0": "AMD_Instinct_MI300A",
"0x74a1": "AMD_Instinct_MI300X", "0x74a1": "AMD_Instinct_MI300X",
"0x74b5": "AMD_Instinct_MI300X", # MI300X VF "0x74b5": "AMD_Instinct_MI300X", # MI300X VF
...@@ -203,7 +203,7 @@ class RocmPlatform(Platform): ...@@ -203,7 +203,7 @@ class RocmPlatform(Platform):
@staticmethod @staticmethod
@with_amdsmi_context @with_amdsmi_context
def is_fully_connected(physical_device_ids: List[int]) -> bool: def is_fully_connected(physical_device_ids: list[int]) -> bool:
""" """
Query if the set of gpus are fully connected by xgmi (1 hop) Query if the set of gpus are fully connected by xgmi (1 hop)
""" """
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast from typing import TYPE_CHECKING, Optional, Union, cast
import torch import torch
from tpu_info import device from tpu_info import device
...@@ -73,7 +73,7 @@ class TpuPlatform(Platform): ...@@ -73,7 +73,7 @@ class TpuPlatform(Platform):
return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU" return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU"
@classmethod @classmethod
def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]: def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
return torch.finfo(dtype).min, torch.finfo(dtype).max return torch.finfo(dtype).min, torch.finfo(dtype).max
@classmethod @classmethod
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import logging import logging
import os import os
from typing import Callable, Dict from typing import Callable
import torch import torch
...@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) ...@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
plugins_loaded = False plugins_loaded = False
def load_plugins_by_group(group: str) -> Dict[str, Callable]: def load_plugins_by_group(group: str) -> dict[str, Callable]:
import sys import sys
if sys.version_info < (3, 10): if sys.version_info < (3, 10):
from importlib_metadata import entry_points from importlib_metadata import entry_points
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment