Unverified Commit 12913d17 authored by Kyle Sayers's avatar Kyle Sayers Committed by GitHub
Browse files

[Quant] Add `SupportsQuant` to phi3 and clip (#13104)

parent 80f63a39
...@@ -169,6 +169,7 @@ class AQLMConfig(QuantizationConfig): ...@@ -169,6 +169,7 @@ class AQLMConfig(QuantizationConfig):
num_codebooks: int, num_codebooks: int,
out_group_size: int, out_group_size: int,
) -> None: ) -> None:
super().__init__()
self.in_group_size = in_group_size self.in_group_size = in_group_size
self.nbits_per_codebook = nbits_per_codebook self.nbits_per_codebook = nbits_per_codebook
self.num_codebooks = num_codebooks self.num_codebooks = num_codebooks
......
...@@ -26,6 +26,7 @@ class AWQConfig(QuantizationConfig): ...@@ -26,6 +26,7 @@ class AWQConfig(QuantizationConfig):
zero_point: bool, zero_point: bool,
modules_to_not_convert: Optional[List[str]] = None, modules_to_not_convert: Optional[List[str]] = None,
) -> None: ) -> None:
super().__init__()
self.weight_bits = weight_bits self.weight_bits = weight_bits
self.group_size = group_size self.group_size = group_size
self.zero_point = zero_point self.zero_point = zero_point
......
...@@ -47,6 +47,7 @@ class AWQMarlinConfig(QuantizationConfig): ...@@ -47,6 +47,7 @@ class AWQMarlinConfig(QuantizationConfig):
lm_head_quantized: bool, lm_head_quantized: bool,
modules_to_not_convert: Optional[List[str]], modules_to_not_convert: Optional[List[str]],
full_config: Dict[str, Any]) -> None: full_config: Dict[str, Any]) -> None:
super().__init__()
self.pack_factor = 32 // weight_bits # packed into int32 self.pack_factor = 32 // weight_bits # packed into int32
self.group_size = group_size self.group_size = group_size
self.zero_point = zero_point self.zero_point = zero_point
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import inspect import inspect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Mapping, Optional, Type from typing import Any, Dict, List, Optional, Type
import torch import torch
from torch import nn from torch import nn
...@@ -59,7 +59,11 @@ def method_has_implemented_embedding( ...@@ -59,7 +59,11 @@ def method_has_implemented_embedding(
class QuantizationConfig(ABC): class QuantizationConfig(ABC):
"""Base class for quantization configs.""" """Base class for quantization configs."""
packed_modules_mapping: Mapping[str, List[str]] = dict()
def __init__(self):
super().__init__()
# mapping is updated by models as they initialize
self.packed_modules_mapping: Dict[str, List[str]] = dict()
@abstractmethod @abstractmethod
def get_name(self) -> str: def get_name(self) -> str:
......
...@@ -30,7 +30,7 @@ class BitsAndBytesConfig(QuantizationConfig): ...@@ -30,7 +30,7 @@ class BitsAndBytesConfig(QuantizationConfig):
llm_int8_skip_modules: Optional[List[str]] = None, llm_int8_skip_modules: Optional[List[str]] = None,
llm_int8_threshold: float = 6.0, llm_int8_threshold: float = 6.0,
) -> None: ) -> None:
super().__init__()
self.load_in_8bit = load_in_8bit self.load_in_8bit = load_in_8bit
self.load_in_4bit = load_in_4bit self.load_in_4bit = load_in_4bit
self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype
......
...@@ -51,7 +51,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -51,7 +51,7 @@ class CompressedTensorsConfig(QuantizationConfig):
kv_cache_scheme: Optional[Dict[str, Any]] = None, kv_cache_scheme: Optional[Dict[str, Any]] = None,
config: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None,
): ):
super().__init__()
self.ignore = ignore self.ignore = ignore
self.quant_format = quant_format self.quant_format = quant_format
# Map from [target -> scheme] # Map from [target -> scheme]
......
...@@ -25,6 +25,7 @@ class DeepSpeedFPConfig(QuantizationConfig): ...@@ -25,6 +25,7 @@ class DeepSpeedFPConfig(QuantizationConfig):
weight_bits: int = 8, weight_bits: int = 8,
group_size: int = 512, group_size: int = 512,
) -> None: ) -> None:
super().__init__()
self.weight_bits = weight_bits self.weight_bits = weight_bits
self.group_size = group_size self.group_size = group_size
self.valid_types = [torch.bfloat16, torch.float16] self.valid_types = [torch.bfloat16, torch.float16]
......
...@@ -17,7 +17,7 @@ class ExpertsInt8Config(QuantizationConfig): ...@@ -17,7 +17,7 @@ class ExpertsInt8Config(QuantizationConfig):
"""Config class for Int8 experts quantization.""" """Config class for Int8 experts quantization."""
def __init__(self) -> None: def __init__(self) -> None:
pass super().__init__()
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> str:
......
...@@ -29,6 +29,7 @@ class FBGEMMFp8Config(QuantizationConfig): ...@@ -29,6 +29,7 @@ class FBGEMMFp8Config(QuantizationConfig):
"""Config class for FBGEMM Fp8.""" """Config class for FBGEMM Fp8."""
def __init__(self, ignore_list: List[str], input_scale_ub: float): def __init__(self, ignore_list: List[str], input_scale_ub: float):
super().__init__()
self.ignore_list = ignore_list if ignore_list else [] self.ignore_list = ignore_list if ignore_list else []
self.input_scale_ub = input_scale_ub self.input_scale_ub = input_scale_ub
......
...@@ -47,6 +47,7 @@ class Fp8Config(QuantizationConfig): ...@@ -47,6 +47,7 @@ class Fp8Config(QuantizationConfig):
ignored_layers: Optional[List[str]] = None, ignored_layers: Optional[List[str]] = None,
weight_block_size: Optional[List[int]] = None, weight_block_size: Optional[List[int]] = None,
) -> None: ) -> None:
super().__init__()
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
if is_checkpoint_fp8_serialized: if is_checkpoint_fp8_serialized:
logger.warning("Detected fp8 checkpoint. Please note that the " logger.warning("Detected fp8 checkpoint. Please note that the "
......
...@@ -20,7 +20,7 @@ class GGUFConfig(QuantizationConfig): ...@@ -20,7 +20,7 @@ class GGUFConfig(QuantizationConfig):
"""Config class for GGUF.""" """Config class for GGUF."""
def __init__(self, ) -> None: def __init__(self, ) -> None:
pass super().__init__()
def __repr__(self) -> str: def __repr__(self) -> str:
return ("GGUFConfig()") return ("GGUFConfig()")
......
...@@ -58,6 +58,7 @@ class GPTQConfig(QuantizationConfig): ...@@ -58,6 +58,7 @@ class GPTQConfig(QuantizationConfig):
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,}, # r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers # r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
# } # }
super().__init__()
self.dynamic = dynamic self.dynamic = dynamic
self.weight_bits = weight_bits self.weight_bits = weight_bits
......
...@@ -46,6 +46,7 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -46,6 +46,7 @@ class GPTQMarlinConfig(QuantizationConfig):
is_sym: bool, lm_head_quantized: bool, is_sym: bool, lm_head_quantized: bool,
dynamic: Dict[str, Dict[str, Union[int, bool]]], dynamic: Dict[str, Dict[str, Union[int, bool]]],
full_config: Dict[str, Any]) -> None: full_config: Dict[str, Any]) -> None:
super().__init__()
if desc_act and group_size == -1: if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False # In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel) # (since we have only one group per output channel)
......
...@@ -38,6 +38,7 @@ class GPTQMarlin24Config(QuantizationConfig): ...@@ -38,6 +38,7 @@ class GPTQMarlin24Config(QuantizationConfig):
weight_bits: int, weight_bits: int,
group_size: int, group_size: int,
) -> None: ) -> None:
super().__init__()
quant_type = { quant_type = {
4: scalar_types.uint4b8, 4: scalar_types.uint4b8,
8: scalar_types.uint8b128, 8: scalar_types.uint8b128,
......
...@@ -33,6 +33,7 @@ class HQQMarlinConfig(QuantizationConfig): ...@@ -33,6 +33,7 @@ class HQQMarlinConfig(QuantizationConfig):
group_size: int, group_size: int,
skip_modules: Optional[List[str]] = None, skip_modules: Optional[List[str]] = None,
) -> None: ) -> None:
super().__init__()
assert group_size == 64, ("The only supported HQQ group size is " assert group_size == 64, ("The only supported HQQ group size is "
"currently 64.") "currently 64.")
assert weight_bits == 4, ("The only supported HQQ quantization " assert weight_bits == 4, ("The only supported HQQ quantization "
......
...@@ -35,6 +35,7 @@ class IPEXConfig(QuantizationConfig): ...@@ -35,6 +35,7 @@ class IPEXConfig(QuantizationConfig):
desc_act: Optional[bool] = None, desc_act: Optional[bool] = None,
lm_head_quantized: Optional[bool] = None, lm_head_quantized: Optional[bool] = None,
) -> None: ) -> None:
super().__init__()
self.method = method self.method = method
self.weight_bits = weight_bits self.weight_bits = weight_bits
self.group_size = group_size self.group_size = group_size
......
...@@ -28,6 +28,7 @@ class ModelOptFp8Config(QuantizationConfig): ...@@ -28,6 +28,7 @@ class ModelOptFp8Config(QuantizationConfig):
self, self,
is_checkpoint_fp8_serialized: bool = False, is_checkpoint_fp8_serialized: bool = False,
) -> None: ) -> None:
super().__init__()
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
if is_checkpoint_fp8_serialized: if is_checkpoint_fp8_serialized:
logger.warning("Detected ModelOpt fp8 checkpoint. Please note that" logger.warning("Detected ModelOpt fp8 checkpoint. Please note that"
......
...@@ -24,6 +24,7 @@ class MoeWNA16Config(QuantizationConfig): ...@@ -24,6 +24,7 @@ class MoeWNA16Config(QuantizationConfig):
group_size: int, has_zp: bool, lm_head_quantized: bool, group_size: int, has_zp: bool, lm_head_quantized: bool,
modules_to_not_convert: Optional[List[str]], modules_to_not_convert: Optional[List[str]],
full_config: Dict[str, Any]) -> None: full_config: Dict[str, Any]) -> None:
super().__init__()
self.weight_bits = weight_bits self.weight_bits = weight_bits
self.group_size = group_size self.group_size = group_size
self.has_zp = has_zp self.has_zp = has_zp
......
...@@ -20,6 +20,7 @@ class NeuronQuantConfig(QuantizationConfig): ...@@ -20,6 +20,7 @@ class NeuronQuantConfig(QuantizationConfig):
dequant_dtype: str = "f16", dequant_dtype: str = "f16",
quantize_method: str = "vector_dynamic", quantize_method: str = "vector_dynamic",
) -> None: ) -> None:
super().__init__()
self.quant_dtype = os.getenv("NEURON_QUANT_DTYPE", "s8") self.quant_dtype = os.getenv("NEURON_QUANT_DTYPE", "s8")
if self.quant_dtype not in SUPPORTED_QUANT_DTYPE_LIST: if self.quant_dtype not in SUPPORTED_QUANT_DTYPE_LIST:
raise ValueError( raise ValueError(
......
...@@ -39,6 +39,7 @@ class QQQConfig(QuantizationConfig): ...@@ -39,6 +39,7 @@ class QQQConfig(QuantizationConfig):
group_size: int, group_size: int,
is_sym: bool = True, is_sym: bool = True,
) -> None: ) -> None:
super().__init__()
self.weight_bits = weight_bits self.weight_bits = weight_bits
self.group_size = group_size self.group_size = group_size
self.is_sym = is_sym self.is_sym = is_sym
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment