Unverified Commit 6223dd81 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated type hinting in `model_executor/layers` (#18056)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 906f0598
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Optional
import torch import torch
from compressed_tensors import CompressionFormat, ModelCompressor from compressed_tensors import CompressionFormat, ModelCompressor
...@@ -31,7 +31,7 @@ class CompressedTensors24(CompressedTensorsScheme): ...@@ -31,7 +31,7 @@ class CompressedTensors24(CompressedTensorsScheme):
quantized: bool = False, quantized: bool = False,
weight_quant: Optional[QuantizationArgs] = None, weight_quant: Optional[QuantizationArgs] = None,
input_quant: Optional[QuantizationArgs] = None, input_quant: Optional[QuantizationArgs] = None,
model_compression_config: Optional[Dict[str, Any]] = None, model_compression_config: Optional[dict[str, Any]] = None,
): ):
self.quantized = quantized self.quantized = quantized
self.weight_quant = weight_quant self.weight_quant = weight_quant
...@@ -53,7 +53,7 @@ class CompressedTensors24(CompressedTensorsScheme): ...@@ -53,7 +53,7 @@ class CompressedTensors24(CompressedTensorsScheme):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
input_size: int, input_size: int,
output_partition_sizes: List[int], output_partition_sizes: list[int],
input_size_per_partition: int, input_size_per_partition: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
weight_loader: Callable, weight_loader: Callable,
...@@ -327,9 +327,9 @@ class CompressedTensors24(CompressedTensorsScheme): ...@@ -327,9 +327,9 @@ class CompressedTensors24(CompressedTensorsScheme):
) )
return sparsity_compressor.decompress_weight(weight_data) return sparsity_compressor.decompress_weight(weight_data)
split_weights: List[torch.Tensor] = [] split_weights: list[torch.Tensor] = []
split_bitmask: List[torch.Tensor] = [] split_bitmask: list[torch.Tensor] = []
split_shape: List[Tuple[int, int]] = [] split_shape: list[tuple[int, int]] = []
if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)): if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)):
split_weights = torch.split(compressed, layer.logical_widths) split_weights = torch.split(compressed, layer.logical_widths)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Optional from typing import Callable, Optional
import torch import torch
from torch.nn import Parameter from torch.nn import Parameter
...@@ -58,7 +58,7 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): ...@@ -58,7 +58,7 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
layer.meta = Parameter(layer.meta.data, requires_grad=False) layer.meta = Parameter(layer.meta.data, requires_grad=False)
def create_weights(self, layer: torch.nn.Module, input_size: int, def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int], output_partition_sizes: list[int],
input_size_per_partition: int, input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable, params_dtype: torch.dtype, weight_loader: Callable,
**kwargs): **kwargs):
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Optional from typing import Callable, Optional
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -26,7 +26,7 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme): ...@@ -26,7 +26,7 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
return 80 return 80
def create_weights(self, layer: torch.nn.Module, def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int], output_partition_sizes: list[int],
input_size_per_partition: int, input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable, params_dtype: torch.dtype, weight_loader: Callable,
**kwargs): **kwargs):
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Optional from typing import Callable, Optional
import torch import torch
from compressed_tensors.quantization import QuantizationStrategy from compressed_tensors.quantization import QuantizationStrategy
...@@ -58,7 +58,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme): ...@@ -58,7 +58,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
prepare_fp8_layer_for_marlin(layer) prepare_fp8_layer_for_marlin(layer)
def create_weights(self, layer: torch.nn.Module, input_size: int, def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int], output_partition_sizes: list[int],
input_size_per_partition: int, input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable, params_dtype: torch.dtype, weight_loader: Callable,
**kwargs): **kwargs):
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Optional from typing import Callable, Optional
import torch import torch
from compressed_tensors.quantization import QuantizationStrategy from compressed_tensors.quantization import QuantizationStrategy
...@@ -90,7 +90,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -90,7 +90,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
layer.input_scale = None layer.input_scale = None
def create_weights(self, layer: torch.nn.Module, def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int], output_partition_sizes: list[int],
input_size_per_partition: int, input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable, params_dtype: torch.dtype, weight_loader: Callable,
**kwargs): **kwargs):
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Optional, Set from typing import Callable, Optional
import torch import torch
from compressed_tensors.quantization import QuantizationStrategy from compressed_tensors.quantization import QuantizationStrategy
...@@ -19,7 +19,7 @@ logger = init_logger(__name__) ...@@ -19,7 +19,7 @@ logger = init_logger(__name__)
class CompressedTensorsW8A8Int8(CompressedTensorsScheme): class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
_kernel_backends_being_used: Set[str] = set() _kernel_backends_being_used: set[str] = set()
def __init__(self, strategy: str, is_static_input_scheme: bool, def __init__(self, strategy: str, is_static_input_scheme: bool,
input_symmetric: bool): input_symmetric: bool):
...@@ -33,7 +33,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ...@@ -33,7 +33,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
return 75 return 75
def create_weights(self, layer: torch.nn.Module, def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int], output_partition_sizes: list[int],
input_size_per_partition: int, input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable, params_dtype: torch.dtype, weight_loader: Callable,
**kwargs): **kwargs):
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Optional, Set from typing import Callable, Optional
import torch import torch
from compressed_tensors.quantization import ActivationOrdering from compressed_tensors.quantization import ActivationOrdering
...@@ -35,7 +35,7 @@ WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys()) ...@@ -35,7 +35,7 @@ WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
class CompressedTensorsWNA16(CompressedTensorsScheme): class CompressedTensorsWNA16(CompressedTensorsScheme):
_kernel_backends_being_used: Set[str] = set() _kernel_backends_being_used: set[str] = set()
def __init__(self, def __init__(self,
strategy: str, strategy: str,
...@@ -70,7 +70,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -70,7 +70,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
return 80 return 80
def create_weights(self, layer: torch.nn.Module, output_size: int, def create_weights(self, layer: torch.nn.Module, output_size: int,
input_size: int, output_partition_sizes: List[int], input_size: int, output_partition_sizes: list[int],
input_size_per_partition: int, input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable, params_dtype: torch.dtype, weight_loader: Callable,
**kwargs): **kwargs):
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional, Type from typing import Optional
import torch import torch
...@@ -126,7 +126,7 @@ def triton_scaled_mm(input: torch.Tensor, ...@@ -126,7 +126,7 @@ def triton_scaled_mm(input: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
scale_a: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor, scale_b: torch.Tensor,
out_dtype: Type[torch.dtype], out_dtype: type[torch.dtype],
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
block_size_m: int = 32, block_size_m: int = 32,
block_size_n: int = 32, block_size_n: int = 32,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import re import re
from collections.abc import Iterable, Mapping
from types import MappingProxyType from types import MappingProxyType
from typing import Iterable, List, Mapping, Optional from typing import Optional
from compressed_tensors import CompressionFormat from compressed_tensors import CompressionFormat
from torch.nn import Module from torch.nn import Module
...@@ -20,7 +21,7 @@ def is_activation_quantization_format(format: str) -> bool: ...@@ -20,7 +21,7 @@ def is_activation_quantization_format(format: str) -> bool:
def should_ignore_layer( def should_ignore_layer(
layer_name: Optional[str], layer_name: Optional[str],
ignore: Iterable[str] = tuple(), ignore: Iterable[str] = tuple(),
fused_mapping: Mapping[str, List[str]] = MappingProxyType({}) fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
) -> bool: ) -> bool:
if layer_name is None: if layer_name is None:
return False return False
...@@ -84,7 +85,7 @@ def find_matched_target( ...@@ -84,7 +85,7 @@ def find_matched_target(
layer_name: Optional[str], layer_name: Optional[str],
module: Module, module: Module,
targets: Iterable[str], targets: Iterable[str],
fused_mapping: Mapping[str, List[str]] = MappingProxyType({}) fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
) -> str: ) -> str:
""" """
Helper function to look up which "target" in the compressed-tensors Helper function to look up which "target" in the compressed-tensors
...@@ -171,7 +172,7 @@ def _is_equal_or_regex_match(value: str, ...@@ -171,7 +172,7 @@ def _is_equal_or_regex_match(value: str,
def _match_fused_layer( def _match_fused_layer(
layer_name: str, target_layers: Iterable[str], layer_name: str, target_layers: Iterable[str],
fused_mapping: Mapping[str, List[str]]) -> Optional[str]: fused_mapping: Mapping[str, list[str]]) -> Optional[str]:
""" """
Match a fused layer name to its corresponding individual layer in Match a fused layer name to its corresponding individual layer in
target_layers. Returns first value in fused_mapping which matches targets target_layers. Returns first value in fused_mapping which matches targets
...@@ -201,7 +202,7 @@ def _match_fused_layer( ...@@ -201,7 +202,7 @@ def _match_fused_layer(
] ]
# for each unfused component, find a match in targets # for each unfused component, find a match in targets
unfused_matches: List[Optional[str]] = [] unfused_matches: list[Optional[str]] = []
for unfused in unfused_paths: for unfused in unfused_paths:
for target in target_layers: for target in target_layers:
if _is_equal_or_regex_match(unfused, target): if _is_equal_or_regex_match(unfused, target):
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional from typing import Any, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -46,7 +46,7 @@ class DeepSpeedFPConfig(QuantizationConfig): ...@@ -46,7 +46,7 @@ class DeepSpeedFPConfig(QuantizationConfig):
return "deepspeedfp" return "deepspeedfp"
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "DeepSpeedFPConfig": def from_config(cls, config: dict[str, Any]) -> "DeepSpeedFPConfig":
weight_bits = cls.get_from_keys(config, ["bits"]) weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"]) group_size = cls.get_from_keys(config, ["group_size"])
return cls(weight_bits=weight_bits, group_size=group_size) return cls(weight_bits=weight_bits, group_size=group_size)
...@@ -55,7 +55,7 @@ class DeepSpeedFPConfig(QuantizationConfig): ...@@ -55,7 +55,7 @@ class DeepSpeedFPConfig(QuantizationConfig):
return DeepSpeedFPLinearMethod(self) return DeepSpeedFPLinearMethod(self)
@classmethod @classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16] return [torch.half, torch.bfloat16]
@classmethod @classmethod
...@@ -64,7 +64,7 @@ class DeepSpeedFPConfig(QuantizationConfig): ...@@ -64,7 +64,7 @@ class DeepSpeedFPConfig(QuantizationConfig):
return 60 return 60
@staticmethod @staticmethod
def get_config_filenames() -> List[str]: def get_config_filenames() -> list[str]:
return [ return [
"quant_config.json", "quant_config.json",
"quantize_config.json", "quantize_config.json",
...@@ -91,7 +91,7 @@ class DeepSpeedFPLinearMethod(LinearMethodBase): ...@@ -91,7 +91,7 @@ class DeepSpeedFPLinearMethod(LinearMethodBase):
def create_weights(self, def create_weights(self,
layer: torch.nn.Module, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
output_partition_sizes: List[int], output_partition_sizes: list[int],
input_size: int, input_size: int,
output_size: int, output_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Optional
import torch import torch
...@@ -25,7 +25,7 @@ class ExpertsInt8Config(QuantizationConfig): ...@@ -25,7 +25,7 @@ class ExpertsInt8Config(QuantizationConfig):
return "experts_int8" return "experts_int8"
@classmethod @classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16, torch.half] return [torch.bfloat16, torch.half]
@classmethod @classmethod
...@@ -33,11 +33,11 @@ class ExpertsInt8Config(QuantizationConfig): ...@@ -33,11 +33,11 @@ class ExpertsInt8Config(QuantizationConfig):
return 80 return 80
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> list[str]:
return [] return []
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "ExpertsInt8Config": def from_config(cls, config: dict[str, Any]) -> "ExpertsInt8Config":
return cls() return cls()
def get_quant_method(self, layer: torch.nn.Module, def get_quant_method(self, layer: torch.nn.Module,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional from typing import Any, Optional
import torch import torch
from torch.nn import Module from torch.nn import Module
...@@ -28,7 +28,7 @@ logger = init_logger(__name__) ...@@ -28,7 +28,7 @@ logger = init_logger(__name__)
class FBGEMMFp8Config(QuantizationConfig): class FBGEMMFp8Config(QuantizationConfig):
"""Config class for FBGEMM Fp8.""" """Config class for FBGEMM Fp8."""
def __init__(self, ignore_list: List[str], input_scale_ub: float): def __init__(self, ignore_list: list[str], input_scale_ub: float):
super().__init__() super().__init__()
self.ignore_list = ignore_list if ignore_list else [] self.ignore_list = ignore_list if ignore_list else []
self.input_scale_ub = input_scale_ub self.input_scale_ub = input_scale_ub
...@@ -43,7 +43,7 @@ class FBGEMMFp8Config(QuantizationConfig): ...@@ -43,7 +43,7 @@ class FBGEMMFp8Config(QuantizationConfig):
return "fbgemm_fp8" return "fbgemm_fp8"
@classmethod @classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16, torch.float16] return [torch.bfloat16, torch.float16]
@classmethod @classmethod
...@@ -51,11 +51,11 @@ class FBGEMMFp8Config(QuantizationConfig): ...@@ -51,11 +51,11 @@ class FBGEMMFp8Config(QuantizationConfig):
return 80 return 80
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> list[str]:
return [] return []
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "FBGEMMFp8Config": def from_config(cls, config: dict[str, Any]) -> "FBGEMMFp8Config":
ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"]) ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"])
input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"]) input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"])
return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub) return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub)
...@@ -82,7 +82,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): ...@@ -82,7 +82,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
output_partition_sizes: List[int], output_partition_sizes: list[int],
input_size: int, input_size: int,
output_size: int, output_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import importlib.util import importlib.util
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -57,8 +57,8 @@ class Fp8Config(QuantizationConfig): ...@@ -57,8 +57,8 @@ class Fp8Config(QuantizationConfig):
self, self,
is_checkpoint_fp8_serialized: bool = False, is_checkpoint_fp8_serialized: bool = False,
activation_scheme: str = "dynamic", activation_scheme: str = "dynamic",
ignored_layers: Optional[List[str]] = None, ignored_layers: Optional[list[str]] = None,
weight_block_size: Optional[List[int]] = None, weight_block_size: Optional[list[int]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
...@@ -90,7 +90,7 @@ class Fp8Config(QuantizationConfig): ...@@ -90,7 +90,7 @@ class Fp8Config(QuantizationConfig):
return "fp8" return "fp8"
@classmethod @classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16, torch.half] return [torch.bfloat16, torch.half]
@classmethod @classmethod
...@@ -98,11 +98,11 @@ class Fp8Config(QuantizationConfig): ...@@ -98,11 +98,11 @@ class Fp8Config(QuantizationConfig):
return 80 return 80
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> list[str]:
return [] return []
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
quant_method = cls.get_from_keys(config, ["quant_method"]) quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_fp8_serialized = ("fp8" in quant_method) is_checkpoint_fp8_serialized = ("fp8" in quant_method)
activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
...@@ -191,7 +191,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -191,7 +191,7 @@ class Fp8LinearMethod(LinearMethodBase):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
output_partition_sizes: List[int], output_partition_sizes: list[int],
input_size: int, input_size: int,
output_size: int, output_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Optional
import gguf import gguf
import torch import torch
...@@ -35,7 +35,7 @@ class GGUFConfig(QuantizationConfig): ...@@ -35,7 +35,7 @@ class GGUFConfig(QuantizationConfig):
def get_name(self) -> QuantizationMethods: def get_name(self) -> QuantizationMethods:
return "gguf" return "gguf"
def get_supported_act_dtypes(self) -> List[torch.dtype]: def get_supported_act_dtypes(self) -> list[torch.dtype]:
return [torch.half, torch.bfloat16, torch.float32] return [torch.half, torch.bfloat16, torch.float32]
@classmethod @classmethod
...@@ -43,11 +43,11 @@ class GGUFConfig(QuantizationConfig): ...@@ -43,11 +43,11 @@ class GGUFConfig(QuantizationConfig):
return 60 return 60
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> list[str]:
return [] # no extra configs. return [] # no extra configs.
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "GGUFConfig": def from_config(cls, config: dict[str, Any]) -> "GGUFConfig":
return cls() return cls()
def get_quant_method(self, layer: torch.nn.Module, def get_quant_method(self, layer: torch.nn.Module,
...@@ -215,7 +215,7 @@ class GGUFLinearMethod(LinearMethodBase): ...@@ -215,7 +215,7 @@ class GGUFLinearMethod(LinearMethodBase):
def create_weights(self, layer: torch.nn.Module, def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int, output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype, output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs): **extra_weight_attrs):
self.params_dtype = params_dtype self.params_dtype = params_dtype
...@@ -406,7 +406,7 @@ class GGUFEmbeddingMethod(GGUFLinearMethod): ...@@ -406,7 +406,7 @@ class GGUFEmbeddingMethod(GGUFLinearMethod):
class GGUFUninitializedParameter(UninitializedParameter): class GGUFUninitializedParameter(UninitializedParameter):
cls_to_become = Parameter cls_to_become = Parameter
data_container: List[torch.Tensor] data_container: list[torch.Tensor]
def materialize_nested(self) -> Parameter: def materialize_nested(self) -> Parameter:
dtype = {data.dtype for data in self.data_container} dtype = {data.dtype for data in self.data_container}
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import enum import enum
from enum import Enum from enum import Enum
from fractions import Fraction from fractions import Fraction
from typing import Any, Dict, List, Optional, Union from typing import Any, Optional, Union
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -34,11 +34,11 @@ class GPTQConfig(QuantizationConfig): ...@@ -34,11 +34,11 @@ class GPTQConfig(QuantizationConfig):
group_size: int, group_size: int,
desc_act: bool, desc_act: bool,
lm_head_quantized: bool, lm_head_quantized: bool,
dynamic: Dict[str, Dict[str, Union[int, bool]]], dynamic: dict[str, dict[str, Union[int, bool]]],
) -> None: ) -> None:
# GPTQModel use `dynamic` config property to allow per module # GPTQModel use `dynamic` config property to allow per module
# quantization config so each module can be individually optimized. # quantization config so each module can be individually optimized.
# Format is Dict[str, Dict] where key is a regex string that can # Format is dict[str, dict] where key is a regex string that can
# perform both positive ("+:" prefixed) or negative ("-:" prefixed) # perform both positive ("+:" prefixed) or negative ("-:" prefixed)
# matching of a module. # matching of a module.
# Default to positive match, override base quant config mode, if no # Default to positive match, override base quant config mode, if no
...@@ -84,7 +84,7 @@ class GPTQConfig(QuantizationConfig): ...@@ -84,7 +84,7 @@ class GPTQConfig(QuantizationConfig):
return "gptq" return "gptq"
@classmethod @classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half] return [torch.half]
@classmethod @classmethod
...@@ -93,11 +93,11 @@ class GPTQConfig(QuantizationConfig): ...@@ -93,11 +93,11 @@ class GPTQConfig(QuantizationConfig):
return 60 return 60
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"] return ["quantize_config.json"]
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig": def from_config(cls, config: dict[str, Any]) -> "GPTQConfig":
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={}) dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
dynamic = {} if dynamic is None else dynamic dynamic = {} if dynamic is None else dynamic
...@@ -135,7 +135,7 @@ class GPTQLinearMethod(LinearMethodBase): ...@@ -135,7 +135,7 @@ class GPTQLinearMethod(LinearMethodBase):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
output_partition_sizes: List[int], output_partition_sizes: list[int],
input_size: int, input_size: int,
output_size: int, output_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional, Set from typing import Any, Optional
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -129,7 +129,7 @@ class GPTQBitBLASConfig(QuantizationConfig): ...@@ -129,7 +129,7 @@ class GPTQBitBLASConfig(QuantizationConfig):
return "gptq_bitblas" return "gptq_bitblas"
@classmethod @classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16] return [torch.half, torch.bfloat16]
@classmethod @classmethod
...@@ -137,11 +137,11 @@ class GPTQBitBLASConfig(QuantizationConfig): ...@@ -137,11 +137,11 @@ class GPTQBitBLASConfig(QuantizationConfig):
return 80 return 80
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"] return ["quantize_config.json"]
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQBitBLASConfig": def from_config(cls, config: dict[str, Any]) -> "GPTQBitBLASConfig":
weight_bits = cls.get_from_keys(config, ["bits"]) weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"]) group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"]) desc_act = cls.get_from_keys(config, ["desc_act"])
...@@ -185,7 +185,7 @@ class GPTQBitBLASConfig(QuantizationConfig): ...@@ -185,7 +185,7 @@ class GPTQBitBLASConfig(QuantizationConfig):
return self.TORCH_BITBLAS_STORAGE_DTYPE return self.TORCH_BITBLAS_STORAGE_DTYPE
@classmethod @classmethod
def is_gptq_bitblas_compatible(cls, quant_config: Dict[str, Any]): def is_gptq_bitblas_compatible(cls, quant_config: dict[str, Any]):
# Extract data from quant config. # Extract data from quant config.
num_bits = quant_config.get("bits") num_bits = quant_config.get("bits")
group_size = quant_config.get("group_size") group_size = quant_config.get("group_size")
...@@ -224,7 +224,7 @@ class GPTQBitBLASLinearMethod(LinearMethodBase): ...@@ -224,7 +224,7 @@ class GPTQBitBLASLinearMethod(LinearMethodBase):
""" """
kernel_type = BitBLASLinearKernel kernel_type = BitBLASLinearKernel
_kernel_backends_being_used: Set[str] = set() _kernel_backends_being_used: set[str] = set()
def __init__(self, quant_config: GPTQBitBLASConfig) -> None: def __init__(self, quant_config: GPTQBitBLASConfig) -> None:
self.quant_config = quant_config self.quant_config = quant_config
...@@ -236,7 +236,7 @@ class GPTQBitBLASLinearMethod(LinearMethodBase): ...@@ -236,7 +236,7 @@ class GPTQBitBLASLinearMethod(LinearMethodBase):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
output_partition_sizes: List[int], output_partition_sizes: list[int],
input_size: int, input_size: int,
output_size: int, output_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, List, Optional, Set, Union from typing import Any, Callable, Optional, Union
import torch import torch
...@@ -45,8 +45,8 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -45,8 +45,8 @@ class GPTQMarlinConfig(QuantizationConfig):
def __init__(self, weight_bits: int, group_size: int, desc_act: bool, def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
is_sym: bool, lm_head_quantized: bool, is_sym: bool, lm_head_quantized: bool,
dynamic: Dict[str, Dict[str, Union[int, bool]]], dynamic: dict[str, dict[str, Union[int, bool]]],
full_config: Dict[str, Any]) -> None: full_config: dict[str, Any]) -> None:
super().__init__() super().__init__()
if desc_act and group_size == -1: if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False # In this case, act_order == True is the same as act_order == False
...@@ -55,7 +55,7 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -55,7 +55,7 @@ class GPTQMarlinConfig(QuantizationConfig):
# GPTQModel use `dynamic` config property to allow per module # GPTQModel use `dynamic` config property to allow per module
# quantization config so each module can be individually optimized. # quantization config so each module can be individually optimized.
# Format is Dict[str, Dict] where key is a regex string that can # Format is dict[str, dict] where key is a regex string that can
# perform both positive ("+:" prefixed) or negative ("-:" prefixed) # perform both positive ("+:" prefixed) or negative ("-:" prefixed)
# matching of a module. # matching of a module.
# Default to positive match, override base quant config mode, if no # Default to positive match, override base quant config mode, if no
...@@ -105,7 +105,7 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -105,7 +105,7 @@ class GPTQMarlinConfig(QuantizationConfig):
return "gptq_marlin" return "gptq_marlin"
@classmethod @classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16] return [torch.half, torch.bfloat16]
@classmethod @classmethod
...@@ -113,11 +113,11 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -113,11 +113,11 @@ class GPTQMarlinConfig(QuantizationConfig):
return 80 return 80
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"] return ["quantize_config.json"]
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig": def from_config(cls, config: dict[str, Any]) -> "GPTQMarlinConfig":
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={}) dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
dynamic = {} if dynamic is None else dynamic dynamic = {} if dynamic is None else dynamic
...@@ -167,7 +167,7 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -167,7 +167,7 @@ class GPTQMarlinConfig(QuantizationConfig):
GPTQMarlinLinearMethod) GPTQMarlinLinearMethod)
@classmethod @classmethod
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]): def is_gptq_marlin_compatible(cls, quant_config: dict[str, Any]):
quant_method = quant_config.get("quant_method", "").lower() quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits") num_bits = quant_config.get("bits")
group_size = quant_config.get("group_size") group_size = quant_config.get("group_size")
...@@ -199,7 +199,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -199,7 +199,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
quant_config: The GPTQ Marlin quantization config. quant_config: The GPTQ Marlin quantization config.
""" """
_kernel_backends_being_used: Set[str] = set() _kernel_backends_being_used: set[str] = set()
def __init__(self, quant_config: GPTQMarlinConfig) -> None: def __init__(self, quant_config: GPTQMarlinConfig) -> None:
self.quant_config = quant_config self.quant_config = quant_config
...@@ -212,7 +212,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -212,7 +212,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
output_partition_sizes: List[int], output_partition_sizes: list[int],
input_size: int, input_size: int,
output_size: int, output_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional from typing import Any, Optional
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -90,7 +90,7 @@ class GPTQMarlin24Config(QuantizationConfig): ...@@ -90,7 +90,7 @@ class GPTQMarlin24Config(QuantizationConfig):
return "gptq_marlin_24" return "gptq_marlin_24"
@classmethod @classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half] return [torch.half]
@classmethod @classmethod
...@@ -99,11 +99,11 @@ class GPTQMarlin24Config(QuantizationConfig): ...@@ -99,11 +99,11 @@ class GPTQMarlin24Config(QuantizationConfig):
return 80 return 80
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"] return ["quantize_config.json"]
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlin24Config": def from_config(cls, config: dict[str, Any]) -> "GPTQMarlin24Config":
weight_bits = cls.get_from_keys(config, ["bits"]) weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"]) group_size = cls.get_from_keys(config, ["group_size"])
return cls(weight_bits, group_size) return cls(weight_bits, group_size)
...@@ -146,7 +146,7 @@ class GPTQMarlin24LinearMethod(LinearMethodBase): ...@@ -146,7 +146,7 @@ class GPTQMarlin24LinearMethod(LinearMethodBase):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
output_partition_sizes: List[int], output_partition_sizes: list[int],
input_size: int, input_size: int,
output_size: int, output_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional from typing import Any, Optional
import torch import torch
...@@ -32,7 +32,7 @@ class HQQMarlinConfig(QuantizationConfig): ...@@ -32,7 +32,7 @@ class HQQMarlinConfig(QuantizationConfig):
self, self,
weight_bits: int, weight_bits: int,
group_size: int, group_size: int,
skip_modules: Optional[List[str]] = None, skip_modules: Optional[list[str]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
assert group_size == 64, ("The only supported HQQ group size is " assert group_size == 64, ("The only supported HQQ group size is "
...@@ -55,7 +55,7 @@ class HQQMarlinConfig(QuantizationConfig): ...@@ -55,7 +55,7 @@ class HQQMarlinConfig(QuantizationConfig):
return "hqq" return "hqq"
@classmethod @classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16] return [torch.half, torch.bfloat16]
@classmethod @classmethod
...@@ -63,11 +63,11 @@ class HQQMarlinConfig(QuantizationConfig): ...@@ -63,11 +63,11 @@ class HQQMarlinConfig(QuantizationConfig):
return 80 return 80
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"] return ["quantize_config.json"]
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "HQQMarlinConfig": def from_config(cls, config: dict[str, Any]) -> "HQQMarlinConfig":
wq_params = (config["quant_config"]["weight_quant_params"]) wq_params = (config["quant_config"]["weight_quant_params"])
weight_bits = cls.get_from_keys(wq_params, ["nbits"]) weight_bits = cls.get_from_keys(wq_params, ["nbits"])
group_size = cls.get_from_keys(wq_params, ["group_size"]) group_size = cls.get_from_keys(wq_params, ["group_size"])
...@@ -192,7 +192,7 @@ class HQQMarlinMethod(LinearMethodBase): ...@@ -192,7 +192,7 @@ class HQQMarlinMethod(LinearMethodBase):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
output_partition_sizes: List[int], output_partition_sizes: list[int],
input_size: int, input_size: int,
output_size: int, output_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional from typing import Any, Optional
import torch import torch
...@@ -32,7 +32,7 @@ class IPEXConfig(QuantizationConfig): ...@@ -32,7 +32,7 @@ class IPEXConfig(QuantizationConfig):
method: str, method: str,
weight_bits: int, weight_bits: int,
group_size: int, group_size: int,
modules_to_not_convert: Optional[List[str]] = None, modules_to_not_convert: Optional[list[str]] = None,
desc_act: Optional[bool] = None, desc_act: Optional[bool] = None,
lm_head_quantized: Optional[bool] = None, lm_head_quantized: Optional[bool] = None,
) -> None: ) -> None:
...@@ -63,7 +63,7 @@ class IPEXConfig(QuantizationConfig): ...@@ -63,7 +63,7 @@ class IPEXConfig(QuantizationConfig):
return "ipex" return "ipex"
@classmethod @classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16, torch.float16] return [torch.bfloat16, torch.float16]
@classmethod @classmethod
...@@ -71,14 +71,14 @@ class IPEXConfig(QuantizationConfig): ...@@ -71,14 +71,14 @@ class IPEXConfig(QuantizationConfig):
return -1 return -1
@staticmethod @staticmethod
def get_config_filenames() -> List[str]: def get_config_filenames() -> list[str]:
return [ return [
"quant_config.json", "quant_config.json",
"quantize_config.json", "quantize_config.json",
] ]
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "IPEXConfig": def from_config(cls, config: dict[str, Any]) -> "IPEXConfig":
method = cls.get_from_keys(config, ["quant_method"]).lower() method = cls.get_from_keys(config, ["quant_method"]).lower()
if method == "awq": if method == "awq":
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment