Commit 1850a632 authored by Kai Zhang's avatar Kai Zhang Committed by Facebook GitHub Bot
Browse files

Sync quantization callback changes

Reviewed By: newstzpz

Differential Revision: D27255960

fbshipit-source-id: 1699ff23d2bc610dffc0215a90a7c1c17e3783c3
parent 1027896a
...@@ -40,6 +40,16 @@ def rgetattr(obj: Any, attr: str, *args) -> Any: ...@@ -40,6 +40,16 @@ def rgetattr(obj: Any, attr: str, *args) -> Any:
return functools.reduce(_getattr, [obj] + attr.split(".")) return functools.reduce(_getattr, [obj] + attr.split("."))
def rhasattr(obj: Any, attr: str, *args) -> bool:
""" Same as hasattr but supports deeply nested objects. """
try:
_ = rgetattr(obj, attr, *args)
except AttributeError:
return False
return True
def _deepcopy(pl_module: LightningModule) -> LightningModule: def _deepcopy(pl_module: LightningModule) -> LightningModule:
"""Copy a LightningModule. Some properties need to be ignored. """ """Copy a LightningModule. Some properties need to be ignored. """
# Remove _result before call to deepcopy since it store non-leaf Tensors. # Remove _result before call to deepcopy since it store non-leaf Tensors.
...@@ -101,11 +111,11 @@ class QuantizationMixin(ABC): ...@@ -101,11 +111,11 @@ class QuantizationMixin(ABC):
As such, we could do something like the below, shown here for QAT. As such, we could do something like the below, shown here for QAT.
>>> class MyQuantizationCallback(QuantizedAwareTraining): >>> class MyQuantizationCallback(QuantizedAwareTraining):
... def prepare(self, model, config): ... def prepare(self, model, config, attrs):
... model.traceable = prepare_qat_fx(model.traceable, config) ... model.traceable = prepare_qat_fx(model.traceable, config)
... return model ... return model
... ...
... def convert(self, model): ... def convert(self, model, attr):
... model.traceable = convert_fx(model.traceable) ... model.traceable = convert_fx(model.traceable)
... return model ... return model
...@@ -122,7 +132,9 @@ class QuantizationMixin(ABC): ...@@ -122,7 +132,9 @@ class QuantizationMixin(ABC):
""" """
def prepare(self, root: LightningModule, configs: QConfigDicts) -> torch.nn.Module: def prepare(
self, root: LightningModule, configs: QConfigDicts, attrs: Set[str]
) -> torch.nn.Module:
"""Prepares the root user modules for quantization. """Prepares the root user modules for quantization.
By default, this tries to prepare the entire LightningModule. If this is By default, this tries to prepare the entire LightningModule. If this is
...@@ -135,6 +147,7 @@ class QuantizationMixin(ABC): ...@@ -135,6 +147,7 @@ class QuantizationMixin(ABC):
root: The LightningModule as given to the lightning Trainer in train mode. root: The LightningModule as given to the lightning Trainer in train mode.
configs: Specification to be used when preparing the model, as provided by the user. configs: Specification to be used when preparing the model, as provided by the user.
It is guaranteed that no key is a suffix of another. It is guaranteed that no key is a suffix of another.
attrs: The list of attributes to maintain for the module.
Returns: Returns:
The prepared Module to be used for quantized aware training. The prepared Module to be used for quantized aware training.
...@@ -144,15 +157,23 @@ class QuantizationMixin(ABC): ...@@ -144,15 +157,23 @@ class QuantizationMixin(ABC):
if isinstance(self, QuantizationAwareTraining) if isinstance(self, QuantizationAwareTraining)
else prepare_fx else prepare_fx
) )
old_attrs = {
attr: rgetattr(root, attr) for attr in attrs if rhasattr(root, attr)
}
prepared = root
if "" in configs: if "" in configs:
return prep_fn(root, configs[""]) prepared = prep_fn(root, configs[""])
else:
for name, config in configs.items(): for name, config in configs.items():
submodule = rgetattr(root, name) submodule = rgetattr(root, name)
rsetattr(root, name, prep_fn(submodule, config)) rsetattr(root, name, prep_fn(submodule, config))
return root for attr, value in old_attrs.items():
rsetattr(prepared, attr, value)
def convert(self, root: torch.nn.Module, submodules: Set[str]) -> torch.nn.Module: return prepared
def convert(
self, root: torch.nn.Module, submodules: Set[str], attrs: Set[str]
) -> torch.nn.Module:
"""Quantizes a previously prepared module (as returned by `prepare`). """Quantizes a previously prepared module (as returned by `prepare`).
By default, this simply quantizes the entire root. If the `prepare` By default, this simply quantizes the entire root. If the `prepare`
...@@ -162,17 +183,24 @@ class QuantizationMixin(ABC): ...@@ -162,17 +183,24 @@ class QuantizationMixin(ABC):
root: The prepared model as returned by `prepare`, after training. root: The prepared model as returned by `prepare`, after training.
submodules: An iterator of fully qualified submodules names that require submodules: An iterator of fully qualified submodules names that require
converting. converting.
attrs: The list of attributes to maintain for the module across this call.
Returns: Returns:
The quantized model. The quantized model.
""" """
old_attrs = {
attr: rgetattr(root, attr) for attr in attrs if rhasattr(root, attr)
}
converted = root
if "" in submodules: if "" in submodules:
return convert_fx(root) converted = convert_fx(root)
else:
for name in submodules: for name in submodules:
prepared = rgetattr(root, name) prepared = rgetattr(root, name)
rsetattr(root, name, convert_fx(prepared)) rsetattr(root, name, convert_fx(prepared))
return root for attr, value in old_attrs.items():
rsetattr(converted, attr, value)
return converted
@dataclass(frozen=True) @dataclass(frozen=True)
...@@ -195,7 +223,7 @@ class ModelTransform: ...@@ -195,7 +223,7 @@ class ModelTransform:
step: Optional[int] = None step: Optional[int] = None
interval: Optional[int] = None interval: Optional[int] = None
def __post_init__(self): def __post_init__(self) -> None:
""" Validate a few properties for early failure. """ """ Validate a few properties for early failure. """
if (self.step is None and self.interval is None) or ( if (self.step is None and self.interval is None) or (
self.step is not None and self.interval is not None self.step is not None and self.interval is not None
...@@ -272,6 +300,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin): ...@@ -272,6 +300,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
qconfig_dicts: Optional[ qconfig_dicts: Optional[
Dict[str, Optional[Dict[str, Union[QConfig, QConfigDynamic]]]] Dict[str, Optional[Dict[str, Union[QConfig, QConfigDynamic]]]]
] = None, ] = None,
preserved_attrs: Optional[List[str]] = None,
) -> None: ) -> None:
""" """
Args: Args:
...@@ -285,6 +314,8 @@ class QuantizationAwareTraining(Callback, QuantizationMixin): ...@@ -285,6 +314,8 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
freeze_bn_step: If specified, the step at which we apply freeze the freeze_bn_step: If specified, the step at which we apply freeze the
collection of batch normalization layer statistics for QAT. collection of batch normalization layer statistics for QAT.
qconfig_dicts: If given, used for quantization of the model during training. qconfig_dicts: If given, used for quantization of the model during training.
preserved_attrs: If provided, a list of attributes to preserve across
quantized modules. These are preserved only if they already exists.
""" """
if start_step < 0: if start_step < 0:
raise ValueError( raise ValueError(
...@@ -355,6 +386,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin): ...@@ -355,6 +386,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
) )
self.prepared: Optional[torch.nn.Module] = None self.prepared: Optional[torch.nn.Module] = None
self.preserved_attrs = set([] if preserved_attrs is None else preserved_attrs)
if not qconfig_dicts: if not qconfig_dicts:
self.qconfig_dicts: QConfigDicts = {"": {"": get_default_qat_qconfig()}} self.qconfig_dicts: QConfigDicts = {"": {"": get_default_qat_qconfig()}}
else: else:
...@@ -379,7 +411,9 @@ class QuantizationAwareTraining(Callback, QuantizationMixin): ...@@ -379,7 +411,9 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
with mode(pl_module, training=True) as train: with mode(pl_module, training=True) as train:
pl_module._prepared = self.prepare( pl_module._prepared = self.prepare(
_deepcopy(train), configs=self.qconfig_dicts _deepcopy(train),
configs=self.qconfig_dicts,
attrs=self.preserved_attrs,
) )
pl_module.forward = MethodType(_quantized_forward, pl_module) pl_module.forward = MethodType(_quantized_forward, pl_module)
self.prepared = pl_module._prepared self.prepared = pl_module._prepared
...@@ -419,7 +453,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin): ...@@ -419,7 +453,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
if hasattr(pl_module, "_quantized"): if hasattr(pl_module, "_quantized"):
return return
pl_module._quantized = self.convert( pl_module._quantized = self.convert(
pl_module._prepared, self.qconfig_dicts.keys() pl_module._prepared, self.qconfig_dicts.keys(), attrs=self.preserved_attrs
) )
self.quantized = pl_module._quantized self.quantized = pl_module._quantized
...@@ -431,7 +465,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin): ...@@ -431,7 +465,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
if hasattr(pl_module, "_quantized"): if hasattr(pl_module, "_quantized"):
return return
pl_module._quantized = self.convert( pl_module._quantized = self.convert(
pl_module._prepared, self.qconfig_dicts.keys() pl_module._prepared, self.qconfig_dicts.keys(), attrs=self.preserved_attrs
) )
self.quantized = pl_module._quantized self.quantized = pl_module._quantized
...@@ -481,9 +515,14 @@ class PostTrainingQuantization(Callback, QuantizationMixin): ...@@ -481,9 +515,14 @@ class PostTrainingQuantization(Callback, QuantizationMixin):
the validation data. Only available after validation has ended. the validation data. Only available after validation has ended.
""" """
def __init__(self, qconfig_dicts: Optional[QConfigDicts] = None) -> None: def __init__(
self,
qconfig_dicts: Optional[QConfigDicts] = None,
preserved_attrs: Optional[List[str]] = None,
) -> None:
""" Initialize the callback. """ """ Initialize the callback. """
self.qconfig_dicts = qconfig_dicts or {"": {"": get_default_qconfig()}} self.qconfig_dicts = qconfig_dicts or {"": {"": get_default_qconfig()}}
self.preserved_attrs = set([] if preserved_attrs is None else preserved_attrs)
self.prepared: Optional[torch.nn.Module] = None self.prepared: Optional[torch.nn.Module] = None
self.quantized: Optional[torch.nn.Module] = None self.quantized: Optional[torch.nn.Module] = None
self.should_calibrate = _requires_calibration(self.qconfig_dicts) self.should_calibrate = _requires_calibration(self.qconfig_dicts)
...@@ -495,12 +534,16 @@ class PostTrainingQuantization(Callback, QuantizationMixin): ...@@ -495,12 +534,16 @@ class PostTrainingQuantization(Callback, QuantizationMixin):
""" """
# Pass a copy to quantization APIs. # Pass a copy to quantization APIs.
self.prepared = self.prepare( self.prepared = self.prepare(
_deepcopy(pl_module).eval(), configs=self.qconfig_dicts _deepcopy(pl_module).eval(),
configs=self.qconfig_dicts,
attrs=self.preserved_attrs,
) )
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None: def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
""" Convert the calibrated model to its finalized quantized version. """ """ Convert the calibrated model to its finalized quantized version. """
self.quantized = self.convert(self.prepared, self.qconfig_dicts.keys()) self.quantized = self.convert(
self.prepared, self.qconfig_dicts.keys(), attrs=self.preserved_attrs
)
def on_validation_batch_end( def on_validation_batch_end(
self, self,
......
...@@ -10,8 +10,10 @@ from d2go.runner.callbacks.quantization import ( ...@@ -10,8 +10,10 @@ from d2go.runner.callbacks.quantization import (
PostTrainingQuantization, PostTrainingQuantization,
QuantizationAwareTraining, QuantizationAwareTraining,
ModelTransform, ModelTransform,
get_default_qconfig,
get_default_qat_qconfig, get_default_qat_qconfig,
rgetattr,
rsetattr,
rhasattr,
) )
from d2go.utils.misc import mode from d2go.utils.misc import mode
from d2go.utils.testing.helper import tempdir from d2go.utils.testing.helper import tempdir
...@@ -27,6 +29,30 @@ from torch.quantization import ( # @manual; @manual ...@@ -27,6 +29,30 @@ from torch.quantization import ( # @manual; @manual
) )
class TestUtilities(unittest.TestCase):
""" Test some basic utilities we rely on. """
def test_get_set_has(self):
""" Trivial test for generic behavior. Only support pre-existing deeply nested values."""
class TestObject(object):
def __init__(self):
self.object = None
self.set_to_five = 5
obj = TestObject()
obj.object = TestObject()
obj.object.set_to_five = 10
rsetattr(obj, "object.set_to_five", 1)
self.assertTrue(rhasattr(obj, "object.set_to_five"))
self.assertEqual(1, rgetattr(obj, "object.set_to_five"))
self.assertEqual(5, rgetattr(obj, "set_to_five"))
with self.assertRaises(AttributeError):
rsetattr(obj, "object.does_not_exist.five", 5)
class TestModelTransform(unittest.TestCase): class TestModelTransform(unittest.TestCase):
""" Tests ModelTransforms. """ """ Tests ModelTransforms. """
...@@ -212,6 +238,40 @@ class TestQuantizationAwareTraining(unittest.TestCase): ...@@ -212,6 +238,40 @@ class TestQuantizationAwareTraining(unittest.TestCase):
self.assertIsNotNone(qat.prepared) self.assertIsNotNone(qat.prepared)
self.assertIsNotNone(qat.quantized) self.assertIsNotNone(qat.quantized)
@tempdir
def test_attribute_preservation_qat(self, root_dir):
""" Validates we can preserve specified properties in module. """
seed_everything(100)
model = TestModule()
model.layer._added_property = 10
model._not_preserved = 15
model._added_property = 20
num_epochs = 2
qat = QuantizationAwareTraining(
preserved_attrs=["_added_property", "layer._added_property"]
)
trainer = Trainer(
default_root_dir=os.path.join(root_dir, "quantized"),
checkpoint_callback=False,
callbacks=[qat],
max_epochs=num_epochs,
logger=False,
)
trainer.fit(model)
self.assertIsNotNone(qat.prepared)
self.assertIsNotNone(qat.quantized)
# Assert properties are maintained.
self.assertTrue(hasattr(qat.prepared, "_added_property"))
self.assertTrue(hasattr(qat.prepared.layer, "_added_property"))
with self.assertRaises(AttributeError):
qat.prepared._not_preserved
@tempdir @tempdir
def test_quantization_and_checkpointing(self, root_dir): def test_quantization_and_checkpointing(self, root_dir):
""" Validate written checkpoints can be loaded back as expected. """ """ Validate written checkpoints can be loaded back as expected. """
...@@ -243,11 +303,11 @@ class TestQuantizationAwareTraining(unittest.TestCase): ...@@ -243,11 +303,11 @@ class TestQuantizationAwareTraining(unittest.TestCase):
class _CustomQAT(QuantizationAwareTraining): class _CustomQAT(QuantizationAwareTraining):
""" Only quantize TestModule.another_layer. """ """ Only quantize TestModule.another_layer. """
def prepare(self, model, configs): def prepare(self, model, configs, attrs):
model.another_layer = prepare_qat_fx(model.another_layer, configs[""]) model.another_layer = prepare_qat_fx(model.another_layer, configs[""])
return model return model
def convert(self, model, submodules): def convert(self, model, submodules, attrs):
model.another_layer = convert_fx(model.another_layer) model.another_layer = convert_fx(model.another_layer)
return model return model
...@@ -406,11 +466,11 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -406,11 +466,11 @@ class TestPostTrainingQuantization(unittest.TestCase):
class _CustomStaticQuant(PostTrainingQuantization): class _CustomStaticQuant(PostTrainingQuantization):
""" Only quantize TestModule.another_layer. """ """ Only quantize TestModule.another_layer. """
def prepare(self, model, configs): def prepare(self, model, configs, attrs):
model.another_layer = prepare_fx(model.another_layer, configs[""]) model.another_layer = prepare_fx(model.another_layer, configs[""])
return model return model
def convert(self, model, submodules): def convert(self, model, submodules, attrs):
model.another_layer = convert_fx(model.another_layer) model.another_layer = convert_fx(model.another_layer)
return model return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment