Unverified Commit 11d8e3ce authored by a120092009's avatar a120092009 Committed by GitHub
Browse files

[Quantization] support pass MappingType for TorchAoConfig (#10927)



* [Quantization] support pass MappingType for TorchAoConfig

* Apply style fixes

---------
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 97fda1b7
......@@ -47,6 +47,16 @@ class QuantizationMethod(str, Enum):
TORCHAO = "torchao"
if is_torchao_available:
from torchao.quantization.quant_primitives import MappingType
class TorchAoJSONEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, MappingType):
return obj.name
return super().default(obj)
@dataclass
class QuantizationConfigMixin:
"""
......@@ -673,4 +683,6 @@ class TorchAoConfig(QuantizationConfigMixin):
```
"""
config_dict = self.to_dict()
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
return (
f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True, cls=TorchAoJSONEncoder)}\n"
)
......@@ -76,6 +76,7 @@ if is_torch_available():
if is_torchao_available():
from torchao.dtypes import AffineQuantizedTensor
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
from torchao.quantization.quant_primitives import MappingType
from torchao.utils import get_model_size_in_bytes
......@@ -122,6 +123,19 @@ class TorchAoConfigTest(unittest.TestCase):
quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "")
self.assertEqual(quantization_repr, expected_repr)
quantization_config = TorchAoConfig("int4dq", group_size=64, act_mapping_type=MappingType.SYMMETRIC)
expected_repr = """TorchAoConfig {
"modules_to_not_convert": null,
"quant_method": "torchao",
"quant_type": "int4dq",
"quant_type_kwargs": {
"act_mapping_type": "SYMMETRIC",
"group_size": 64
}
}""".replace(" ", "").replace("\n", "")
quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "")
self.assertEqual(quantization_repr, expected_repr)
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment