Commit 75176365 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

re-enable lightning task's test_qat

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/330

- re-enable the `test_qat`
- remove `example_input` from `DetMetaArchForTest`, since its `custom_prepare_fx` just create a tensor for avgpool.

Reviewed By: jerryzh168

Differential Revision: D37793260

fbshipit-source-id: ec7a825c61292d9c6d792f910a957c1c27832336
parent 74bc35ea
......@@ -99,7 +99,7 @@ def checkpoint_has_prepared(checkpoint: Dict[str, Any]) -> bool:
def maybe_prepare_for_quantization(model: LightningModule, checkpoint: Dict[str, Any]):
if checkpoint_has_prepared(checkpoint) and not hasattr(model, PREPARED):
# model has been prepared for QAT before saving into checkpoint
setattr(model, PREPARED, _deepcopy(model).prepare_for_quant())
setattr(model, PREPARED, _deepcopy(model).custom_prepare_fx())
class QuantizationMixin(ABC):
......@@ -161,8 +161,8 @@ class QuantizationMixin(ABC):
Returns:
The prepared Module to be used for quantized aware training.
"""
if hasattr(root, "prepare_for_quant"):
return root.prepare_for_quant()
if hasattr(root, "custom_prepare_fx"):
return root.custom_prepare_fx()
prep_fn = (
prepare_qat_fx
if isinstance(self, QuantizationAwareTraining)
......
......@@ -19,7 +19,7 @@ from d2go.modeling.model_freezing_utils import set_requires_grad
from d2go.optimizer import build_optimizer_mapper
from d2go.quantization.modeling import (
default_custom_convert_fx,
default_prepare_for_quant,
default_custom_prepare_fx,
)
from d2go.runner.callbacks.quantization import maybe_prepare_for_quantization, PREPARED
from d2go.runner.default_runner import (
......@@ -475,12 +475,13 @@ class DefaultTask(pl.LightningModule):
self.ema_state.load_state_dict(checkpointed_state["model_ema"])
rank_zero_info("Loaded EMA state from checkpoint.")
def prepare_for_quant(self) -> pl.LightningModule:
example_input = self.model.example_input
if hasattr(self.model, "prepare_for_quant"):
self.model = self.model.prepare_for_quant(self.cfg, example_input)
def custom_prepare_fx(self) -> pl.LightningModule:
if hasattr(self.model, "custom_prepare_fx"):
self.model = self.model.custom_prepare_fx(self.cfg, example_input=None)
else:
self.model = default_prepare_for_quant(self.cfg, self.model, example_input)
self.model = default_custom_prepare_fx(
self.cfg, self.model, example_input=None
)
return self
def custom_convert_fx(self) -> pl.LightningModule:
......
......@@ -26,11 +26,6 @@ class DetMetaArchForTest(torch.nn.Module):
def device(self):
return self.conv.weight.device
@property
def example_input(self):
# TODO[quant-example-inputs]: set example_input properly
return torch.randn(1, 3, 224, 224)
def forward(self, inputs):
if not self.training:
return self.inference(inputs)
......@@ -57,8 +52,7 @@ class DetMetaArchForTest(torch.nn.Module):
ret = [{"instances": instance}]
return ret
def prepare_for_quant(self, cfg, example_input=None):
# TODO[quant-example-inputs]: use example_input
def custom_prepare_fx(self, cfg, example_input=None):
example_inputs = (torch.rand(1, 3, 3, 3),)
self.avgpool = prepare_qat_fx(
self.avgpool,
......
......@@ -161,9 +161,6 @@ class TestLightningTask(unittest.TestCase):
)
@tempdir
@unittest.skip(
"FX Graph Mode Quantization API has been updated, re-enable the test after PyTorch 1.13 stable release"
)
def test_qat(self, tmp_dir):
@META_ARCH_REGISTRY.register()
class QuantizableDetMetaArchForTest(mah.DetMetaArchForTest):
......@@ -174,7 +171,7 @@ class TestLightningTask(unittest.TestCase):
self.avgpool.preserved_attr = "foo"
self.avgpool.not_preserved_attr = "bar"
def prepare_for_quant(self, cfg):
def custom_prepare_fx(self, cfg, example_input=None):
example_inputs = (torch.rand(1, 3, 3, 3),)
self.avgpool = prepare_qat_fx(
self.avgpool,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment