Commit 596f3721 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

add is_qat to lightning codepath

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/339

add is_qat to lightning codepath

Reviewed By: jerryzh168

Differential Revision: D37937336

fbshipit-source-id: 68debe57c7f7dcf8647fad6ab9e34eff2aaa851c
parent 6c8c2fc8
......@@ -99,7 +99,7 @@ def checkpoint_has_prepared(checkpoint: Dict[str, Any]) -> bool:
def maybe_prepare_for_quantization(model: LightningModule, checkpoint: Dict[str, Any]):
if checkpoint_has_prepared(checkpoint) and not hasattr(model, PREPARED):
# model has been prepared for QAT before saving into checkpoint
setattr(model, PREPARED, _deepcopy(model).custom_prepare_fx())
setattr(model, PREPARED, _deepcopy(model).custom_prepare_fx(is_qat=True))
class QuantizationMixin(ABC):
......@@ -161,13 +161,10 @@ class QuantizationMixin(ABC):
Returns:
The prepared Module to be used for quantized aware training.
"""
is_qat = isinstance(self, QuantizationAwareTraining)
if hasattr(root, "custom_prepare_fx"):
return root.custom_prepare_fx()
prep_fn = (
prepare_qat_fx
if isinstance(self, QuantizationAwareTraining)
else prepare_fx
)
return root.custom_prepare_fx(is_qat)
prep_fn = prepare_qat_fx if is_qat else prepare_fx
old_attrs = {
attr: rgetattr(root, attr) for attr in attrs if rhasattr(root, attr)
}
......
......@@ -477,12 +477,14 @@ class DefaultTask(pl.LightningModule):
# TODO: remove custom_prepare_fx/custom_convert_fx from LightningModule
def custom_prepare_fx(self) -> pl.LightningModule:
def custom_prepare_fx(self, is_qat) -> pl.LightningModule:
if hasattr(self.model, "custom_prepare_fx"):
self.model = self.model.custom_prepare_fx(self.cfg, example_input=None)
self.model = self.model.custom_prepare_fx(
self.cfg, is_qat, example_input=None
)
else:
self.model = default_custom_prepare_fx(
self.cfg, self.model, self.model.training, example_input=None
self.cfg, self.model, is_qat, example_input=None
)
return self
......
......@@ -52,7 +52,7 @@ class DetMetaArchForTest(torch.nn.Module):
ret = [{"instances": instance}]
return ret
def custom_prepare_fx(self, cfg, example_input=None):
def custom_prepare_fx(self, cfg, is_qat, example_input=None):
example_inputs = (torch.rand(1, 3, 3, 3),)
self.avgpool = prepare_qat_fx(
self.avgpool,
......
......@@ -171,7 +171,7 @@ class TestLightningTask(unittest.TestCase):
self.avgpool.preserved_attr = "foo"
self.avgpool.not_preserved_attr = "bar"
def custom_prepare_fx(self, cfg, example_input=None):
def custom_prepare_fx(self, cfg, is_qat, example_input=None):
example_inputs = (torch.rand(1, 3, 3, 3),)
self.avgpool = prepare_qat_fx(
self.avgpool,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment