Unverified Commit 6223dd81 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated type hinting in `model_executor/layers` (#18056)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 906f0598
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, Optional, Tuple from typing import Callable, Optional
import torch import torch
...@@ -12,8 +12,8 @@ from vllm.scalar_type import ScalarType ...@@ -12,8 +12,8 @@ from vllm.scalar_type import ScalarType
@dataclass @dataclass
class MPLinearLayerConfig: class MPLinearLayerConfig:
full_weight_shape: Tuple[int, int] # [in, out] full_weight_shape: tuple[int, int] # [in, out]
partition_weight_shape: Tuple[int, int] partition_weight_shape: tuple[int, int]
weight_type: ScalarType weight_type: ScalarType
act_type: torch.dtype act_type: torch.dtype
group_size: int group_size: int
...@@ -31,7 +31,7 @@ class MPLinearKernel(ABC): ...@@ -31,7 +31,7 @@ class MPLinearKernel(ABC):
@classmethod @classmethod
@abstractmethod @abstractmethod
def can_implement(cls, def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
raise NotImplementedError raise NotImplementedError
def __init__(self, def __init__(self,
...@@ -75,7 +75,7 @@ class MPLinearKernel(ABC): ...@@ -75,7 +75,7 @@ class MPLinearKernel(ABC):
torch.nn.Parameter(new_param.data, requires_grad=False)) torch.nn.Parameter(new_param.data, requires_grad=False))
def _get_weight_params( def _get_weight_params(
self, layer: torch.nn.Module) -> Tuple[ self, layer: torch.nn.Module) -> tuple[
torch.Tensor, # w_q torch.Tensor, # w_q
torch.Tensor, # w_s torch.Tensor, # w_s
Optional[torch.Tensor], # w_zp, Optional[torch.Tensor], # w_zp,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Type from typing import Optional
import vllm.envs as envs import vllm.envs as envs
from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501 from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501
...@@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKer ...@@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKer
from vllm.platforms import current_platform from vllm.platforms import current_platform
# in priority/performance order (when available) # in priority/performance order (when available)
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ _POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [
MacheteLinearKernel, MacheteLinearKernel,
AllSparkLinearKernel, AllSparkLinearKernel,
MarlinLinearKernel, MarlinLinearKernel,
...@@ -29,7 +29,7 @@ _POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ ...@@ -29,7 +29,7 @@ _POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
def choose_mp_linear_kernel( def choose_mp_linear_kernel(
config: MPLinearLayerConfig, config: MPLinearLayerConfig,
compute_capability: Optional[int] = None) -> Type[MPLinearKernel]: compute_capability: Optional[int] = None) -> type[MPLinearKernel]:
""" """
Choose an MPLinearKernel that can implement the given config for the given Choose an MPLinearKernel that can implement the given config for the given
compute capability. Attempts to choose the best kernel in terms of compute capability. Attempts to choose the best kernel in terms of
...@@ -46,7 +46,7 @@ def choose_mp_linear_kernel( ...@@ -46,7 +46,7 @@ def choose_mp_linear_kernel(
ValueError: If no kernel can implement the given config. ValueError: If no kernel can implement the given config.
Returns: Returns:
Type[MPLinearKernel]: Chosen kernel. type[MPLinearKernel]: Chosen kernel.
""" """
if compute_capability is None: if compute_capability is None:
if current_platform is None: if current_platform is None:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple from typing import Optional
import torch import torch
...@@ -22,7 +22,7 @@ class AllSparkLinearKernel(MPLinearKernel): ...@@ -22,7 +22,7 @@ class AllSparkLinearKernel(MPLinearKernel):
@classmethod @classmethod
def can_implement(cls, def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
if c.has_g_idx: if c.has_g_idx:
return False, "Act reordering currently not supported by AllSpark" return False, "Act reordering currently not supported by AllSpark"
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Dict, List, Optional, Tuple from typing import Optional
import torch import torch
...@@ -21,10 +21,10 @@ logger = init_logger(__name__) ...@@ -21,10 +21,10 @@ logger = init_logger(__name__)
class BitBLASLinearKernel(MPLinearKernel): class BitBLASLinearKernel(MPLinearKernel):
OPT_FEATURES: List[int] = BITBLAS_OPTIMIZE_FEATURES OPT_FEATURES: list[int] = BITBLAS_OPTIMIZE_FEATURES
ENABLE_TUNING: bool = True ENABLE_TUNING: bool = True
MATMUL_LAYOUT: str = "nt" MATMUL_LAYOUT: str = "nt"
BITBLAS_DTYPES: Dict[torch.dtype, str] = { BITBLAS_DTYPES: dict[torch.dtype, str] = {
torch.float32: "float32", torch.float32: "float32",
torch.float16: "float16", torch.float16: "float16",
torch.bfloat16: "bfloat16", torch.bfloat16: "bfloat16",
...@@ -103,7 +103,7 @@ class BitBLASLinearKernel(MPLinearKernel): ...@@ -103,7 +103,7 @@ class BitBLASLinearKernel(MPLinearKernel):
@classmethod @classmethod
def can_implement(cls, def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
is_bitblas_installed = True is_bitblas_installed = True
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple from typing import Optional
import torch import torch
...@@ -25,7 +25,7 @@ class ExllamaLinearKernel(MPLinearKernel): ...@@ -25,7 +25,7 @@ class ExllamaLinearKernel(MPLinearKernel):
@classmethod @classmethod
def can_implement(cls, def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
if c.has_g_idx and\ if c.has_g_idx and\
c.partition_weight_shape[0] != c.full_weight_shape[0]: c.partition_weight_shape[0] != c.full_weight_shape[0]:
return False, "Act reordering currently not supported by Exllama, "\ return False, "Act reordering currently not supported by Exllama, "\
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from functools import partial from functools import partial
from typing import Optional, Tuple from typing import Optional
import torch import torch
...@@ -25,7 +25,7 @@ class MacheteLinearKernel(MPLinearKernel): ...@@ -25,7 +25,7 @@ class MacheteLinearKernel(MPLinearKernel):
@classmethod @classmethod
def can_implement(cls, def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
if c.has_g_idx and\ if c.has_g_idx and\
c.partition_weight_shape[0] != c.full_weight_shape[0]: c.partition_weight_shape[0] != c.full_weight_shape[0]:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple from typing import Optional
import torch import torch
...@@ -24,7 +24,7 @@ class MarlinLinearKernel(MPLinearKernel): ...@@ -24,7 +24,7 @@ class MarlinLinearKernel(MPLinearKernel):
@classmethod @classmethod
def can_implement(cls, def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
quant_types = query_marlin_supported_quant_types(c.zero_points) quant_types = query_marlin_supported_quant_types(c.zero_points)
if c.weight_type not in quant_types: if c.weight_type not in quant_types:
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional
import torch import torch
...@@ -24,7 +24,7 @@ class ScaledMMLinearKernel(ABC): ...@@ -24,7 +24,7 @@ class ScaledMMLinearKernel(ABC):
@classmethod @classmethod
@abstractmethod @abstractmethod
def can_implement( def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
raise NotImplementedError raise NotImplementedError
def __init__(self, c: ScaledMMLinearLayerConfig, w_q_param_name: str, def __init__(self, c: ScaledMMLinearLayerConfig, w_q_param_name: str,
...@@ -50,7 +50,7 @@ class ScaledMMLinearKernel(ABC): ...@@ -50,7 +50,7 @@ class ScaledMMLinearKernel(ABC):
raise NotImplementedError raise NotImplementedError
def _get_weight_params( def _get_weight_params(
self, layer: torch.nn.Module) -> Tuple[ self, layer: torch.nn.Module) -> tuple[
torch.Tensor, # weight torch.Tensor, # weight
torch.Tensor, # weight_scale torch.Tensor, # weight_scale
Optional[torch.Tensor], # input_scale, Optional[torch.Tensor], # input_scale,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os import os
from typing import Dict, List, Optional, Type from typing import Optional
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
AiterScaledMMLinearKernel) AiterScaledMMLinearKernel)
...@@ -16,7 +16,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import ( ...@@ -16,7 +16,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
from vllm.platforms import PlatformEnum, current_platform from vllm.platforms import PlatformEnum, current_platform
# in priority/performance order (when available) # in priority/performance order (when available)
_POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = { _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
PlatformEnum.CPU: [CutlassScaledMMLinearKernel], PlatformEnum.CPU: [CutlassScaledMMLinearKernel],
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel], PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel], PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
...@@ -27,7 +27,7 @@ _POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = { ...@@ -27,7 +27,7 @@ _POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = {
def choose_scaled_mm_linear_kernel( def choose_scaled_mm_linear_kernel(
config: ScaledMMLinearLayerConfig, config: ScaledMMLinearLayerConfig,
compute_capability: Optional[int] = None compute_capability: Optional[int] = None
) -> Type[ScaledMMLinearKernel]: ) -> type[ScaledMMLinearKernel]:
""" """
Choose an ScaledMMLinearKernel that can implement the given config for the Choose an ScaledMMLinearKernel that can implement the given config for the
given compute capability. Attempts to choose the best kernel in terms of given compute capability. Attempts to choose the best kernel in terms of
...@@ -44,7 +44,7 @@ def choose_scaled_mm_linear_kernel( ...@@ -44,7 +44,7 @@ def choose_scaled_mm_linear_kernel(
ValueError: If no kernel can implement the given config. ValueError: If no kernel can implement the given config.
Returns: Returns:
Type[ScaledMMLinearKernel]: Chosen kernel. type[ScaledMMLinearKernel]: Chosen kernel.
""" """
if compute_capability is None: if compute_capability is None:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple from typing import Optional
import torch import torch
...@@ -20,7 +20,7 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): ...@@ -20,7 +20,7 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
@classmethod @classmethod
def can_implement( def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
if not current_platform.is_rocm(): if not current_platform.is_rocm():
return ( return (
False, False,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple from typing import Optional
import torch import torch
...@@ -22,7 +22,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): ...@@ -22,7 +22,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
@classmethod @classmethod
def can_implement( def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
if (not current_platform.is_cuda() and not current_platform.is_cpu()): if (not current_platform.is_cuda() and not current_platform.is_cpu()):
return False, "CutlassScaledMM requires running on CUDA or CPU." return False, "CutlassScaledMM requires running on CUDA or CPU."
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple from typing import Optional
import torch import torch
...@@ -18,7 +18,7 @@ class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel): ...@@ -18,7 +18,7 @@ class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
@classmethod @classmethod
def can_implement( def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
if current_platform.is_cpu(): if current_platform.is_cpu():
return ( return (
False, False,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import warnings import warnings
from typing import Optional, Tuple from typing import Optional
import torch import torch
from functorch.experimental.control_flow import cond # noqa: F401 from functorch.experimental.control_flow import cond # noqa: F401
...@@ -25,7 +25,7 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel): ...@@ -25,7 +25,7 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
@classmethod @classmethod
def can_implement( def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
if not current_platform.is_tpu(): if not current_platform.is_tpu():
return False, "ScaledMMXLA requires running on TPU." return False, "ScaledMMXLA requires running on TPU."
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional from typing import Any, Optional
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -68,7 +68,7 @@ class MarlinConfig(QuantizationConfig): ...@@ -68,7 +68,7 @@ class MarlinConfig(QuantizationConfig):
return "marlin" return "marlin"
@classmethod @classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half] return [torch.half]
@classmethod @classmethod
...@@ -77,11 +77,11 @@ class MarlinConfig(QuantizationConfig): ...@@ -77,11 +77,11 @@ class MarlinConfig(QuantizationConfig):
return 80 return 80
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"] return ["quantize_config.json"]
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig": def from_config(cls, config: dict[str, Any]) -> "MarlinConfig":
group_size = cls.get_from_keys(config, ["group_size"]) group_size = cls.get_from_keys(config, ["group_size"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False) default=False)
...@@ -128,7 +128,7 @@ class MarlinLinearMethod(LinearMethodBase): ...@@ -128,7 +128,7 @@ class MarlinLinearMethod(LinearMethodBase):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
output_partition_sizes: List[int], output_partition_sizes: list[int],
input_size: int, input_size: int,
output_size: int, output_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Optional, Union
import torch import torch
from torch.nn import Module from torch.nn import Module
...@@ -53,7 +53,7 @@ class ModelOptFp8Config(QuantizationConfig): ...@@ -53,7 +53,7 @@ class ModelOptFp8Config(QuantizationConfig):
return "modelopt" return "modelopt"
@classmethod @classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16, torch.half] return [torch.bfloat16, torch.half]
@classmethod @classmethod
...@@ -61,11 +61,11 @@ class ModelOptFp8Config(QuantizationConfig): ...@@ -61,11 +61,11 @@ class ModelOptFp8Config(QuantizationConfig):
return 89 return 89
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> list[str]:
return ["hf_quant_config.json"] return ["hf_quant_config.json"]
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config": def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config":
quant_config = cls.get_from_keys(config, ["quantization"]) quant_config = cls.get_from_keys(config, ["quantization"])
quant_method = quant_config["quant_algo"] quant_method = quant_config["quant_algo"]
if quant_method not in QUANT_ALGOS: if quant_method not in QUANT_ALGOS:
...@@ -107,7 +107,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase): ...@@ -107,7 +107,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
output_partition_sizes: List[int], output_partition_sizes: list[int],
input_size: int, input_size: int,
output_size: int, output_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
...@@ -177,7 +177,7 @@ class ModelOptNvFp4Config(QuantizationConfig): ...@@ -177,7 +177,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
self, self,
is_checkpoint_nvfp4_serialized: bool, is_checkpoint_nvfp4_serialized: bool,
kv_cache_quant_algo: str, kv_cache_quant_algo: str,
exclude_modules: List[str], exclude_modules: list[str],
group_size: int = 16, group_size: int = 16,
) -> None: ) -> None:
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
...@@ -195,7 +195,7 @@ class ModelOptNvFp4Config(QuantizationConfig): ...@@ -195,7 +195,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
return "nvfp4" return "nvfp4"
@classmethod @classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16, torch.half, torch.float8_e4m3fn] return [torch.bfloat16, torch.half, torch.float8_e4m3fn]
@classmethod @classmethod
...@@ -203,11 +203,11 @@ class ModelOptNvFp4Config(QuantizationConfig): ...@@ -203,11 +203,11 @@ class ModelOptNvFp4Config(QuantizationConfig):
return 80 return 80
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> list[str]:
return ["hf_quant_config.json"] return ["hf_quant_config.json"]
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "ModelOptNvFp4Config": def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config":
quant_config = cls.get_from_keys(config, ["quantization"]) quant_config = cls.get_from_keys(config, ["quantization"])
quant_method = quant_config["quant_algo"] quant_method = quant_config["quant_algo"]
if quant_method not in QUANT_ALGOS: if quant_method not in QUANT_ALGOS:
...@@ -227,7 +227,7 @@ class ModelOptNvFp4Config(QuantizationConfig): ...@@ -227,7 +227,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo, return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
exclude_modules, group_size) exclude_modules, group_size)
def is_layer_excluded(self, prefix: str, exclude_modules: List): def is_layer_excluded(self, prefix: str, exclude_modules: list):
import re import re
for pattern in exclude_modules: for pattern in exclude_modules:
regex_str = pattern.replace('.', r'\.').replace('*', r'.*') regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
...@@ -296,7 +296,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): ...@@ -296,7 +296,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
output_partition_sizes: List[int], output_partition_sizes: list[int],
input_size: int, input_size: int,
output_size: int, output_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Optional
import torch import torch
...@@ -23,8 +23,8 @@ class MoeWNA16Config(QuantizationConfig): ...@@ -23,8 +23,8 @@ class MoeWNA16Config(QuantizationConfig):
def __init__(self, linear_quant_method: str, weight_bits: int, def __init__(self, linear_quant_method: str, weight_bits: int,
group_size: int, has_zp: bool, lm_head_quantized: bool, group_size: int, has_zp: bool, lm_head_quantized: bool,
modules_to_not_convert: Optional[List[str]], modules_to_not_convert: Optional[list[str]],
full_config: Dict[str, Any]) -> None: full_config: dict[str, Any]) -> None:
super().__init__() super().__init__()
self.weight_bits = weight_bits self.weight_bits = weight_bits
self.group_size = group_size self.group_size = group_size
...@@ -69,7 +69,7 @@ class MoeWNA16Config(QuantizationConfig): ...@@ -69,7 +69,7 @@ class MoeWNA16Config(QuantizationConfig):
return "moe_wna16" return "moe_wna16"
@classmethod @classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16, torch.half] return [torch.bfloat16, torch.half]
@classmethod @classmethod
...@@ -77,11 +77,11 @@ class MoeWNA16Config(QuantizationConfig): ...@@ -77,11 +77,11 @@ class MoeWNA16Config(QuantizationConfig):
return 70 return 70
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"] return ["quantize_config.json"]
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "MoeWNA16Config": def from_config(cls, config: dict[str, Any]) -> "MoeWNA16Config":
linear_quant_method = cls.get_from_keys(config, ["quant_method"]) linear_quant_method = cls.get_from_keys(config, ["quant_method"])
weight_bits = cls.get_from_keys(config, ["bits"]) weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"]) group_size = cls.get_from_keys(config, ["group_size"])
...@@ -109,7 +109,7 @@ class MoeWNA16Config(QuantizationConfig): ...@@ -109,7 +109,7 @@ class MoeWNA16Config(QuantizationConfig):
return None return None
@classmethod @classmethod
def is_moe_wna16_compatible(cls, quant_config: Dict[str, Any]): def is_moe_wna16_compatible(cls, quant_config: dict[str, Any]):
# Extract data from quant config. # Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower() quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits") num_bits = quant_config.get("bits")
...@@ -163,7 +163,7 @@ class MoeWNA16Config(QuantizationConfig): ...@@ -163,7 +163,7 @@ class MoeWNA16Config(QuantizationConfig):
return None return None
def is_layer_skipped_quant(prefix: str, modules_to_not_convert: List[str]): def is_layer_skipped_quant(prefix: str, modules_to_not_convert: list[str]):
return any(module_name in prefix for module_name in modules_to_not_convert) return any(module_name in prefix for module_name in modules_to_not_convert)
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import os import os
from importlib.util import find_spec from importlib.util import find_spec
from typing import Any, Dict, List, Optional from typing import Any, Optional
from torch.nn import Module from torch.nn import Module
...@@ -34,7 +34,7 @@ class NeuronQuantConfig(QuantizationConfig): ...@@ -34,7 +34,7 @@ class NeuronQuantConfig(QuantizationConfig):
def get_name(self) -> QuantizationMethods: def get_name(self) -> QuantizationMethods:
return "neuron_quant" return "neuron_quant"
def get_supported_act_dtypes(self) -> List[str]: def get_supported_act_dtypes(self) -> list[str]:
return SUPPORTED_QUANT_DTYPE_LIST return SUPPORTED_QUANT_DTYPE_LIST
@classmethod @classmethod
...@@ -43,11 +43,11 @@ class NeuronQuantConfig(QuantizationConfig): ...@@ -43,11 +43,11 @@ class NeuronQuantConfig(QuantizationConfig):
"This function should not be called with Neuron Backend") "This function should not be called with Neuron Backend")
@staticmethod @staticmethod
def get_config_filenames() -> List[str]: def get_config_filenames() -> list[str]:
return [] return []
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "NeuronQuantConfig": def from_config(cls, config: dict[str, Any]) -> "NeuronQuantConfig":
quantize_method = cls.get_from_keys(config, ["quantize_method"]) quantize_method = cls.get_from_keys(config, ["quantize_method"])
dequant_dtype = cls.get_from_keys(config, ["dequant_dtype"]) dequant_dtype = cls.get_from_keys(config, ["dequant_dtype"])
return cls(dequant_dtype=dequant_dtype, return cls(dequant_dtype=dequant_dtype,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional from typing import Any, Optional
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -32,7 +32,7 @@ class PTPCFp8Config(Fp8Config): ...@@ -32,7 +32,7 @@ class PTPCFp8Config(Fp8Config):
def __init__( def __init__(
self, self,
activation_scheme: str = "dynamic", activation_scheme: str = "dynamic",
ignored_layers: Optional[List[str]] = None, ignored_layers: Optional[list[str]] = None,
) -> None: ) -> None:
if not current_platform.is_rocm(): if not current_platform.is_rocm():
raise ValueError( raise ValueError(
...@@ -55,7 +55,7 @@ class PTPCFp8Config(Fp8Config): ...@@ -55,7 +55,7 @@ class PTPCFp8Config(Fp8Config):
return "ptpc_fp8" return "ptpc_fp8"
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "PTPCFp8Config": def from_config(cls, config: dict[str, Any]) -> "PTPCFp8Config":
activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
return cls(activation_scheme=activation_scheme, return cls(activation_scheme=activation_scheme,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional from typing import Any, Optional
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -89,7 +89,7 @@ class QQQConfig(QuantizationConfig): ...@@ -89,7 +89,7 @@ class QQQConfig(QuantizationConfig):
return "qqq" return "qqq"
@classmethod @classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half] return [torch.half]
@classmethod @classmethod
...@@ -97,7 +97,7 @@ class QQQConfig(QuantizationConfig): ...@@ -97,7 +97,7 @@ class QQQConfig(QuantizationConfig):
return 80 return 80
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> list[str]:
"""List of filenames to search for in the model directory.""" """List of filenames to search for in the model directory."""
return [ return [
"quant_config.json", "quant_config.json",
...@@ -105,7 +105,7 @@ class QQQConfig(QuantizationConfig): ...@@ -105,7 +105,7 @@ class QQQConfig(QuantizationConfig):
] ]
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "QQQConfig": def from_config(cls, config: dict[str, Any]) -> "QQQConfig":
weight_bits = cls.get_from_keys(config, ["wbits"]) weight_bits = cls.get_from_keys(config, ["wbits"])
group_size = cls.get_from_keys(config, ["group_size"]) group_size = cls.get_from_keys(config, ["group_size"])
return cls(weight_bits, group_size) return cls(weight_bits, group_size)
...@@ -131,7 +131,7 @@ class QQQLinearMethod(LinearMethodBase): ...@@ -131,7 +131,7 @@ class QQQLinearMethod(LinearMethodBase):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
output_partition_sizes: List[int], output_partition_sizes: list[int],
input_size: int, input_size: int,
output_size: int, output_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import fnmatch import fnmatch
from typing import Any, Dict, List, Optional, cast from typing import Any, Optional, cast
import torch import torch
...@@ -29,9 +29,9 @@ logger = init_logger(__name__) ...@@ -29,9 +29,9 @@ logger = init_logger(__name__)
class QuarkConfig(QuantizationConfig): class QuarkConfig(QuantizationConfig):
def __init__(self, def __init__(self,
quant_config: Dict[str, Any], quant_config: dict[str, Any],
kv_cache_group: Optional[List[str]] = None, kv_cache_group: Optional[list[str]] = None,
kv_cache_config: Optional[Dict[str, Any]] = None, kv_cache_config: Optional[dict[str, Any]] = None,
pack_method: str = "reorder"): pack_method: str = "reorder"):
super().__init__() super().__init__()
if kv_cache_group is None: if kv_cache_group is None:
...@@ -44,7 +44,7 @@ class QuarkConfig(QuantizationConfig): ...@@ -44,7 +44,7 @@ class QuarkConfig(QuantizationConfig):
def get_linear_method(self) -> "QuarkLinearMethod": def get_linear_method(self) -> "QuarkLinearMethod":
return QuarkLinearMethod(self) return QuarkLinearMethod(self)
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16] return [torch.float16, torch.bfloat16]
@classmethod @classmethod
...@@ -59,7 +59,7 @@ class QuarkConfig(QuantizationConfig): ...@@ -59,7 +59,7 @@ class QuarkConfig(QuantizationConfig):
from vllm.attention.layer import Attention # Avoid circular import from vllm.attention.layer import Attention # Avoid circular import
# Check if the layer is skipped for quantization. # Check if the layer is skipped for quantization.
exclude_layers = cast(List[str], self.quant_config.get("exclude")) exclude_layers = cast(list[str], self.quant_config.get("exclude"))
if should_ignore_layer(prefix, if should_ignore_layer(prefix,
ignore=exclude_layers, ignore=exclude_layers,
fused_mapping=self.packed_modules_mapping): fused_mapping=self.packed_modules_mapping):
...@@ -78,12 +78,12 @@ class QuarkConfig(QuantizationConfig): ...@@ -78,12 +78,12 @@ class QuarkConfig(QuantizationConfig):
return None return None
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "QuarkConfig": def from_config(cls, config: dict[str, Any]) -> "QuarkConfig":
export_config = config.get("export") export_config = config.get("export")
if export_config is None: if export_config is None:
raise ValueError("The export key should be included in " raise ValueError("The export key should be included in "
"the configurations of Quark quantized model") "the configurations of Quark quantized model")
kv_cache_group = cast(List[str], export_config.get("kv_cache_group")) kv_cache_group = cast(list[str], export_config.get("kv_cache_group"))
pack_method = cast(str, export_config.get("pack_method")) pack_method = cast(str, export_config.get("pack_method"))
# In the export model of quark, the quantization configuration # In the export model of quark, the quantization configuration
...@@ -95,7 +95,7 @@ class QuarkConfig(QuantizationConfig): ...@@ -95,7 +95,7 @@ class QuarkConfig(QuantizationConfig):
kv_cache_config = None kv_cache_config = None
else: else:
kv_cache_set = set(kv_cache_group) kv_cache_set = set(kv_cache_group)
layer_quant_config = cast(Dict[str, Any], layer_quant_config = cast(dict[str, Any],
config.get("layer_quant_config")) config.get("layer_quant_config"))
layer_quant_names = list(layer_quant_config.keys()) layer_quant_names = list(layer_quant_config.keys())
layer_quant_set = set(layer_quant_names) layer_quant_set = set(layer_quant_names)
...@@ -108,7 +108,7 @@ class QuarkConfig(QuantizationConfig): ...@@ -108,7 +108,7 @@ class QuarkConfig(QuantizationConfig):
"configuration.") "configuration.")
q_configs = [ q_configs = [
cast(Dict[str, Any], layer_quant_config.get(name)) cast(dict[str, Any], layer_quant_config.get(name))
for name in kv_cache_group for name in kv_cache_group
] ]
if not all( if not all(
...@@ -131,7 +131,7 @@ class QuarkConfig(QuantizationConfig): ...@@ -131,7 +131,7 @@ class QuarkConfig(QuantizationConfig):
# In case q_proj output is also quantized, remove the configuration # In case q_proj output is also quantized, remove the configuration
# to keep qkv consistency. # to keep qkv consistency.
q_proj_q_config = cast(Dict[str, Any], q_proj_q_config = cast(dict[str, Any],
layer_quant_config.get("*q_proj")) layer_quant_config.get("*q_proj"))
if q_proj_q_config is not None: if q_proj_q_config is not None:
q_proj_q_config["output_tensors"] = None q_proj_q_config["output_tensors"] = None
...@@ -142,7 +142,7 @@ class QuarkConfig(QuantizationConfig): ...@@ -142,7 +142,7 @@ class QuarkConfig(QuantizationConfig):
pack_method=pack_method) pack_method=pack_method)
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> list[str]:
return [] return []
def _check_scheme_supported(self, def _check_scheme_supported(self,
...@@ -162,8 +162,8 @@ class QuarkConfig(QuantizationConfig): ...@@ -162,8 +162,8 @@ class QuarkConfig(QuantizationConfig):
else: else:
return False return False
def _is_fp8_w8a8(self, weight_quant: Optional[Dict[str, Any]], def _is_fp8_w8a8(self, weight_quant: Optional[dict[str, Any]],
input_quant: Optional[Dict[str, Any]]) -> bool: input_quant: Optional[dict[str, Any]]) -> bool:
# Confirm weights and input quantized. # Confirm weights and input quantized.
if weight_quant is None or input_quant is None: if weight_quant is None or input_quant is None:
return False return False
...@@ -187,8 +187,8 @@ class QuarkConfig(QuantizationConfig): ...@@ -187,8 +187,8 @@ class QuarkConfig(QuantizationConfig):
is_per_tensor_activation = (input_quant.get("qscheme") == "per_tensor") is_per_tensor_activation = (input_quant.get("qscheme") == "per_tensor")
return is_per_tensor_activation return is_per_tensor_activation
def _is_static_tensor_w8a8(self, weight_quant: Optional[Dict[str, Any]], def _is_static_tensor_w8a8(self, weight_quant: Optional[dict[str, Any]],
input_quant: Optional[Dict[str, Any]]) -> bool: input_quant: Optional[dict[str, Any]]) -> bool:
# Confirm weights and input quantized. # Confirm weights and input quantized.
if weight_quant is None or input_quant is None: if weight_quant is None or input_quant is None:
return False return False
...@@ -209,8 +209,8 @@ class QuarkConfig(QuantizationConfig): ...@@ -209,8 +209,8 @@ class QuarkConfig(QuantizationConfig):
# Only symmetric weight quantization supported. # Only symmetric weight quantization supported.
return is_int8_dtype and is_tensor and is_weight_symmetric and is_static return is_int8_dtype and is_tensor and is_weight_symmetric and is_static
def _is_mx_fp4(self, weight_quant: Optional[Dict[str, Any]], def _is_mx_fp4(self, weight_quant: Optional[dict[str, Any]],
input_quant: Optional[Dict[str, Any]]) -> bool: input_quant: Optional[dict[str, Any]]) -> bool:
# Confirm weights and input quantized. # Confirm weights and input quantized.
if weight_quant is None or input_quant is None: if weight_quant is None or input_quant is None:
logger.debug("Quark model is not in MX-FP4 format: " logger.debug("Quark model is not in MX-FP4 format: "
...@@ -258,7 +258,7 @@ class QuarkConfig(QuantizationConfig): ...@@ -258,7 +258,7 @@ class QuarkConfig(QuantizationConfig):
return True return True
def _find_matched_config(self, layer_name: str, def _find_matched_config(self, layer_name: str,
module: torch.nn.Module) -> Dict[str, Any]: module: torch.nn.Module) -> dict[str, Any]:
proj_name = layer_name.split(".")[-1] proj_name = layer_name.split(".")[-1]
if proj_name in self.packed_modules_mapping: if proj_name in self.packed_modules_mapping:
...@@ -283,29 +283,29 @@ class QuarkConfig(QuantizationConfig): ...@@ -283,29 +283,29 @@ class QuarkConfig(QuantizationConfig):
return shard_configs[0] return shard_configs[0]
else: else:
layer_quant_config = cast( layer_quant_config = cast(
Dict[str, Any], self.quant_config.get("layer_quant_config")) dict[str, Any], self.quant_config.get("layer_quant_config"))
for name_pattern in layer_quant_config: for name_pattern in layer_quant_config:
if fnmatch.fnmatch(layer_name, name_pattern): if fnmatch.fnmatch(layer_name, name_pattern):
return layer_quant_config[name_pattern] return layer_quant_config[name_pattern]
layer_type = cast(str, type(module)) layer_type = cast(str, type(module))
layer_type_quant_config = cast( layer_type_quant_config = cast(
Dict[str, Any], dict[str, Any],
self.quant_config.get("layer_type_quant_config")) self.quant_config.get("layer_type_quant_config"))
if layer_type in layer_type_quant_config: if layer_type in layer_type_quant_config:
return layer_type_quant_config[layer_type] return layer_type_quant_config[layer_type]
global_quant_config = cast( global_quant_config = cast(
Dict[str, Any], self.quant_config.get("global_quant_config")) dict[str, Any], self.quant_config.get("global_quant_config"))
return global_quant_config return global_quant_config
def _get_scheme_from_config(self, config: Dict[str, Any]) -> "QuarkScheme": def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme":
if config.get("output_tensors") or config.get("bias"): if config.get("output_tensors") or config.get("bias"):
raise NotImplementedError( raise NotImplementedError(
"Currently, Quark models with output_tensors " "Currently, Quark models with output_tensors "
"and bias quantized are not supported") "and bias quantized are not supported")
weight_config = cast(Dict[str, Any], config.get("weight")) weight_config = cast(dict[str, Any], config.get("weight"))
input_config = cast(Dict[str, Any], config.get("input_tensors")) input_config = cast(dict[str, Any], config.get("input_tensors"))
if self._is_fp8_w8a8(weight_config, input_config): if self._is_fp8_w8a8(weight_config, input_config):
is_fp8_w8a8_supported = self._check_scheme_supported( is_fp8_w8a8_supported = self._check_scheme_supported(
...@@ -373,7 +373,7 @@ class QuarkLinearMethod(LinearMethodBase): ...@@ -373,7 +373,7 @@ class QuarkLinearMethod(LinearMethodBase):
def create_weights(self, layer: torch.nn.Module, def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int, output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype, output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs): **extra_weight_attrs):
""" """
...@@ -417,7 +417,7 @@ class QuarkKVCacheMethod(BaseKVCacheMethod): ...@@ -417,7 +417,7 @@ class QuarkKVCacheMethod(BaseKVCacheMethod):
super().__init__(quant_config) super().__init__(quant_config)
@staticmethod @staticmethod
def validate_kv_cache_config(kv_cache_config: Optional[Dict[str, Any]]): def validate_kv_cache_config(kv_cache_config: Optional[dict[str, Any]]):
""" """
Validator for the kv cache configuration. Useful for controlling the Validator for the kv cache configuration. Useful for controlling the
kv cache quantization schemes, that are being supported in vLLM kv cache quantization schemes, that are being supported in vLLM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment