Unverified Commit 6223dd81 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated type hinting in `model_executor/layers` (#18056)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 906f0598
......@@ -2,7 +2,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Optional, Tuple
from typing import Callable, Optional
import torch
......@@ -12,8 +12,8 @@ from vllm.scalar_type import ScalarType
@dataclass
class MPLinearLayerConfig:
full_weight_shape: Tuple[int, int] # [in, out]
partition_weight_shape: Tuple[int, int]
full_weight_shape: tuple[int, int] # [in, out]
partition_weight_shape: tuple[int, int]
weight_type: ScalarType
act_type: torch.dtype
group_size: int
......@@ -31,7 +31,7 @@ class MPLinearKernel(ABC):
@classmethod
@abstractmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
raise NotImplementedError
def __init__(self,
......@@ -75,7 +75,7 @@ class MPLinearKernel(ABC):
torch.nn.Parameter(new_param.data, requires_grad=False))
def _get_weight_params(
self, layer: torch.nn.Module) -> Tuple[
self, layer: torch.nn.Module) -> tuple[
torch.Tensor, # w_q
torch.Tensor, # w_s
Optional[torch.Tensor], # w_zp,
......
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Type
from typing import Optional
import vllm.envs as envs
from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501
......@@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKer
from vllm.platforms import current_platform
# in priority/performance order (when available)
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
_POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [
MacheteLinearKernel,
AllSparkLinearKernel,
MarlinLinearKernel,
......@@ -29,7 +29,7 @@ _POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
def choose_mp_linear_kernel(
config: MPLinearLayerConfig,
compute_capability: Optional[int] = None) -> Type[MPLinearKernel]:
compute_capability: Optional[int] = None) -> type[MPLinearKernel]:
"""
Choose an MPLinearKernel that can implement the given config for the given
compute capability. Attempts to choose the best kernel in terms of
......@@ -46,7 +46,7 @@ def choose_mp_linear_kernel(
ValueError: If no kernel can implement the given config.
Returns:
Type[MPLinearKernel]: Chosen kernel.
type[MPLinearKernel]: Chosen kernel.
"""
if compute_capability is None:
if current_platform is None:
......
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
from typing import Optional
import torch
......@@ -22,7 +22,7 @@ class AllSparkLinearKernel(MPLinearKernel):
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
if c.has_g_idx:
return False, "Act reordering currently not supported by AllSpark"
......
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, List, Optional, Tuple
from typing import Optional
import torch
......@@ -21,10 +21,10 @@ logger = init_logger(__name__)
class BitBLASLinearKernel(MPLinearKernel):
OPT_FEATURES: List[int] = BITBLAS_OPTIMIZE_FEATURES
OPT_FEATURES: list[int] = BITBLAS_OPTIMIZE_FEATURES
ENABLE_TUNING: bool = True
MATMUL_LAYOUT: str = "nt"
BITBLAS_DTYPES: Dict[torch.dtype, str] = {
BITBLAS_DTYPES: dict[torch.dtype, str] = {
torch.float32: "float32",
torch.float16: "float16",
torch.bfloat16: "bfloat16",
......@@ -103,7 +103,7 @@ class BitBLASLinearKernel(MPLinearKernel):
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
is_bitblas_installed = True
......
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
from typing import Optional
import torch
......@@ -25,7 +25,7 @@ class ExllamaLinearKernel(MPLinearKernel):
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
if c.has_g_idx and\
c.partition_weight_shape[0] != c.full_weight_shape[0]:
return False, "Act reordering currently not supported by Exllama, "\
......
# SPDX-License-Identifier: Apache-2.0
from functools import partial
from typing import Optional, Tuple
from typing import Optional
import torch
......@@ -25,7 +25,7 @@ class MacheteLinearKernel(MPLinearKernel):
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
if c.has_g_idx and\
c.partition_weight_shape[0] != c.full_weight_shape[0]:
......
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
from typing import Optional
import torch
......@@ -24,7 +24,7 @@ class MarlinLinearKernel(MPLinearKernel):
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
quant_types = query_marlin_supported_quant_types(c.zero_points)
if c.weight_type not in quant_types:
......
......@@ -2,7 +2,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Optional
import torch
......@@ -24,7 +24,7 @@ class ScaledMMLinearKernel(ABC):
@classmethod
@abstractmethod
def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
raise NotImplementedError
def __init__(self, c: ScaledMMLinearLayerConfig, w_q_param_name: str,
......@@ -50,7 +50,7 @@ class ScaledMMLinearKernel(ABC):
raise NotImplementedError
def _get_weight_params(
self, layer: torch.nn.Module) -> Tuple[
self, layer: torch.nn.Module) -> tuple[
torch.Tensor, # weight
torch.Tensor, # weight_scale
Optional[torch.Tensor], # input_scale,
......
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Dict, List, Optional, Type
from typing import Optional
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
AiterScaledMMLinearKernel)
......@@ -16,7 +16,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
from vllm.platforms import PlatformEnum, current_platform
# in priority/performance order (when available)
_POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = {
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
PlatformEnum.CPU: [CutlassScaledMMLinearKernel],
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
......@@ -27,7 +27,7 @@ _POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = {
def choose_scaled_mm_linear_kernel(
config: ScaledMMLinearLayerConfig,
compute_capability: Optional[int] = None
) -> Type[ScaledMMLinearKernel]:
) -> type[ScaledMMLinearKernel]:
"""
Choose an ScaledMMLinearKernel that can implement the given config for the
given compute capability. Attempts to choose the best kernel in terms of
......@@ -44,7 +44,7 @@ def choose_scaled_mm_linear_kernel(
ValueError: If no kernel can implement the given config.
Returns:
Type[ScaledMMLinearKernel]: Chosen kernel.
type[ScaledMMLinearKernel]: Chosen kernel.
"""
if compute_capability is None:
......
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
from typing import Optional
import torch
......@@ -20,7 +20,7 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
@classmethod
def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
if not current_platform.is_rocm():
return (
False,
......
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
from typing import Optional
import torch
......@@ -22,7 +22,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
@classmethod
def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
if (not current_platform.is_cuda() and not current_platform.is_cpu()):
return False, "CutlassScaledMM requires running on CUDA or CPU."
......
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
from typing import Optional
import torch
......@@ -18,7 +18,7 @@ class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
@classmethod
def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
if current_platform.is_cpu():
return (
False,
......
# SPDX-License-Identifier: Apache-2.0
import warnings
from typing import Optional, Tuple
from typing import Optional
import torch
from functorch.experimental.control_flow import cond # noqa: F401
......@@ -25,7 +25,7 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
@classmethod
def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
if not current_platform.is_tpu():
return False, "ScaledMMXLA requires running on TPU."
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import torch
from torch.nn.parameter import Parameter
......@@ -68,7 +68,7 @@ class MarlinConfig(QuantizationConfig):
return "marlin"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half]
@classmethod
......@@ -77,11 +77,11 @@ class MarlinConfig(QuantizationConfig):
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
def from_config(cls, config: dict[str, Any]) -> "MarlinConfig":
group_size = cls.get_from_keys(config, ["group_size"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
......@@ -128,7 +128,7 @@ class MarlinLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Optional, Union
import torch
from torch.nn import Module
......@@ -53,7 +53,7 @@ class ModelOptFp8Config(QuantizationConfig):
return "modelopt"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
......@@ -61,11 +61,11 @@ class ModelOptFp8Config(QuantizationConfig):
return 89
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return ["hf_quant_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config":
quant_config = cls.get_from_keys(config, ["quantization"])
quant_method = quant_config["quant_algo"]
if quant_method not in QUANT_ALGOS:
......@@ -107,7 +107,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
......@@ -177,7 +177,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
self,
is_checkpoint_nvfp4_serialized: bool,
kv_cache_quant_algo: str,
exclude_modules: List[str],
exclude_modules: list[str],
group_size: int = 16,
) -> None:
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
......@@ -195,7 +195,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
return "nvfp4"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16, torch.half, torch.float8_e4m3fn]
@classmethod
......@@ -203,11 +203,11 @@ class ModelOptNvFp4Config(QuantizationConfig):
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return ["hf_quant_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ModelOptNvFp4Config":
def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config":
quant_config = cls.get_from_keys(config, ["quantization"])
quant_method = quant_config["quant_algo"]
if quant_method not in QUANT_ALGOS:
......@@ -227,7 +227,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
exclude_modules, group_size)
def is_layer_excluded(self, prefix: str, exclude_modules: List):
def is_layer_excluded(self, prefix: str, exclude_modules: list):
import re
for pattern in exclude_modules:
regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
......@@ -296,7 +296,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Optional
import torch
......@@ -23,8 +23,8 @@ class MoeWNA16Config(QuantizationConfig):
def __init__(self, linear_quant_method: str, weight_bits: int,
group_size: int, has_zp: bool, lm_head_quantized: bool,
modules_to_not_convert: Optional[List[str]],
full_config: Dict[str, Any]) -> None:
modules_to_not_convert: Optional[list[str]],
full_config: dict[str, Any]) -> None:
super().__init__()
self.weight_bits = weight_bits
self.group_size = group_size
......@@ -69,7 +69,7 @@ class MoeWNA16Config(QuantizationConfig):
return "moe_wna16"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
......@@ -77,11 +77,11 @@ class MoeWNA16Config(QuantizationConfig):
return 70
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "MoeWNA16Config":
def from_config(cls, config: dict[str, Any]) -> "MoeWNA16Config":
linear_quant_method = cls.get_from_keys(config, ["quant_method"])
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
......@@ -109,7 +109,7 @@ class MoeWNA16Config(QuantizationConfig):
return None
@classmethod
def is_moe_wna16_compatible(cls, quant_config: Dict[str, Any]):
def is_moe_wna16_compatible(cls, quant_config: dict[str, Any]):
# Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits")
......@@ -163,7 +163,7 @@ class MoeWNA16Config(QuantizationConfig):
return None
def is_layer_skipped_quant(prefix: str, modules_to_not_convert: List[str]):
def is_layer_skipped_quant(prefix: str, modules_to_not_convert: list[str]):
return any(module_name in prefix for module_name in modules_to_not_convert)
......
......@@ -2,7 +2,7 @@
import os
from importlib.util import find_spec
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from torch.nn import Module
......@@ -34,7 +34,7 @@ class NeuronQuantConfig(QuantizationConfig):
def get_name(self) -> QuantizationMethods:
return "neuron_quant"
def get_supported_act_dtypes(self) -> List[str]:
def get_supported_act_dtypes(self) -> list[str]:
return SUPPORTED_QUANT_DTYPE_LIST
@classmethod
......@@ -43,11 +43,11 @@ class NeuronQuantConfig(QuantizationConfig):
"This function should not be called with Neuron Backend")
@staticmethod
def get_config_filenames() -> List[str]:
def get_config_filenames() -> list[str]:
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "NeuronQuantConfig":
def from_config(cls, config: dict[str, Any]) -> "NeuronQuantConfig":
quantize_method = cls.get_from_keys(config, ["quantize_method"])
dequant_dtype = cls.get_from_keys(config, ["dequant_dtype"])
return cls(dequant_dtype=dequant_dtype,
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import torch
from torch.nn.parameter import Parameter
......@@ -32,7 +32,7 @@ class PTPCFp8Config(Fp8Config):
def __init__(
self,
activation_scheme: str = "dynamic",
ignored_layers: Optional[List[str]] = None,
ignored_layers: Optional[list[str]] = None,
) -> None:
if not current_platform.is_rocm():
raise ValueError(
......@@ -55,7 +55,7 @@ class PTPCFp8Config(Fp8Config):
return "ptpc_fp8"
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "PTPCFp8Config":
def from_config(cls, config: dict[str, Any]) -> "PTPCFp8Config":
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
return cls(activation_scheme=activation_scheme,
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import torch
from torch.nn.parameter import Parameter
......@@ -89,7 +89,7 @@ class QQQConfig(QuantizationConfig):
return "qqq"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half]
@classmethod
......@@ -97,7 +97,7 @@ class QQQConfig(QuantizationConfig):
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
"""List of filenames to search for in the model directory."""
return [
"quant_config.json",
......@@ -105,7 +105,7 @@ class QQQConfig(QuantizationConfig):
]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "QQQConfig":
def from_config(cls, config: dict[str, Any]) -> "QQQConfig":
weight_bits = cls.get_from_keys(config, ["wbits"])
group_size = cls.get_from_keys(config, ["group_size"])
return cls(weight_bits, group_size)
......@@ -131,7 +131,7 @@ class QQQLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
......
# SPDX-License-Identifier: Apache-2.0
import fnmatch
from typing import Any, Dict, List, Optional, cast
from typing import Any, Optional, cast
import torch
......@@ -29,9 +29,9 @@ logger = init_logger(__name__)
class QuarkConfig(QuantizationConfig):
def __init__(self,
quant_config: Dict[str, Any],
kv_cache_group: Optional[List[str]] = None,
kv_cache_config: Optional[Dict[str, Any]] = None,
quant_config: dict[str, Any],
kv_cache_group: Optional[list[str]] = None,
kv_cache_config: Optional[dict[str, Any]] = None,
pack_method: str = "reorder"):
super().__init__()
if kv_cache_group is None:
......@@ -44,7 +44,7 @@ class QuarkConfig(QuantizationConfig):
def get_linear_method(self) -> "QuarkLinearMethod":
return QuarkLinearMethod(self)
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
......@@ -59,7 +59,7 @@ class QuarkConfig(QuantizationConfig):
from vllm.attention.layer import Attention # Avoid circular import
# Check if the layer is skipped for quantization.
exclude_layers = cast(List[str], self.quant_config.get("exclude"))
exclude_layers = cast(list[str], self.quant_config.get("exclude"))
if should_ignore_layer(prefix,
ignore=exclude_layers,
fused_mapping=self.packed_modules_mapping):
......@@ -78,12 +78,12 @@ class QuarkConfig(QuantizationConfig):
return None
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "QuarkConfig":
def from_config(cls, config: dict[str, Any]) -> "QuarkConfig":
export_config = config.get("export")
if export_config is None:
raise ValueError("The export key should be included in "
"the configurations of Quark quantized model")
kv_cache_group = cast(List[str], export_config.get("kv_cache_group"))
kv_cache_group = cast(list[str], export_config.get("kv_cache_group"))
pack_method = cast(str, export_config.get("pack_method"))
# In the export model of quark, the quantization configuration
......@@ -95,7 +95,7 @@ class QuarkConfig(QuantizationConfig):
kv_cache_config = None
else:
kv_cache_set = set(kv_cache_group)
layer_quant_config = cast(Dict[str, Any],
layer_quant_config = cast(dict[str, Any],
config.get("layer_quant_config"))
layer_quant_names = list(layer_quant_config.keys())
layer_quant_set = set(layer_quant_names)
......@@ -108,7 +108,7 @@ class QuarkConfig(QuantizationConfig):
"configuration.")
q_configs = [
cast(Dict[str, Any], layer_quant_config.get(name))
cast(dict[str, Any], layer_quant_config.get(name))
for name in kv_cache_group
]
if not all(
......@@ -131,7 +131,7 @@ class QuarkConfig(QuantizationConfig):
# In case q_proj output is also quantized, remove the configuration
# to keep qkv consistency.
q_proj_q_config = cast(Dict[str, Any],
q_proj_q_config = cast(dict[str, Any],
layer_quant_config.get("*q_proj"))
if q_proj_q_config is not None:
q_proj_q_config["output_tensors"] = None
......@@ -142,7 +142,7 @@ class QuarkConfig(QuantizationConfig):
pack_method=pack_method)
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return []
def _check_scheme_supported(self,
......@@ -162,8 +162,8 @@ class QuarkConfig(QuantizationConfig):
else:
return False
def _is_fp8_w8a8(self, weight_quant: Optional[Dict[str, Any]],
input_quant: Optional[Dict[str, Any]]) -> bool:
def _is_fp8_w8a8(self, weight_quant: Optional[dict[str, Any]],
input_quant: Optional[dict[str, Any]]) -> bool:
# Confirm weights and input quantized.
if weight_quant is None or input_quant is None:
return False
......@@ -187,8 +187,8 @@ class QuarkConfig(QuantizationConfig):
is_per_tensor_activation = (input_quant.get("qscheme") == "per_tensor")
return is_per_tensor_activation
def _is_static_tensor_w8a8(self, weight_quant: Optional[Dict[str, Any]],
input_quant: Optional[Dict[str, Any]]) -> bool:
def _is_static_tensor_w8a8(self, weight_quant: Optional[dict[str, Any]],
input_quant: Optional[dict[str, Any]]) -> bool:
# Confirm weights and input quantized.
if weight_quant is None or input_quant is None:
return False
......@@ -209,8 +209,8 @@ class QuarkConfig(QuantizationConfig):
# Only symmetric weight quantization supported.
return is_int8_dtype and is_tensor and is_weight_symmetric and is_static
def _is_mx_fp4(self, weight_quant: Optional[Dict[str, Any]],
input_quant: Optional[Dict[str, Any]]) -> bool:
def _is_mx_fp4(self, weight_quant: Optional[dict[str, Any]],
input_quant: Optional[dict[str, Any]]) -> bool:
# Confirm weights and input quantized.
if weight_quant is None or input_quant is None:
logger.debug("Quark model is not in MX-FP4 format: "
......@@ -258,7 +258,7 @@ class QuarkConfig(QuantizationConfig):
return True
def _find_matched_config(self, layer_name: str,
module: torch.nn.Module) -> Dict[str, Any]:
module: torch.nn.Module) -> dict[str, Any]:
proj_name = layer_name.split(".")[-1]
if proj_name in self.packed_modules_mapping:
......@@ -283,29 +283,29 @@ class QuarkConfig(QuantizationConfig):
return shard_configs[0]
else:
layer_quant_config = cast(
Dict[str, Any], self.quant_config.get("layer_quant_config"))
dict[str, Any], self.quant_config.get("layer_quant_config"))
for name_pattern in layer_quant_config:
if fnmatch.fnmatch(layer_name, name_pattern):
return layer_quant_config[name_pattern]
layer_type = cast(str, type(module))
layer_type_quant_config = cast(
Dict[str, Any],
dict[str, Any],
self.quant_config.get("layer_type_quant_config"))
if layer_type in layer_type_quant_config:
return layer_type_quant_config[layer_type]
global_quant_config = cast(
Dict[str, Any], self.quant_config.get("global_quant_config"))
dict[str, Any], self.quant_config.get("global_quant_config"))
return global_quant_config
def _get_scheme_from_config(self, config: Dict[str, Any]) -> "QuarkScheme":
def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme":
if config.get("output_tensors") or config.get("bias"):
raise NotImplementedError(
"Currently, Quark models with output_tensors "
"and bias quantized are not supported")
weight_config = cast(Dict[str, Any], config.get("weight"))
input_config = cast(Dict[str, Any], config.get("input_tensors"))
weight_config = cast(dict[str, Any], config.get("weight"))
input_config = cast(dict[str, Any], config.get("input_tensors"))
if self._is_fp8_w8a8(weight_config, input_config):
is_fp8_w8a8_supported = self._check_scheme_supported(
......@@ -373,7 +373,7 @@ class QuarkLinearMethod(LinearMethodBase):
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
"""
......@@ -417,7 +417,7 @@ class QuarkKVCacheMethod(BaseKVCacheMethod):
super().__init__(quant_config)
@staticmethod
def validate_kv_cache_config(kv_cache_config: Optional[Dict[str, Any]]):
def validate_kv_cache_config(kv_cache_config: Optional[dict[str, Any]]):
"""
Validator for the kv cache configuration. Useful for controlling the
kv cache quantization schemes, that are being supported in vLLM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment