Commit 9215e1a8 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

replace custom_convert_fx API with convert callback

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/379

This diff ~~~prototypes~~~ implements replacing the `custom_convert_fx` API with a callback.

Reviewed By: LiamZhuuu

Differential Revision: D39859228

fbshipit-source-id: 34719d1758c4afa7e47930c12d3443813d3f4546
parent 3c68dda7
......@@ -61,9 +61,6 @@ class GeneralizedRCNN(_GeneralizedRCNN):
def custom_prepare_fx(self, cfg, is_qat, example_input=None):
return default_rcnn_custom_prepare_fx(self, cfg, is_qat, example_input)
def custom_convert_fx(self, cfg):
return default_rcnn_custom_convert_fx(self, cfg)
def _cast_model_to_device(self, device):
return _cast_detection_model(self, device)
......@@ -309,7 +306,10 @@ def default_rcnn_custom_prepare_fx(self, cfg, is_qat, example_input=None):
_fx_quant_prepare(model, cfg, is_qat, example_input)
return model
def convert_fx_callback(model):
return default_rcnn_custom_convert_fx(model, cfg)
return model, convert_fx_callback
def _fx_quant_prepare(self, cfg, is_qat, example_input):
......
......@@ -29,6 +29,8 @@ else:
logger = logging.getLogger(__name__)
_CONVERT_FX_CALLBACK_ATTRIBUTE = "_convert_fx_callback"
def _is_observer_key(state_dict_key):
observer_keys = ["activation_post_process", "weight_fake_quant"]
......@@ -295,11 +297,7 @@ def default_custom_prepare_fx(cfg, model, is_qat, example_input=None):
model = prepare_fx(model, qconfig_dict, (example_input,))
logger.info("Setup the model with qconfig:\n{}".format(qconfig))
return model
def default_custom_convert_fx(cfg, model):
return convert_fx(model)
return model, convert_fx
def prepare_fake_quant_model(cfg, model, is_qat, example_input=None):
......@@ -333,12 +331,28 @@ def prepare_fake_quant_model(cfg, model, is_qat, example_input=None):
model = fuse_utils.swap_modules(model)
if hasattr(model, "custom_prepare_fx"):
model = model.custom_prepare_fx(cfg, is_qat, example_input)
ret = model.custom_prepare_fx(cfg, is_qat, example_input)
if not (isinstance(ret, tuple) and len(ret) == 2):
raise ValueError(
"`custom_prepare_fx` requires return model and convert_callback"
)
model, convert_fx_callback = ret
else:
logger.info(
"Using default implementation for custom_prepare_fx (FX graph mode)"
)
model = default_custom_prepare_fx(cfg, model, is_qat, example_input)
model, convert_fx_callback = default_custom_prepare_fx(
cfg, model, is_qat, example_input
)
# HACK: store the convert_callback function as model attribute, which can be
# later accessed to convert fake quant model to quantized model. We'll find a
# better place to store this.
if hasattr(model, _CONVERT_FX_CALLBACK_ATTRIBUTE):
raise AttributeError(
f"{_CONVERT_FX_CALLBACK_ATTRIBUTE} is already set in model: {model}"
)
setattr(model, _CONVERT_FX_CALLBACK_ATTRIBUTE, convert_fx_callback)
return model
......@@ -352,10 +366,15 @@ def convert_to_quantized_model(cfg, fp32_model):
int8_model = convert(fp32_model, inplace=False)
else:
# FX graph mode quantization
if hasattr(fp32_model, "custom_convert_fx"):
int8_model = fp32_model.custom_convert_fx(cfg)
else:
int8_model = convert_fx(fp32_model)
if not hasattr(fp32_model, _CONVERT_FX_CALLBACK_ATTRIBUTE):
raise AttributeError(
f"Can't find {_CONVERT_FX_CALLBACK_ATTRIBUTE} in model, please check "
f"`prepare_fake_quant_model` has been called: {fp32_model}"
)
convert_fx_callback = getattr(fp32_model, _CONVERT_FX_CALLBACK_ATTRIBUTE)
int8_model = convert_fx_callback(fp32_model)
return int8_model
......@@ -617,3 +636,15 @@ def _reset_qat_data_loader_if_needed(cfg, trainer, build_loader_func):
# This method assumes the data loader can be replaced from trainer
assert trainer.__class__ == SimpleTrainer
trainer.reset_data_loader(lambda: build_loader_func(loader_cfg))
def forward_custom_prepare_fx(root, sub_module_name, orig_ret):
"""Helper function to forward return of `custom_prepare_fx` from sub module"""
new_sub_module, callback = orig_ret
setattr(root, sub_module_name, new_sub_module)
def new_callback(m):
setattr(m, sub_module_name, callback(getattr(m, sub_module_name)))
return m
return root, new_callback
......@@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import torch
from d2go.config import CfgNode
from d2go.quantization.modeling import prepare_fake_quant_model
from d2go.utils.misc import mode
from mobile_cv.arch.quantization.observer import update_stat as observer_update_stat
from pytorch_lightning import LightningModule, Trainer
......@@ -99,7 +100,10 @@ def checkpoint_has_prepared(checkpoint: Dict[str, Any]) -> bool:
def maybe_prepare_for_quantization(model: LightningModule, checkpoint: Dict[str, Any]):
if checkpoint_has_prepared(checkpoint) and not hasattr(model, PREPARED):
# model has been prepared for QAT before saving into checkpoint
setattr(model, PREPARED, _deepcopy(model).custom_prepare_fx(is_qat=True))
copied = _deepcopy(model)
prepared = prepare_fake_quant_model(copied.cfg, copied.model, is_qat=True)
copied.model = prepared
setattr(model, PREPARED, copied)
class QuantizationMixin(ABC):
......@@ -162,8 +166,14 @@ class QuantizationMixin(ABC):
The prepared Module to be used for quantized aware training.
"""
is_qat = isinstance(self, QuantizationAwareTraining)
if hasattr(root, "custom_prepare_fx"):
return root.custom_prepare_fx(is_qat)
self._convert_fx_callback = None
if hasattr(root.model, "custom_prepare_fx"):
prepared, convert_fx_callback = root.model.custom_prepare_fx(
root.cfg, is_qat
)
self._convert_fx_callback = convert_fx_callback
root.model = prepared
return root
prep_fn = prepare_qat_fx if is_qat else prepare_fx
old_attrs = {
attr: rgetattr(root, attr) for attr in attrs if rhasattr(root, attr)
......@@ -203,8 +213,8 @@ class QuantizationMixin(ABC):
Returns:
The quantized model.
"""
if hasattr(root, "custom_convert_fx"):
return root.custom_convert_fx()
if self._convert_fx_callback is not None:
return self._convert_fx_callback(root)
old_attrs = {
attr: rgetattr(root, attr) for attr in attrs if rhasattr(root, attr)
}
......
......@@ -17,10 +17,6 @@ from d2go.data.utils import update_cfg_if_using_adhoc_dataset
from d2go.modeling.api import build_meta_arch
from d2go.modeling.model_freezing_utils import set_requires_grad
from d2go.optimizer import build_optimizer_mapper
from d2go.quantization.modeling import (
default_custom_convert_fx,
default_custom_prepare_fx,
)
from d2go.runner.callbacks.quantization import maybe_prepare_for_quantization, PREPARED
from d2go.runner.default_runner import (
_get_tbx_writer,
......@@ -505,26 +501,6 @@ class DefaultTask(pl.LightningModule):
self.ema_state.load_state_dict(checkpointed_state["model_ema"])
rank_zero_info("Loaded EMA state from checkpoint.")
# TODO: remove custom_prepare_fx/custom_convert_fx from LightningModule
def custom_prepare_fx(self, is_qat) -> pl.LightningModule:
if hasattr(self.model, "custom_prepare_fx"):
self.model = self.model.custom_prepare_fx(
self.cfg, is_qat, example_input=None
)
else:
self.model = default_custom_prepare_fx(
self.cfg, self.model, is_qat, example_input=None
)
return self
def custom_convert_fx(self) -> pl.LightningModule:
if hasattr(self.model, "custom_convert_fx"):
self.model = self.model.custom_convert_fx(self.cfg)
else:
self.model = default_custom_convert_fx(self.cfg, self.model)
return self
# TODO(T123654122): subclass of DefaultTask will be refactored
class GeneralizedRCNNTask(DefaultTask):
......
......@@ -6,7 +6,7 @@ import logging
import os
import warnings
from contextlib import contextmanager
from typing import Any, Dict, Iterator
from typing import Any, Callable, Dict, Iterator, Optional
# @manual=//vision/fair/detectron2/detectron2:detectron2
import detectron2.utils.comm as comm
......@@ -129,16 +129,27 @@ def _log_api_usage(identifier: str):
torch._C._log_api_usage_once("d2go." + identifier)
def inplace_delegate(self, api_name, sub_module_name, *args, **kwargs):
def inplace_delegate(
self,
api_name: str,
sub_module_name: str,
setter_fn: Optional[Callable],
*args,
**kwargs,
) -> Any:
"""Helper function to delegate API calls to its submodule"""
sub_module = getattr(self, sub_module_name)
api_name = f"delegate_{api_name}"
if hasattr(sub_module, api_name):
func = getattr(sub_module, api_name)
# Assume the return of `func` will replace the submodule
setattr(self, sub_module_name, func(*args, **kwargs))
return self
orig_ret = func(*args, **kwargs)
if setter_fn is None:
# Assume the return of `func` will replace the submodule
setattr(self, sub_module_name, orig_ret)
return self
else:
return setter_fn(self, sub_module_name, orig_ret)
else:
raise RuntimeError(
f"It seems the {sub_module_name} doesn't implement {api_name},"
......
......@@ -59,11 +59,12 @@ class DetMetaArchForTest(torch.nn.Module):
{"": set_backend_and_create_qconfig(cfg, is_train=self.training)},
example_inputs,
)
return self
def custom_convert_fx(self, cfg):
self.avgpool = convert_fx(self.avgpool)
return self
def convert_fx_callback(model):
model.avgpool = convert_fx(model.avgpool)
return model
return self, convert_fx_callback
def get_det_meta_arch_cfg(cfg, dataset_name, output_dir):
......
......@@ -214,17 +214,19 @@ class TestLightningTask(unittest.TestCase):
example_inputs,
self.custom_config_dict,
)
return self
def custom_convert_fx(self, cfg):
self.avgpool = convert_fx(
self.avgpool, convert_custom_config=self.custom_config_dict
)
return self
def convert_fx_callback(model):
model.avgpool = convert_fx(
model.avgpool, convert_custom_config=model.custom_config_dict
)
return model
return self, convert_fx_callback
cfg = self._get_cfg(tmp_dir)
cfg.MODEL.META_ARCHITECTURE = "QuantizableDetMetaArchForTest"
cfg.QUANTIZATION.QAT.ENABLED = True
cfg.QUANTIZATION.EAGER_MODE = False
task = GeneralizedRCNNTask(cfg)
callbacks = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment