Commit 75176365 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

re-enable lightning task's test_qat

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/330

- re-enable the `test_qat`
- remove `example_input` from `DetMetaArchForTest`, since its `custom_prepare_fx` just create a tensor for avgpool.

Reviewed By: jerryzh168

Differential Revision: D37793260

fbshipit-source-id: ec7a825c61292d9c6d792f910a957c1c27832336
parent 74bc35ea
...@@ -99,7 +99,7 @@ def checkpoint_has_prepared(checkpoint: Dict[str, Any]) -> bool: ...@@ -99,7 +99,7 @@ def checkpoint_has_prepared(checkpoint: Dict[str, Any]) -> bool:
def maybe_prepare_for_quantization(model: LightningModule, checkpoint: Dict[str, Any]): def maybe_prepare_for_quantization(model: LightningModule, checkpoint: Dict[str, Any]):
if checkpoint_has_prepared(checkpoint) and not hasattr(model, PREPARED): if checkpoint_has_prepared(checkpoint) and not hasattr(model, PREPARED):
# model has been prepared for QAT before saving into checkpoint # model has been prepared for QAT before saving into checkpoint
setattr(model, PREPARED, _deepcopy(model).prepare_for_quant()) setattr(model, PREPARED, _deepcopy(model).custom_prepare_fx())
class QuantizationMixin(ABC): class QuantizationMixin(ABC):
...@@ -161,8 +161,8 @@ class QuantizationMixin(ABC): ...@@ -161,8 +161,8 @@ class QuantizationMixin(ABC):
Returns: Returns:
The prepared Module to be used for quantized aware training. The prepared Module to be used for quantized aware training.
""" """
if hasattr(root, "prepare_for_quant"): if hasattr(root, "custom_prepare_fx"):
return root.prepare_for_quant() return root.custom_prepare_fx()
prep_fn = ( prep_fn = (
prepare_qat_fx prepare_qat_fx
if isinstance(self, QuantizationAwareTraining) if isinstance(self, QuantizationAwareTraining)
......
...@@ -19,7 +19,7 @@ from d2go.modeling.model_freezing_utils import set_requires_grad ...@@ -19,7 +19,7 @@ from d2go.modeling.model_freezing_utils import set_requires_grad
from d2go.optimizer import build_optimizer_mapper from d2go.optimizer import build_optimizer_mapper
from d2go.quantization.modeling import ( from d2go.quantization.modeling import (
default_custom_convert_fx, default_custom_convert_fx,
default_prepare_for_quant, default_custom_prepare_fx,
) )
from d2go.runner.callbacks.quantization import maybe_prepare_for_quantization, PREPARED from d2go.runner.callbacks.quantization import maybe_prepare_for_quantization, PREPARED
from d2go.runner.default_runner import ( from d2go.runner.default_runner import (
...@@ -475,12 +475,13 @@ class DefaultTask(pl.LightningModule): ...@@ -475,12 +475,13 @@ class DefaultTask(pl.LightningModule):
self.ema_state.load_state_dict(checkpointed_state["model_ema"]) self.ema_state.load_state_dict(checkpointed_state["model_ema"])
rank_zero_info("Loaded EMA state from checkpoint.") rank_zero_info("Loaded EMA state from checkpoint.")
def prepare_for_quant(self) -> pl.LightningModule: def custom_prepare_fx(self) -> pl.LightningModule:
example_input = self.model.example_input if hasattr(self.model, "custom_prepare_fx"):
if hasattr(self.model, "prepare_for_quant"): self.model = self.model.custom_prepare_fx(self.cfg, example_input=None)
self.model = self.model.prepare_for_quant(self.cfg, example_input)
else: else:
self.model = default_prepare_for_quant(self.cfg, self.model, example_input) self.model = default_custom_prepare_fx(
self.cfg, self.model, example_input=None
)
return self return self
def custom_convert_fx(self) -> pl.LightningModule: def custom_convert_fx(self) -> pl.LightningModule:
......
...@@ -26,11 +26,6 @@ class DetMetaArchForTest(torch.nn.Module): ...@@ -26,11 +26,6 @@ class DetMetaArchForTest(torch.nn.Module):
def device(self): def device(self):
return self.conv.weight.device return self.conv.weight.device
@property
def example_input(self):
# TODO[quant-example-inputs]: set example_input properly
return torch.randn(1, 3, 224, 224)
def forward(self, inputs): def forward(self, inputs):
if not self.training: if not self.training:
return self.inference(inputs) return self.inference(inputs)
...@@ -57,8 +52,7 @@ class DetMetaArchForTest(torch.nn.Module): ...@@ -57,8 +52,7 @@ class DetMetaArchForTest(torch.nn.Module):
ret = [{"instances": instance}] ret = [{"instances": instance}]
return ret return ret
def prepare_for_quant(self, cfg, example_input=None): def custom_prepare_fx(self, cfg, example_input=None):
# TODO[quant-example-inputs]: use example_input
example_inputs = (torch.rand(1, 3, 3, 3),) example_inputs = (torch.rand(1, 3, 3, 3),)
self.avgpool = prepare_qat_fx( self.avgpool = prepare_qat_fx(
self.avgpool, self.avgpool,
......
...@@ -161,9 +161,6 @@ class TestLightningTask(unittest.TestCase): ...@@ -161,9 +161,6 @@ class TestLightningTask(unittest.TestCase):
) )
@tempdir @tempdir
@unittest.skip(
"FX Graph Mode Quantization API has been updated, re-enable the test after PyTorch 1.13 stable release"
)
def test_qat(self, tmp_dir): def test_qat(self, tmp_dir):
@META_ARCH_REGISTRY.register() @META_ARCH_REGISTRY.register()
class QuantizableDetMetaArchForTest(mah.DetMetaArchForTest): class QuantizableDetMetaArchForTest(mah.DetMetaArchForTest):
...@@ -174,7 +171,7 @@ class TestLightningTask(unittest.TestCase): ...@@ -174,7 +171,7 @@ class TestLightningTask(unittest.TestCase):
self.avgpool.preserved_attr = "foo" self.avgpool.preserved_attr = "foo"
self.avgpool.not_preserved_attr = "bar" self.avgpool.not_preserved_attr = "bar"
def prepare_for_quant(self, cfg): def custom_prepare_fx(self, cfg, example_input=None):
example_inputs = (torch.rand(1, 3, 3, 3),) example_inputs = (torch.rand(1, 3, 3, 3),)
self.avgpool = prepare_qat_fx( self.avgpool = prepare_qat_fx(
self.avgpool, self.avgpool,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment