You need to sign in or sign up before continuing.
Commit 596f3721 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

add is_qat to lightning codepath

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/339

add is_qat to lightning codepath

Reviewed By: jerryzh168

Differential Revision: D37937336

fbshipit-source-id: 68debe57c7f7dcf8647fad6ab9e34eff2aaa851c
parent 6c8c2fc8
...@@ -99,7 +99,7 @@ def checkpoint_has_prepared(checkpoint: Dict[str, Any]) -> bool: ...@@ -99,7 +99,7 @@ def checkpoint_has_prepared(checkpoint: Dict[str, Any]) -> bool:
def maybe_prepare_for_quantization(model: LightningModule, checkpoint: Dict[str, Any]): def maybe_prepare_for_quantization(model: LightningModule, checkpoint: Dict[str, Any]):
if checkpoint_has_prepared(checkpoint) and not hasattr(model, PREPARED): if checkpoint_has_prepared(checkpoint) and not hasattr(model, PREPARED):
# model has been prepared for QAT before saving into checkpoint # model has been prepared for QAT before saving into checkpoint
setattr(model, PREPARED, _deepcopy(model).custom_prepare_fx()) setattr(model, PREPARED, _deepcopy(model).custom_prepare_fx(is_qat=True))
class QuantizationMixin(ABC): class QuantizationMixin(ABC):
...@@ -161,13 +161,10 @@ class QuantizationMixin(ABC): ...@@ -161,13 +161,10 @@ class QuantizationMixin(ABC):
Returns: Returns:
The prepared Module to be used for quantized aware training. The prepared Module to be used for quantized aware training.
""" """
is_qat = isinstance(self, QuantizationAwareTraining)
if hasattr(root, "custom_prepare_fx"): if hasattr(root, "custom_prepare_fx"):
return root.custom_prepare_fx() return root.custom_prepare_fx(is_qat)
prep_fn = ( prep_fn = prepare_qat_fx if is_qat else prepare_fx
prepare_qat_fx
if isinstance(self, QuantizationAwareTraining)
else prepare_fx
)
old_attrs = { old_attrs = {
attr: rgetattr(root, attr) for attr in attrs if rhasattr(root, attr) attr: rgetattr(root, attr) for attr in attrs if rhasattr(root, attr)
} }
......
...@@ -477,12 +477,14 @@ class DefaultTask(pl.LightningModule): ...@@ -477,12 +477,14 @@ class DefaultTask(pl.LightningModule):
# TODO: remove custom_prepare_fx/custom_convert_fx from LightningModule # TODO: remove custom_prepare_fx/custom_convert_fx from LightningModule
def custom_prepare_fx(self) -> pl.LightningModule: def custom_prepare_fx(self, is_qat) -> pl.LightningModule:
if hasattr(self.model, "custom_prepare_fx"): if hasattr(self.model, "custom_prepare_fx"):
self.model = self.model.custom_prepare_fx(self.cfg, example_input=None) self.model = self.model.custom_prepare_fx(
self.cfg, is_qat, example_input=None
)
else: else:
self.model = default_custom_prepare_fx( self.model = default_custom_prepare_fx(
self.cfg, self.model, self.model.training, example_input=None self.cfg, self.model, is_qat, example_input=None
) )
return self return self
......
...@@ -52,7 +52,7 @@ class DetMetaArchForTest(torch.nn.Module): ...@@ -52,7 +52,7 @@ class DetMetaArchForTest(torch.nn.Module):
ret = [{"instances": instance}] ret = [{"instances": instance}]
return ret return ret
def custom_prepare_fx(self, cfg, example_input=None): def custom_prepare_fx(self, cfg, is_qat, example_input=None):
example_inputs = (torch.rand(1, 3, 3, 3),) example_inputs = (torch.rand(1, 3, 3, 3),)
self.avgpool = prepare_qat_fx( self.avgpool = prepare_qat_fx(
self.avgpool, self.avgpool,
......
...@@ -171,7 +171,7 @@ class TestLightningTask(unittest.TestCase): ...@@ -171,7 +171,7 @@ class TestLightningTask(unittest.TestCase):
self.avgpool.preserved_attr = "foo" self.avgpool.preserved_attr = "foo"
self.avgpool.not_preserved_attr = "bar" self.avgpool.not_preserved_attr = "bar"
def custom_prepare_fx(self, cfg, example_input=None): def custom_prepare_fx(self, cfg, is_qat, example_input=None):
example_inputs = (torch.rand(1, 3, 3, 3),) example_inputs = (torch.rand(1, 3, 3, 3),)
self.avgpool = prepare_qat_fx( self.avgpool = prepare_qat_fx(
self.avgpool, self.avgpool,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment