Commit 1850a632 authored by Kai Zhang's avatar Kai Zhang Committed by Facebook GitHub Bot
Browse files

Sync quantization callback changes

Reviewed By: newstzpz

Differential Revision: D27255960

fbshipit-source-id: 1699ff23d2bc610dffc0215a90a7c1c17e3783c3
parent 1027896a
......@@ -40,6 +40,16 @@ def rgetattr(obj: Any, attr: str, *args) -> Any:
return functools.reduce(_getattr, [obj] + attr.split("."))
def rhasattr(obj: Any, attr: str, *args) -> bool:
""" Same as hasattr but supports deeply nested objects. """
try:
_ = rgetattr(obj, attr, *args)
except AttributeError:
return False
return True
def _deepcopy(pl_module: LightningModule) -> LightningModule:
"""Copy a LightningModule. Some properties need to be ignored. """
# Remove _result before call to deepcopy since it store non-leaf Tensors.
......@@ -101,11 +111,11 @@ class QuantizationMixin(ABC):
As such, we could do something like the below, shown here for QAT.
>>> class MyQuantizationCallback(QuantizedAwareTraining):
... def prepare(self, model, config):
... def prepare(self, model, config, attrs):
... model.traceable = prepare_qat_fx(model.traceable, config)
... return model
...
... def convert(self, model):
... def convert(self, model, attr):
... model.traceable = convert_fx(model.traceable)
... return model
......@@ -122,7 +132,9 @@ class QuantizationMixin(ABC):
"""
def prepare(self, root: LightningModule, configs: QConfigDicts) -> torch.nn.Module:
def prepare(
self, root: LightningModule, configs: QConfigDicts, attrs: Set[str]
) -> torch.nn.Module:
"""Prepares the root user modules for quantization.
By default, this tries to prepare the entire LightningModule. If this is
......@@ -135,6 +147,7 @@ class QuantizationMixin(ABC):
root: The LightningModule as given to the lightning Trainer in train mode.
configs: Specification to be used when preparing the model, as provided by the user.
It is guaranteed that no key is a suffix of another.
attrs: The list of attributes to maintain for the module.
Returns:
The prepared Module to be used for quantized aware training.
......@@ -144,15 +157,23 @@ class QuantizationMixin(ABC):
if isinstance(self, QuantizationAwareTraining)
else prepare_fx
)
old_attrs = {
attr: rgetattr(root, attr) for attr in attrs if rhasattr(root, attr)
}
prepared = root
if "" in configs:
return prep_fn(root, configs[""])
prepared = prep_fn(root, configs[""])
else:
for name, config in configs.items():
submodule = rgetattr(root, name)
rsetattr(root, name, prep_fn(submodule, config))
return root
for attr, value in old_attrs.items():
rsetattr(prepared, attr, value)
return prepared
def convert(self, root: torch.nn.Module, submodules: Set[str]) -> torch.nn.Module:
def convert(
self, root: torch.nn.Module, submodules: Set[str], attrs: Set[str]
) -> torch.nn.Module:
"""Quantizes a previously prepared module (as returned by `prepare`).
By default, this simply quantizes the entire root. If the `prepare`
......@@ -162,17 +183,24 @@ class QuantizationMixin(ABC):
root: The prepared model as returned by `prepare`, after training.
submodules: An iterator of fully qualified submodules names that require
converting.
attrs: The list of attributes to maintain for the module across this call.
Returns:
The quantized model.
"""
old_attrs = {
attr: rgetattr(root, attr) for attr in attrs if rhasattr(root, attr)
}
converted = root
if "" in submodules:
return convert_fx(root)
converted = convert_fx(root)
else:
for name in submodules:
prepared = rgetattr(root, name)
rsetattr(root, name, convert_fx(prepared))
return root
for attr, value in old_attrs.items():
rsetattr(converted, attr, value)
return converted
@dataclass(frozen=True)
......@@ -195,7 +223,7 @@ class ModelTransform:
step: Optional[int] = None
interval: Optional[int] = None
def __post_init__(self):
def __post_init__(self) -> None:
""" Validate a few properties for early failure. """
if (self.step is None and self.interval is None) or (
self.step is not None and self.interval is not None
......@@ -272,6 +300,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
qconfig_dicts: Optional[
Dict[str, Optional[Dict[str, Union[QConfig, QConfigDynamic]]]]
] = None,
preserved_attrs: Optional[List[str]] = None,
) -> None:
"""
Args:
......@@ -285,6 +314,8 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
freeze_bn_step: If specified, the step at which we apply freeze the
collection of batch normalization layer statistics for QAT.
qconfig_dicts: If given, used for quantization of the model during training.
preserved_attrs: If provided, a list of attributes to preserve across
quantized modules. These are preserved only if they already exists.
"""
if start_step < 0:
raise ValueError(
......@@ -355,6 +386,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
)
self.prepared: Optional[torch.nn.Module] = None
self.preserved_attrs = set([] if preserved_attrs is None else preserved_attrs)
if not qconfig_dicts:
self.qconfig_dicts: QConfigDicts = {"": {"": get_default_qat_qconfig()}}
else:
......@@ -379,7 +411,9 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
with mode(pl_module, training=True) as train:
pl_module._prepared = self.prepare(
_deepcopy(train), configs=self.qconfig_dicts
_deepcopy(train),
configs=self.qconfig_dicts,
attrs=self.preserved_attrs,
)
pl_module.forward = MethodType(_quantized_forward, pl_module)
self.prepared = pl_module._prepared
......@@ -419,7 +453,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
if hasattr(pl_module, "_quantized"):
return
pl_module._quantized = self.convert(
pl_module._prepared, self.qconfig_dicts.keys()
pl_module._prepared, self.qconfig_dicts.keys(), attrs=self.preserved_attrs
)
self.quantized = pl_module._quantized
......@@ -431,7 +465,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
if hasattr(pl_module, "_quantized"):
return
pl_module._quantized = self.convert(
pl_module._prepared, self.qconfig_dicts.keys()
pl_module._prepared, self.qconfig_dicts.keys(), attrs=self.preserved_attrs
)
self.quantized = pl_module._quantized
......@@ -481,9 +515,14 @@ class PostTrainingQuantization(Callback, QuantizationMixin):
the validation data. Only available after validation has ended.
"""
def __init__(self, qconfig_dicts: Optional[QConfigDicts] = None) -> None:
def __init__(
self,
qconfig_dicts: Optional[QConfigDicts] = None,
preserved_attrs: Optional[List[str]] = None,
) -> None:
""" Initialize the callback. """
self.qconfig_dicts = qconfig_dicts or {"": {"": get_default_qconfig()}}
self.preserved_attrs = set([] if preserved_attrs is None else preserved_attrs)
self.prepared: Optional[torch.nn.Module] = None
self.quantized: Optional[torch.nn.Module] = None
self.should_calibrate = _requires_calibration(self.qconfig_dicts)
......@@ -495,12 +534,16 @@ class PostTrainingQuantization(Callback, QuantizationMixin):
"""
# Pass a copy to quantization APIs.
self.prepared = self.prepare(
_deepcopy(pl_module).eval(), configs=self.qconfig_dicts
_deepcopy(pl_module).eval(),
configs=self.qconfig_dicts,
attrs=self.preserved_attrs,
)
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
""" Convert the calibrated model to its finalized quantized version. """
self.quantized = self.convert(self.prepared, self.qconfig_dicts.keys())
self.quantized = self.convert(
self.prepared, self.qconfig_dicts.keys(), attrs=self.preserved_attrs
)
def on_validation_batch_end(
self,
......
......@@ -10,8 +10,10 @@ from d2go.runner.callbacks.quantization import (
PostTrainingQuantization,
QuantizationAwareTraining,
ModelTransform,
get_default_qconfig,
get_default_qat_qconfig,
rgetattr,
rsetattr,
rhasattr,
)
from d2go.utils.misc import mode
from d2go.utils.testing.helper import tempdir
......@@ -27,6 +29,30 @@ from torch.quantization import ( # @manual; @manual
)
class TestUtilities(unittest.TestCase):
""" Test some basic utilities we rely on. """
def test_get_set_has(self):
""" Trivial test for generic behavior. Only support pre-existing deeply nested values."""
class TestObject(object):
def __init__(self):
self.object = None
self.set_to_five = 5
obj = TestObject()
obj.object = TestObject()
obj.object.set_to_five = 10
rsetattr(obj, "object.set_to_five", 1)
self.assertTrue(rhasattr(obj, "object.set_to_five"))
self.assertEqual(1, rgetattr(obj, "object.set_to_five"))
self.assertEqual(5, rgetattr(obj, "set_to_five"))
with self.assertRaises(AttributeError):
rsetattr(obj, "object.does_not_exist.five", 5)
class TestModelTransform(unittest.TestCase):
""" Tests ModelTransforms. """
......@@ -212,6 +238,40 @@ class TestQuantizationAwareTraining(unittest.TestCase):
self.assertIsNotNone(qat.prepared)
self.assertIsNotNone(qat.quantized)
@tempdir
def test_attribute_preservation_qat(self, root_dir):
""" Validates we can preserve specified properties in module. """
seed_everything(100)
model = TestModule()
model.layer._added_property = 10
model._not_preserved = 15
model._added_property = 20
num_epochs = 2
qat = QuantizationAwareTraining(
preserved_attrs=["_added_property", "layer._added_property"]
)
trainer = Trainer(
default_root_dir=os.path.join(root_dir, "quantized"),
checkpoint_callback=False,
callbacks=[qat],
max_epochs=num_epochs,
logger=False,
)
trainer.fit(model)
self.assertIsNotNone(qat.prepared)
self.assertIsNotNone(qat.quantized)
# Assert properties are maintained.
self.assertTrue(hasattr(qat.prepared, "_added_property"))
self.assertTrue(hasattr(qat.prepared.layer, "_added_property"))
with self.assertRaises(AttributeError):
qat.prepared._not_preserved
@tempdir
def test_quantization_and_checkpointing(self, root_dir):
""" Validate written checkpoints can be loaded back as expected. """
......@@ -243,11 +303,11 @@ class TestQuantizationAwareTraining(unittest.TestCase):
class _CustomQAT(QuantizationAwareTraining):
""" Only quantize TestModule.another_layer. """
def prepare(self, model, configs):
def prepare(self, model, configs, attrs):
model.another_layer = prepare_qat_fx(model.another_layer, configs[""])
return model
def convert(self, model, submodules):
def convert(self, model, submodules, attrs):
model.another_layer = convert_fx(model.another_layer)
return model
......@@ -406,11 +466,11 @@ class TestPostTrainingQuantization(unittest.TestCase):
class _CustomStaticQuant(PostTrainingQuantization):
""" Only quantize TestModule.another_layer. """
def prepare(self, model, configs):
def prepare(self, model, configs, attrs):
model.another_layer = prepare_fx(model.another_layer, configs[""])
return model
def convert(self, model, submodules):
def convert(self, model, submodules, attrs):
model.another_layer = convert_fx(model.another_layer)
return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment