Unverified Commit 6223dd81 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated type hinting in `model_executor/layers` (#18056)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 906f0598
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Optional
import torch
from compressed_tensors import CompressionFormat, ModelCompressor
......@@ -31,7 +31,7 @@ class CompressedTensors24(CompressedTensorsScheme):
quantized: bool = False,
weight_quant: Optional[QuantizationArgs] = None,
input_quant: Optional[QuantizationArgs] = None,
model_compression_config: Optional[Dict[str, Any]] = None,
model_compression_config: Optional[dict[str, Any]] = None,
):
self.quantized = quantized
self.weight_quant = weight_quant
......@@ -53,7 +53,7 @@ class CompressedTensors24(CompressedTensorsScheme):
self,
layer: torch.nn.Module,
input_size: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
......@@ -327,9 +327,9 @@ class CompressedTensors24(CompressedTensorsScheme):
)
return sparsity_compressor.decompress_weight(weight_data)
split_weights: List[torch.Tensor] = []
split_bitmask: List[torch.Tensor] = []
split_shape: List[Tuple[int, int]] = []
split_weights: list[torch.Tensor] = []
split_bitmask: list[torch.Tensor] = []
split_shape: list[tuple[int, int]] = []
if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)):
split_weights = torch.split(compressed, layer.logical_widths)
......
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Optional
from typing import Callable, Optional
import torch
from torch.nn import Parameter
......@@ -58,7 +58,7 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
layer.meta = Parameter(layer.meta.data, requires_grad=False)
def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
......
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Optional
from typing import Callable, Optional
import torch
from torch.nn.parameter import Parameter
......@@ -26,7 +26,7 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
return 80
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
......
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Optional
from typing import Callable, Optional
import torch
from compressed_tensors.quantization import QuantizationStrategy
......@@ -58,7 +58,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
prepare_fp8_layer_for_marlin(layer)
def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
......
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Optional
from typing import Callable, Optional
import torch
from compressed_tensors.quantization import QuantizationStrategy
......@@ -90,7 +90,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
layer.input_scale = None
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
......
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Optional, Set
from typing import Callable, Optional
import torch
from compressed_tensors.quantization import QuantizationStrategy
......@@ -19,7 +19,7 @@ logger = init_logger(__name__)
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
_kernel_backends_being_used: Set[str] = set()
_kernel_backends_being_used: set[str] = set()
def __init__(self, strategy: str, is_static_input_scheme: bool,
input_symmetric: bool):
......@@ -33,7 +33,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
return 75
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
......
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Optional, Set
from typing import Callable, Optional
import torch
from compressed_tensors.quantization import ActivationOrdering
......@@ -35,7 +35,7 @@ WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
class CompressedTensorsWNA16(CompressedTensorsScheme):
_kernel_backends_being_used: Set[str] = set()
_kernel_backends_being_used: set[str] = set()
def __init__(self,
strategy: str,
......@@ -70,7 +70,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
return 80
def create_weights(self, layer: torch.nn.Module, output_size: int,
input_size: int, output_partition_sizes: List[int],
input_size: int, output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
......
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Type
from typing import Optional
import torch
......@@ -126,7 +126,7 @@ def triton_scaled_mm(input: torch.Tensor,
weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype],
out_dtype: type[torch.dtype],
bias: Optional[torch.Tensor] = None,
block_size_m: int = 32,
block_size_n: int = 32,
......
# SPDX-License-Identifier: Apache-2.0
import re
from collections.abc import Iterable, Mapping
from types import MappingProxyType
from typing import Iterable, List, Mapping, Optional
from typing import Optional
from compressed_tensors import CompressionFormat
from torch.nn import Module
......@@ -20,7 +21,7 @@ def is_activation_quantization_format(format: str) -> bool:
def should_ignore_layer(
layer_name: Optional[str],
ignore: Iterable[str] = tuple(),
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
) -> bool:
if layer_name is None:
return False
......@@ -84,7 +85,7 @@ def find_matched_target(
layer_name: Optional[str],
module: Module,
targets: Iterable[str],
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
) -> str:
"""
Helper function to look up which "target" in the compressed-tensors
......@@ -171,7 +172,7 @@ def _is_equal_or_regex_match(value: str,
def _match_fused_layer(
layer_name: str, target_layers: Iterable[str],
fused_mapping: Mapping[str, List[str]]) -> Optional[str]:
fused_mapping: Mapping[str, list[str]]) -> Optional[str]:
"""
Match a fused layer name to its corresponding individual layer in
target_layers. Returns first value in fused_mapping which matches targets
......@@ -201,7 +202,7 @@ def _match_fused_layer(
]
# for each unfused component, find a match in targets
unfused_matches: List[Optional[str]] = []
unfused_matches: list[Optional[str]] = []
for unfused in unfused_paths:
for target in target_layers:
if _is_equal_or_regex_match(unfused, target):
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import torch
import torch.nn as nn
......@@ -46,7 +46,7 @@ class DeepSpeedFPConfig(QuantizationConfig):
return "deepspeedfp"
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "DeepSpeedFPConfig":
def from_config(cls, config: dict[str, Any]) -> "DeepSpeedFPConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
return cls(weight_bits=weight_bits, group_size=group_size)
......@@ -55,7 +55,7 @@ class DeepSpeedFPConfig(QuantizationConfig):
return DeepSpeedFPLinearMethod(self)
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
......@@ -64,7 +64,7 @@ class DeepSpeedFPConfig(QuantizationConfig):
return 60
@staticmethod
def get_config_filenames() -> List[str]:
def get_config_filenames() -> list[str]:
return [
"quant_config.json",
"quantize_config.json",
......@@ -91,7 +91,7 @@ class DeepSpeedFPLinearMethod(LinearMethodBase):
def create_weights(self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Optional
import torch
......@@ -25,7 +25,7 @@ class ExpertsInt8Config(QuantizationConfig):
return "experts_int8"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
......@@ -33,11 +33,11 @@ class ExpertsInt8Config(QuantizationConfig):
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ExpertsInt8Config":
def from_config(cls, config: dict[str, Any]) -> "ExpertsInt8Config":
return cls()
def get_quant_method(self, layer: torch.nn.Module,
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import torch
from torch.nn import Module
......@@ -28,7 +28,7 @@ logger = init_logger(__name__)
class FBGEMMFp8Config(QuantizationConfig):
"""Config class for FBGEMM Fp8."""
def __init__(self, ignore_list: List[str], input_scale_ub: float):
def __init__(self, ignore_list: list[str], input_scale_ub: float):
super().__init__()
self.ignore_list = ignore_list if ignore_list else []
self.input_scale_ub = input_scale_ub
......@@ -43,7 +43,7 @@ class FBGEMMFp8Config(QuantizationConfig):
return "fbgemm_fp8"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16, torch.float16]
@classmethod
......@@ -51,11 +51,11 @@ class FBGEMMFp8Config(QuantizationConfig):
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "FBGEMMFp8Config":
def from_config(cls, config: dict[str, Any]) -> "FBGEMMFp8Config":
ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"])
input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"])
return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub)
......@@ -82,7 +82,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
......
# SPDX-License-Identifier: Apache-2.0
import importlib.util
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Optional
import torch
import torch.nn.functional as F
......@@ -57,8 +57,8 @@ class Fp8Config(QuantizationConfig):
self,
is_checkpoint_fp8_serialized: bool = False,
activation_scheme: str = "dynamic",
ignored_layers: Optional[List[str]] = None,
weight_block_size: Optional[List[int]] = None,
ignored_layers: Optional[list[str]] = None,
weight_block_size: Optional[list[int]] = None,
) -> None:
super().__init__()
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
......@@ -90,7 +90,7 @@ class Fp8Config(QuantizationConfig):
return "fp8"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
......@@ -98,11 +98,11 @@ class Fp8Config(QuantizationConfig):
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_fp8_serialized = ("fp8" in quant_method)
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
......@@ -191,7 +191,7 @@ class Fp8LinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Optional
import gguf
import torch
......@@ -35,7 +35,7 @@ class GGUFConfig(QuantizationConfig):
def get_name(self) -> QuantizationMethods:
return "gguf"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
def get_supported_act_dtypes(self) -> list[torch.dtype]:
return [torch.half, torch.bfloat16, torch.float32]
@classmethod
......@@ -43,11 +43,11 @@ class GGUFConfig(QuantizationConfig):
return 60
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return [] # no extra configs.
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GGUFConfig":
def from_config(cls, config: dict[str, Any]) -> "GGUFConfig":
return cls()
def get_quant_method(self, layer: torch.nn.Module,
......@@ -215,7 +215,7 @@ class GGUFLinearMethod(LinearMethodBase):
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
self.params_dtype = params_dtype
......@@ -406,7 +406,7 @@ class GGUFEmbeddingMethod(GGUFLinearMethod):
class GGUFUninitializedParameter(UninitializedParameter):
cls_to_become = Parameter
data_container: List[torch.Tensor]
data_container: list[torch.Tensor]
def materialize_nested(self) -> Parameter:
dtype = {data.dtype for data in self.data_container}
......
......@@ -3,7 +3,7 @@
import enum
from enum import Enum
from fractions import Fraction
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union
import torch
from torch.nn.parameter import Parameter
......@@ -34,11 +34,11 @@ class GPTQConfig(QuantizationConfig):
group_size: int,
desc_act: bool,
lm_head_quantized: bool,
dynamic: Dict[str, Dict[str, Union[int, bool]]],
dynamic: dict[str, dict[str, Union[int, bool]]],
) -> None:
# GPTQModel use `dynamic` config property to allow per module
# quantization config so each module can be individually optimized.
# Format is Dict[str, Dict] where key is a regex string that can
# Format is dict[str, dict] where key is a regex string that can
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
# matching of a module.
# Default to positive match, override base quant config mode, if no
......@@ -84,7 +84,7 @@ class GPTQConfig(QuantizationConfig):
return "gptq"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half]
@classmethod
......@@ -93,11 +93,11 @@ class GPTQConfig(QuantizationConfig):
return 60
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
def from_config(cls, config: dict[str, Any]) -> "GPTQConfig":
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
dynamic = {} if dynamic is None else dynamic
......@@ -135,7 +135,7 @@ class GPTQLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional, Set
from typing import Any, Optional
import torch
from torch.nn.parameter import Parameter
......@@ -129,7 +129,7 @@ class GPTQBitBLASConfig(QuantizationConfig):
return "gptq_bitblas"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
......@@ -137,11 +137,11 @@ class GPTQBitBLASConfig(QuantizationConfig):
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQBitBLASConfig":
def from_config(cls, config: dict[str, Any]) -> "GPTQBitBLASConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
......@@ -185,7 +185,7 @@ class GPTQBitBLASConfig(QuantizationConfig):
return self.TORCH_BITBLAS_STORAGE_DTYPE
@classmethod
def is_gptq_bitblas_compatible(cls, quant_config: Dict[str, Any]):
def is_gptq_bitblas_compatible(cls, quant_config: dict[str, Any]):
# Extract data from quant config.
num_bits = quant_config.get("bits")
group_size = quant_config.get("group_size")
......@@ -224,7 +224,7 @@ class GPTQBitBLASLinearMethod(LinearMethodBase):
"""
kernel_type = BitBLASLinearKernel
_kernel_backends_being_used: Set[str] = set()
_kernel_backends_being_used: set[str] = set()
def __init__(self, quant_config: GPTQBitBLASConfig) -> None:
self.quant_config = quant_config
......@@ -236,7 +236,7 @@ class GPTQBitBLASLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, List, Optional, Set, Union
from typing import Any, Callable, Optional, Union
import torch
......@@ -45,8 +45,8 @@ class GPTQMarlinConfig(QuantizationConfig):
def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
is_sym: bool, lm_head_quantized: bool,
dynamic: Dict[str, Dict[str, Union[int, bool]]],
full_config: Dict[str, Any]) -> None:
dynamic: dict[str, dict[str, Union[int, bool]]],
full_config: dict[str, Any]) -> None:
super().__init__()
if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
......@@ -55,7 +55,7 @@ class GPTQMarlinConfig(QuantizationConfig):
# GPTQModel use `dynamic` config property to allow per module
# quantization config so each module can be individually optimized.
# Format is Dict[str, Dict] where key is a regex string that can
# Format is dict[str, dict] where key is a regex string that can
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
# matching of a module.
# Default to positive match, override base quant config mode, if no
......@@ -105,7 +105,7 @@ class GPTQMarlinConfig(QuantizationConfig):
return "gptq_marlin"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
......@@ -113,11 +113,11 @@ class GPTQMarlinConfig(QuantizationConfig):
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
def from_config(cls, config: dict[str, Any]) -> "GPTQMarlinConfig":
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
dynamic = {} if dynamic is None else dynamic
......@@ -167,7 +167,7 @@ class GPTQMarlinConfig(QuantizationConfig):
GPTQMarlinLinearMethod)
@classmethod
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
def is_gptq_marlin_compatible(cls, quant_config: dict[str, Any]):
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits")
group_size = quant_config.get("group_size")
......@@ -199,7 +199,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
quant_config: The GPTQ Marlin quantization config.
"""
_kernel_backends_being_used: Set[str] = set()
_kernel_backends_being_used: set[str] = set()
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
self.quant_config = quant_config
......@@ -212,7 +212,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import torch
from torch.nn.parameter import Parameter
......@@ -90,7 +90,7 @@ class GPTQMarlin24Config(QuantizationConfig):
return "gptq_marlin_24"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half]
@classmethod
......@@ -99,11 +99,11 @@ class GPTQMarlin24Config(QuantizationConfig):
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlin24Config":
def from_config(cls, config: dict[str, Any]) -> "GPTQMarlin24Config":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
return cls(weight_bits, group_size)
......@@ -146,7 +146,7 @@ class GPTQMarlin24LinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import torch
......@@ -32,7 +32,7 @@ class HQQMarlinConfig(QuantizationConfig):
self,
weight_bits: int,
group_size: int,
skip_modules: Optional[List[str]] = None,
skip_modules: Optional[list[str]] = None,
) -> None:
super().__init__()
assert group_size == 64, ("The only supported HQQ group size is "
......@@ -55,7 +55,7 @@ class HQQMarlinConfig(QuantizationConfig):
return "hqq"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
......@@ -63,11 +63,11 @@ class HQQMarlinConfig(QuantizationConfig):
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "HQQMarlinConfig":
def from_config(cls, config: dict[str, Any]) -> "HQQMarlinConfig":
wq_params = (config["quant_config"]["weight_quant_params"])
weight_bits = cls.get_from_keys(wq_params, ["nbits"])
group_size = cls.get_from_keys(wq_params, ["group_size"])
......@@ -192,7 +192,7 @@ class HQQMarlinMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import torch
......@@ -32,7 +32,7 @@ class IPEXConfig(QuantizationConfig):
method: str,
weight_bits: int,
group_size: int,
modules_to_not_convert: Optional[List[str]] = None,
modules_to_not_convert: Optional[list[str]] = None,
desc_act: Optional[bool] = None,
lm_head_quantized: Optional[bool] = None,
) -> None:
......@@ -63,7 +63,7 @@ class IPEXConfig(QuantizationConfig):
return "ipex"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16, torch.float16]
@classmethod
......@@ -71,14 +71,14 @@ class IPEXConfig(QuantizationConfig):
return -1
@staticmethod
def get_config_filenames() -> List[str]:
def get_config_filenames() -> list[str]:
return [
"quant_config.json",
"quantize_config.json",
]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "IPEXConfig":
def from_config(cls, config: dict[str, Any]) -> "IPEXConfig":
method = cls.get_from_keys(config, ["quant_method"]).lower()
if method == "awq":
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment