Commit 9215e1a8 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

replace custom_convert_fx API with convert callback

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/379

This diff ~~~prototypes~~~ implements replacing the `custom_convert_fx` API with a callback.

Reviewed By: LiamZhuuu

Differential Revision: D39859228

fbshipit-source-id: 34719d1758c4afa7e47930c12d3443813d3f4546
parent 3c68dda7
...@@ -61,9 +61,6 @@ class GeneralizedRCNN(_GeneralizedRCNN): ...@@ -61,9 +61,6 @@ class GeneralizedRCNN(_GeneralizedRCNN):
def custom_prepare_fx(self, cfg, is_qat, example_input=None): def custom_prepare_fx(self, cfg, is_qat, example_input=None):
return default_rcnn_custom_prepare_fx(self, cfg, is_qat, example_input) return default_rcnn_custom_prepare_fx(self, cfg, is_qat, example_input)
def custom_convert_fx(self, cfg):
return default_rcnn_custom_convert_fx(self, cfg)
def _cast_model_to_device(self, device): def _cast_model_to_device(self, device):
return _cast_detection_model(self, device) return _cast_detection_model(self, device)
...@@ -309,7 +306,10 @@ def default_rcnn_custom_prepare_fx(self, cfg, is_qat, example_input=None): ...@@ -309,7 +306,10 @@ def default_rcnn_custom_prepare_fx(self, cfg, is_qat, example_input=None):
_fx_quant_prepare(model, cfg, is_qat, example_input) _fx_quant_prepare(model, cfg, is_qat, example_input)
return model def convert_fx_callback(model):
return default_rcnn_custom_convert_fx(model, cfg)
return model, convert_fx_callback
def _fx_quant_prepare(self, cfg, is_qat, example_input): def _fx_quant_prepare(self, cfg, is_qat, example_input):
......
...@@ -29,6 +29,8 @@ else: ...@@ -29,6 +29,8 @@ else:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_CONVERT_FX_CALLBACK_ATTRIBUTE = "_convert_fx_callback"
def _is_observer_key(state_dict_key): def _is_observer_key(state_dict_key):
observer_keys = ["activation_post_process", "weight_fake_quant"] observer_keys = ["activation_post_process", "weight_fake_quant"]
...@@ -295,11 +297,7 @@ def default_custom_prepare_fx(cfg, model, is_qat, example_input=None): ...@@ -295,11 +297,7 @@ def default_custom_prepare_fx(cfg, model, is_qat, example_input=None):
model = prepare_fx(model, qconfig_dict, (example_input,)) model = prepare_fx(model, qconfig_dict, (example_input,))
logger.info("Setup the model with qconfig:\n{}".format(qconfig)) logger.info("Setup the model with qconfig:\n{}".format(qconfig))
return model return model, convert_fx
def default_custom_convert_fx(cfg, model):
return convert_fx(model)
def prepare_fake_quant_model(cfg, model, is_qat, example_input=None): def prepare_fake_quant_model(cfg, model, is_qat, example_input=None):
...@@ -333,12 +331,28 @@ def prepare_fake_quant_model(cfg, model, is_qat, example_input=None): ...@@ -333,12 +331,28 @@ def prepare_fake_quant_model(cfg, model, is_qat, example_input=None):
model = fuse_utils.swap_modules(model) model = fuse_utils.swap_modules(model)
if hasattr(model, "custom_prepare_fx"): if hasattr(model, "custom_prepare_fx"):
model = model.custom_prepare_fx(cfg, is_qat, example_input) ret = model.custom_prepare_fx(cfg, is_qat, example_input)
if not (isinstance(ret, tuple) and len(ret) == 2):
raise ValueError(
"`custom_prepare_fx` requires return model and convert_callback"
)
model, convert_fx_callback = ret
else: else:
logger.info( logger.info(
"Using default implementation for custom_prepare_fx (FX graph mode)" "Using default implementation for custom_prepare_fx (FX graph mode)"
) )
model = default_custom_prepare_fx(cfg, model, is_qat, example_input) model, convert_fx_callback = default_custom_prepare_fx(
cfg, model, is_qat, example_input
)
# HACK: store the convert_callback function as model attribute, which can be
# later accessed to convert fake quant model to quantized model. We'll find a
# better place to store this.
if hasattr(model, _CONVERT_FX_CALLBACK_ATTRIBUTE):
raise AttributeError(
f"{_CONVERT_FX_CALLBACK_ATTRIBUTE} is already set in model: {model}"
)
setattr(model, _CONVERT_FX_CALLBACK_ATTRIBUTE, convert_fx_callback)
return model return model
...@@ -352,10 +366,15 @@ def convert_to_quantized_model(cfg, fp32_model): ...@@ -352,10 +366,15 @@ def convert_to_quantized_model(cfg, fp32_model):
int8_model = convert(fp32_model, inplace=False) int8_model = convert(fp32_model, inplace=False)
else: else:
# FX graph mode quantization # FX graph mode quantization
if hasattr(fp32_model, "custom_convert_fx"): if not hasattr(fp32_model, _CONVERT_FX_CALLBACK_ATTRIBUTE):
int8_model = fp32_model.custom_convert_fx(cfg) raise AttributeError(
else: f"Can't find {_CONVERT_FX_CALLBACK_ATTRIBUTE} in model, please check "
int8_model = convert_fx(fp32_model) f"`prepare_fake_quant_model` has been called: {fp32_model}"
)
convert_fx_callback = getattr(fp32_model, _CONVERT_FX_CALLBACK_ATTRIBUTE)
int8_model = convert_fx_callback(fp32_model)
return int8_model return int8_model
...@@ -617,3 +636,15 @@ def _reset_qat_data_loader_if_needed(cfg, trainer, build_loader_func): ...@@ -617,3 +636,15 @@ def _reset_qat_data_loader_if_needed(cfg, trainer, build_loader_func):
# This method assumes the data loader can be replaced from trainer # This method assumes the data loader can be replaced from trainer
assert trainer.__class__ == SimpleTrainer assert trainer.__class__ == SimpleTrainer
trainer.reset_data_loader(lambda: build_loader_func(loader_cfg)) trainer.reset_data_loader(lambda: build_loader_func(loader_cfg))
def forward_custom_prepare_fx(root, sub_module_name, orig_ret):
"""Helper function to forward return of `custom_prepare_fx` from sub module"""
new_sub_module, callback = orig_ret
setattr(root, sub_module_name, new_sub_module)
def new_callback(m):
setattr(m, sub_module_name, callback(getattr(m, sub_module_name)))
return m
return root, new_callback
...@@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union ...@@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import torch import torch
from d2go.config import CfgNode from d2go.config import CfgNode
from d2go.quantization.modeling import prepare_fake_quant_model
from d2go.utils.misc import mode from d2go.utils.misc import mode
from mobile_cv.arch.quantization.observer import update_stat as observer_update_stat from mobile_cv.arch.quantization.observer import update_stat as observer_update_stat
from pytorch_lightning import LightningModule, Trainer from pytorch_lightning import LightningModule, Trainer
...@@ -99,7 +100,10 @@ def checkpoint_has_prepared(checkpoint: Dict[str, Any]) -> bool: ...@@ -99,7 +100,10 @@ def checkpoint_has_prepared(checkpoint: Dict[str, Any]) -> bool:
def maybe_prepare_for_quantization(model: LightningModule, checkpoint: Dict[str, Any]): def maybe_prepare_for_quantization(model: LightningModule, checkpoint: Dict[str, Any]):
if checkpoint_has_prepared(checkpoint) and not hasattr(model, PREPARED): if checkpoint_has_prepared(checkpoint) and not hasattr(model, PREPARED):
# model has been prepared for QAT before saving into checkpoint # model has been prepared for QAT before saving into checkpoint
setattr(model, PREPARED, _deepcopy(model).custom_prepare_fx(is_qat=True)) copied = _deepcopy(model)
prepared = prepare_fake_quant_model(copied.cfg, copied.model, is_qat=True)
copied.model = prepared
setattr(model, PREPARED, copied)
class QuantizationMixin(ABC): class QuantizationMixin(ABC):
...@@ -162,8 +166,14 @@ class QuantizationMixin(ABC): ...@@ -162,8 +166,14 @@ class QuantizationMixin(ABC):
The prepared Module to be used for quantized aware training. The prepared Module to be used for quantized aware training.
""" """
is_qat = isinstance(self, QuantizationAwareTraining) is_qat = isinstance(self, QuantizationAwareTraining)
if hasattr(root, "custom_prepare_fx"): self._convert_fx_callback = None
return root.custom_prepare_fx(is_qat) if hasattr(root.model, "custom_prepare_fx"):
prepared, convert_fx_callback = root.model.custom_prepare_fx(
root.cfg, is_qat
)
self._convert_fx_callback = convert_fx_callback
root.model = prepared
return root
prep_fn = prepare_qat_fx if is_qat else prepare_fx prep_fn = prepare_qat_fx if is_qat else prepare_fx
old_attrs = { old_attrs = {
attr: rgetattr(root, attr) for attr in attrs if rhasattr(root, attr) attr: rgetattr(root, attr) for attr in attrs if rhasattr(root, attr)
...@@ -203,8 +213,8 @@ class QuantizationMixin(ABC): ...@@ -203,8 +213,8 @@ class QuantizationMixin(ABC):
Returns: Returns:
The quantized model. The quantized model.
""" """
if hasattr(root, "custom_convert_fx"): if self._convert_fx_callback is not None:
return root.custom_convert_fx() return self._convert_fx_callback(root)
old_attrs = { old_attrs = {
attr: rgetattr(root, attr) for attr in attrs if rhasattr(root, attr) attr: rgetattr(root, attr) for attr in attrs if rhasattr(root, attr)
} }
......
...@@ -17,10 +17,6 @@ from d2go.data.utils import update_cfg_if_using_adhoc_dataset ...@@ -17,10 +17,6 @@ from d2go.data.utils import update_cfg_if_using_adhoc_dataset
from d2go.modeling.api import build_meta_arch from d2go.modeling.api import build_meta_arch
from d2go.modeling.model_freezing_utils import set_requires_grad from d2go.modeling.model_freezing_utils import set_requires_grad
from d2go.optimizer import build_optimizer_mapper from d2go.optimizer import build_optimizer_mapper
from d2go.quantization.modeling import (
default_custom_convert_fx,
default_custom_prepare_fx,
)
from d2go.runner.callbacks.quantization import maybe_prepare_for_quantization, PREPARED from d2go.runner.callbacks.quantization import maybe_prepare_for_quantization, PREPARED
from d2go.runner.default_runner import ( from d2go.runner.default_runner import (
_get_tbx_writer, _get_tbx_writer,
...@@ -505,26 +501,6 @@ class DefaultTask(pl.LightningModule): ...@@ -505,26 +501,6 @@ class DefaultTask(pl.LightningModule):
self.ema_state.load_state_dict(checkpointed_state["model_ema"]) self.ema_state.load_state_dict(checkpointed_state["model_ema"])
rank_zero_info("Loaded EMA state from checkpoint.") rank_zero_info("Loaded EMA state from checkpoint.")
# TODO: remove custom_prepare_fx/custom_convert_fx from LightningModule
def custom_prepare_fx(self, is_qat) -> pl.LightningModule:
if hasattr(self.model, "custom_prepare_fx"):
self.model = self.model.custom_prepare_fx(
self.cfg, is_qat, example_input=None
)
else:
self.model = default_custom_prepare_fx(
self.cfg, self.model, is_qat, example_input=None
)
return self
def custom_convert_fx(self) -> pl.LightningModule:
if hasattr(self.model, "custom_convert_fx"):
self.model = self.model.custom_convert_fx(self.cfg)
else:
self.model = default_custom_convert_fx(self.cfg, self.model)
return self
# TODO(T123654122): subclass of DefaultTask will be refactored # TODO(T123654122): subclass of DefaultTask will be refactored
class GeneralizedRCNNTask(DefaultTask): class GeneralizedRCNNTask(DefaultTask):
......
...@@ -6,7 +6,7 @@ import logging ...@@ -6,7 +6,7 @@ import logging
import os import os
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, Iterator from typing import Any, Callable, Dict, Iterator, Optional
# @manual=//vision/fair/detectron2/detectron2:detectron2 # @manual=//vision/fair/detectron2/detectron2:detectron2
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
...@@ -129,16 +129,27 @@ def _log_api_usage(identifier: str): ...@@ -129,16 +129,27 @@ def _log_api_usage(identifier: str):
torch._C._log_api_usage_once("d2go." + identifier) torch._C._log_api_usage_once("d2go." + identifier)
def inplace_delegate(self, api_name, sub_module_name, *args, **kwargs): def inplace_delegate(
self,
api_name: str,
sub_module_name: str,
setter_fn: Optional[Callable],
*args,
**kwargs,
) -> Any:
"""Helper function to delegate API calls to its submodule""" """Helper function to delegate API calls to its submodule"""
sub_module = getattr(self, sub_module_name) sub_module = getattr(self, sub_module_name)
api_name = f"delegate_{api_name}" api_name = f"delegate_{api_name}"
if hasattr(sub_module, api_name): if hasattr(sub_module, api_name):
func = getattr(sub_module, api_name) func = getattr(sub_module, api_name)
orig_ret = func(*args, **kwargs)
if setter_fn is None:
# Assume the return of `func` will replace the submodule # Assume the return of `func` will replace the submodule
setattr(self, sub_module_name, func(*args, **kwargs)) setattr(self, sub_module_name, orig_ret)
return self return self
else:
return setter_fn(self, sub_module_name, orig_ret)
else: else:
raise RuntimeError( raise RuntimeError(
f"It seems the {sub_module_name} doesn't implement {api_name}," f"It seems the {sub_module_name} doesn't implement {api_name},"
......
...@@ -59,11 +59,12 @@ class DetMetaArchForTest(torch.nn.Module): ...@@ -59,11 +59,12 @@ class DetMetaArchForTest(torch.nn.Module):
{"": set_backend_and_create_qconfig(cfg, is_train=self.training)}, {"": set_backend_and_create_qconfig(cfg, is_train=self.training)},
example_inputs, example_inputs,
) )
return self
def custom_convert_fx(self, cfg): def convert_fx_callback(model):
self.avgpool = convert_fx(self.avgpool) model.avgpool = convert_fx(model.avgpool)
return self return model
return self, convert_fx_callback
def get_det_meta_arch_cfg(cfg, dataset_name, output_dir): def get_det_meta_arch_cfg(cfg, dataset_name, output_dir):
......
...@@ -214,17 +214,19 @@ class TestLightningTask(unittest.TestCase): ...@@ -214,17 +214,19 @@ class TestLightningTask(unittest.TestCase):
example_inputs, example_inputs,
self.custom_config_dict, self.custom_config_dict,
) )
return self
def custom_convert_fx(self, cfg): def convert_fx_callback(model):
self.avgpool = convert_fx( model.avgpool = convert_fx(
self.avgpool, convert_custom_config=self.custom_config_dict model.avgpool, convert_custom_config=model.custom_config_dict
) )
return self return model
return self, convert_fx_callback
cfg = self._get_cfg(tmp_dir) cfg = self._get_cfg(tmp_dir)
cfg.MODEL.META_ARCHITECTURE = "QuantizableDetMetaArchForTest" cfg.MODEL.META_ARCHITECTURE = "QuantizableDetMetaArchForTest"
cfg.QUANTIZATION.QAT.ENABLED = True cfg.QUANTIZATION.QAT.ENABLED = True
cfg.QUANTIZATION.EAGER_MODE = False
task = GeneralizedRCNNTask(cfg) task = GeneralizedRCNNTask(cfg)
callbacks = [ callbacks = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment