Unverified Commit 11d8e3ce authored by a120092009's avatar a120092009 Committed by GitHub
Browse files

[Quantization] support pass MappingType for TorchAoConfig (#10927)



* [Quantization] support pass MappingType for TorchAoConfig

* Apply style fixes

---------
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 97fda1b7
...@@ -47,6 +47,16 @@ class QuantizationMethod(str, Enum): ...@@ -47,6 +47,16 @@ class QuantizationMethod(str, Enum):
TORCHAO = "torchao" TORCHAO = "torchao"
if is_torchao_available:
from torchao.quantization.quant_primitives import MappingType
class TorchAoJSONEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, MappingType):
return obj.name
return super().default(obj)
@dataclass @dataclass
class QuantizationConfigMixin: class QuantizationConfigMixin:
""" """
...@@ -673,4 +683,6 @@ class TorchAoConfig(QuantizationConfigMixin): ...@@ -673,4 +683,6 @@ class TorchAoConfig(QuantizationConfigMixin):
``` ```
""" """
config_dict = self.to_dict() config_dict = self.to_dict()
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n" return (
f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True, cls=TorchAoJSONEncoder)}\n"
)
...@@ -76,6 +76,7 @@ if is_torch_available(): ...@@ -76,6 +76,7 @@ if is_torch_available():
if is_torchao_available(): if is_torchao_available():
from torchao.dtypes import AffineQuantizedTensor from torchao.dtypes import AffineQuantizedTensor
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
from torchao.quantization.quant_primitives import MappingType
from torchao.utils import get_model_size_in_bytes from torchao.utils import get_model_size_in_bytes
...@@ -122,6 +123,19 @@ class TorchAoConfigTest(unittest.TestCase): ...@@ -122,6 +123,19 @@ class TorchAoConfigTest(unittest.TestCase):
quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "") quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "")
self.assertEqual(quantization_repr, expected_repr) self.assertEqual(quantization_repr, expected_repr)
quantization_config = TorchAoConfig("int4dq", group_size=64, act_mapping_type=MappingType.SYMMETRIC)
expected_repr = """TorchAoConfig {
"modules_to_not_convert": null,
"quant_method": "torchao",
"quant_type": "int4dq",
"quant_type_kwargs": {
"act_mapping_type": "SYMMETRIC",
"group_size": 64
}
}""".replace(" ", "").replace("\n", "")
quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "")
self.assertEqual(quantization_repr, expected_repr)
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch @require_torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment