Commit ef9c20cc authored by Peizhao Zhang's avatar Peizhao Zhang Committed by Facebook GitHub Bot
Browse files

Refactored qat related code.

Summary:
Refactored qat related code.
* Moved `_prepare_model_for_qat` related code to a function.
* Moved `_setup_non_qat_to_qat_state_dict_map` related code to a function.
* Moved QATHook related code to the quantization file and implemented as a class.

Differential Revision: D31370819

fbshipit-source-id: 836550b2c8d68cd93a84d5877ad9cef6f0f0eb39
parent bfc08c53
......@@ -2,14 +2,18 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import contextlib
import copy
import inspect
import logging
import math
from typing import Tuple
import detectron2.utils.comm as comm
import torch
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.engine import HookBase
from detectron2.engine import SimpleTrainer
from mobile_cv.arch.quantization.observer import update_stat as observer_update_stat
from mobile_cv.arch.utils import fuse_utils
from mobile_cv.common.misc.iter_utils import recursive_iterate
......@@ -185,7 +189,7 @@ def mock_quantization_type(quant_func):
@functools.wraps(quant_func)
def wrapper(cfg, model, *args, **kwargs):
if type(d2l.Linear) == torch.nn.Linear:
if d2l.Linear == torch.nn.Linear:
# we do not need the moc after when the type is expected, consider
# remving those related code
logger.warning(
......@@ -317,15 +321,7 @@ def post_training_quantize(cfg, model, data_loader):
return model
@mock_quantization_type
def setup_qat_model(cfg, model, enable_fake_quant=False, enable_observer=False):
if hasattr(model, "_non_qat_to_qat_state_dict_map"):
raise RuntimeError("The model is already setup to be QAT, cannot setup again!")
device = model.device
torch.backends.quantized.engine = cfg.QUANTIZATION.BACKEND
original_state_dict_shapes = {k: v.shape for k, v in model.state_dict().items()}
def _prepare_model_for_qat(cfg, model):
if cfg.QUANTIZATION.EAGER_MODE:
if hasattr(model, "prepare_for_quant"):
model = model.prepare_for_quant(cfg)
......@@ -341,6 +337,27 @@ def setup_qat_model(cfg, model, enable_fake_quant=False, enable_observer=False):
else:
logger.info("Using default implementation for prepare_for_quant")
model = default_prepare_for_quant(cfg, model)
return model
@mock_quantization_type
def setup_qat_model(
cfg,
model_fp32,
enable_fake_quant: bool = False,
enable_observer: bool = False,
):
if hasattr(model_fp32, "_non_qat_to_qat_state_dict_map"):
raise RuntimeError("The model is already setup to be QAT, cannot setup again!")
device = model_fp32.device
torch.backends.quantized.engine = cfg.QUANTIZATION.BACKEND
# prepare for qat may modify the fp32 model directly so we create a copy
model_fp32_state_dict = model_fp32.state_dict()
# prepare model for qat
model = _prepare_model_for_qat(cfg, model_fp32)
# Move newly added observers to the original device
model.to(device)
......@@ -352,21 +369,33 @@ def setup_qat_model(cfg, model, enable_fake_quant=False, enable_observer=False):
logger.info("Disabling observer ...")
model.apply(torch.ao.quantization.disable_observer)
# qat state dict mapper
model = _setup_non_qat_to_qat_state_dict_map(
model_fp32_state_dict, model, is_eager_mode=cfg.QUANTIZATION.EAGER_MODE
)
return model
def _setup_non_qat_to_qat_state_dict_map(
model_fp32_state_dict, model_qat, is_eager_mode
):
original_state_dict_shapes = {k: v.shape for k, v in model_fp32_state_dict.items()}
# fuse_model and prepare_qat may change the state_dict of model, keep a map from the
# orginal model to the key QAT in order to load weight from non-QAT model.
new_state_dict_shapes = {k: v.shape for k, v in model.state_dict().items()}
new_state_dict_shapes = {k: v.shape for k, v in model_qat.state_dict().items()}
new_state_dict_non_observer_keys = [
k for k in new_state_dict_shapes if not _is_observer_key(k)
]
assert len(new_state_dict_non_observer_keys) == len(original_state_dict_shapes)
if cfg.QUANTIZATION.EAGER_MODE:
if is_eager_mode:
for n_k, o_k in zip(
new_state_dict_non_observer_keys, original_state_dict_shapes
):
assert new_state_dict_shapes[n_k] == original_state_dict_shapes[o_k]
# _q_state_dict_map will store
model._non_qat_to_qat_state_dict_map = dict(
model_qat._non_qat_to_qat_state_dict_map = dict(
zip(original_state_dict_shapes, new_state_dict_non_observer_keys)
)
else:
......@@ -386,13 +415,109 @@ def setup_qat_model(cfg, model, enable_fake_quant=False, enable_observer=False):
# - bn
return old_bn_key.replace(".bn.", ".conv.bn.")
model._non_qat_to_qat_state_dict_map = {}
model_qat._non_qat_to_qat_state_dict_map = {}
for key in original_state_dict_shapes.keys():
if key in new_state_dict_non_observer_keys:
model._non_qat_to_qat_state_dict_map[key] = key
model_qat._non_qat_to_qat_state_dict_map[key] = key
else:
maybe_new_bn_key = get_new_bn_key(key)
if maybe_new_bn_key in new_state_dict_non_observer_keys:
model._non_qat_to_qat_state_dict_map[key] = maybe_new_bn_key
model_qat._non_qat_to_qat_state_dict_map[key] = maybe_new_bn_key
return model_qat
class QATHook(HookBase):
def __init__(self, cfg, build_data_loader_func=None):
self.cfg = cfg
self.build_data_loader_func = build_data_loader_func
self._applied = {
"enable_fake_quant": False,
"enable_observer": False,
"disable_observer": False,
"freeze_bn_stats": False,
}
return model
assert (
cfg.QUANTIZATION.QAT.ENABLE_OBSERVER_ITER
<= cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER
), "Can't diable observer before enabling it"
def before_step(self):
cur_iter = self.trainer.iter
model = self.trainer.model
cfg = self.cfg
if (
not self._applied["enable_fake_quant"]
and cur_iter >= cfg.QUANTIZATION.QAT.START_ITER
):
logger.info(
"[QAT] enable fake quant to start QAT, iter = {}".format(cur_iter)
)
model.apply(torch.ao.quantization.enable_fake_quant)
self._applied["enable_fake_quant"] = True
_reset_qat_data_loader_if_needed(
self.cfg, self.trainer, self.build_data_loader_func
)
if (
not self._applied["enable_observer"]
and cur_iter >= cfg.QUANTIZATION.QAT.ENABLE_OBSERVER_ITER
and cur_iter < cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER
):
logger.info("[QAT] enable observer, iter = {}".format(cur_iter))
model.apply(torch.ao.quantization.enable_observer)
self._applied["enable_observer"] = True
if (
not self._applied["disable_observer"]
and cur_iter >= cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER
):
logger.info(
"[QAT] disabling observer for sub seq iters, iter = {}".format(cur_iter)
)
model.apply(torch.ao.quantization.disable_observer)
self._applied["disable_observer"] = True
if (
not self._applied["freeze_bn_stats"]
and cur_iter >= cfg.QUANTIZATION.QAT.FREEZE_BN_ITER
):
logger.info(
"[QAT] freezing BN for subseq iters, iter = {}".format(cur_iter)
)
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
self._applied["freeze_bn_stats"] = True
if (
self._applied["enable_fake_quant"]
and cfg.QUANTIZATION.QAT.UPDATE_OBSERVER_STATS_PERIODICALLY
and cur_iter % cfg.QUANTIZATION.QAT.UPDATE_OBSERVER_STATS_PERIOD == 0
):
logger.info(f"[QAT] updating observers, iter = {cur_iter}")
model.apply(observer_update_stat)
def _reset_qat_data_loader_if_needed(cfg, trainer, build_loader_func):
if cfg.QUANTIZATION.QAT.BATCH_SIZE_FACTOR != 1.0:
loader_cfg = cfg.clone()
loader_cfg.defrost()
num_gpus = comm.get_world_size()
old_bs = cfg.SOLVER.IMS_PER_BATCH // num_gpus
new_bs = math.ceil(old_bs * cfg.QUANTIZATION.QAT.BATCH_SIZE_FACTOR)
loader_cfg.SOLVER.IMS_PER_BATCH = new_bs * num_gpus
loader_cfg.freeze()
logger.info(
"[QAT] Rebuild data loader with batch size per GPU: {} -> {}".format(
old_bs, new_bs
)
)
# This method assumes the data loader can be replaced from trainer
assert trainer.__class__ == SimpleTrainer
del trainer._data_loader_iter
del trainer.data_loader
data_loader = build_loader_func(loader_cfg)
trainer.data_loader = data_loader
trainer._data_loader_iter = iter(data_loader)
......@@ -3,7 +3,6 @@
import logging
import math
import os
from collections import OrderedDict
from functools import lru_cache
......@@ -35,6 +34,7 @@ from d2go.modeling.model_freezing_utils import (
from d2go.modeling.quantization import (
QATCheckpointer,
setup_qat_model,
QATHook,
)
from d2go.optimizer import build_optimizer_mapper
from d2go.utils.flop_calculator import add_flop_printing_hook
......@@ -65,7 +65,6 @@ from detectron2.solver import (
)
from detectron2.utils.events import CommonMetricPrinter, JSONWriter
from detectron2.utils.registry import Registry
from mobile_cv.arch.quantization.observer import update_stat as observer_update_stat
from mobile_cv.predictor.api import PredictorWrapper
......@@ -593,98 +592,14 @@ class Detectron2GoRunner(BaseRunner):
"""
return None
def _create_qat_hook(self, cfg):
def _create_qat_hook(self, cfg) -> Optional[QATHook]:
"""
Create a hook to start QAT (during training) and/or change the phase of QAT.
"""
applied = {
"enable_fake_quant": False,
"enable_observer": False,
"disable_observer": False,
"freeze_bn_stats": False,
}
assert (
cfg.QUANTIZATION.QAT.ENABLE_OBSERVER_ITER
<= cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER
), "Can't diable observer before enabling it"
def qat_before_step_callback(trainer):
if (
not applied["enable_fake_quant"]
and trainer.iter >= cfg.QUANTIZATION.QAT.START_ITER
):
logger.info(
"[QAT] enable fake quant to start QAT, iter = {}".format(
trainer.iter
)
)
trainer.model.apply(torch.ao.quantization.enable_fake_quant)
applied["enable_fake_quant"] = True
if cfg.QUANTIZATION.QAT.BATCH_SIZE_FACTOR != 1.0:
loader_cfg = cfg.clone()
loader_cfg.defrost()
num_gpus = comm.get_world_size()
old_bs = cfg.SOLVER.IMS_PER_BATCH // num_gpus
new_bs = math.ceil(old_bs * cfg.QUANTIZATION.QAT.BATCH_SIZE_FACTOR)
loader_cfg.SOLVER.IMS_PER_BATCH = new_bs * num_gpus
loader_cfg.freeze()
logger.info(
"[QAT] Rebuild data loader with batch size per GPU: {} -> {}".format(
old_bs, new_bs
)
)
# This method assumes the data loader can be replaced from trainer
assert trainer.__class__ == SimpleTrainer
del trainer._data_loader_iter
del trainer.data_loader
data_loader = self.build_detection_train_loader(loader_cfg)
trainer.data_loader = data_loader
trainer._data_loader_iter = iter(data_loader)
if (
not applied["enable_observer"]
and trainer.iter >= cfg.QUANTIZATION.QAT.ENABLE_OBSERVER_ITER
and trainer.iter < cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER
):
logger.info("[QAT] enable observer, iter = {}".format(trainer.iter))
trainer.model.apply(torch.ao.quantization.enable_observer)
applied["enable_observer"] = True
if (
not applied["disable_observer"]
and trainer.iter >= cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER
):
logger.info(
"[QAT] disabling observer for sub seq iters, iter = {}".format(
trainer.iter
)
)
trainer.model.apply(torch.ao.quantization.disable_observer)
applied["disable_observer"] = True
if (
not applied["freeze_bn_stats"]
and trainer.iter >= cfg.QUANTIZATION.QAT.FREEZE_BN_ITER
):
logger.info(
"[QAT] freezing BN for subseq iters, iter = {}".format(trainer.iter)
)
trainer.model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
applied["freeze_bn_stats"] = True
if (
applied["enable_fake_quant"]
and cfg.QUANTIZATION.QAT.UPDATE_OBSERVER_STATS_PERIODICALLY
and trainer.iter % cfg.QUANTIZATION.QAT.UPDATE_OBSERVER_STATS_PERIOD
== 0
):
logger.info(f"[QAT] updating observers, iter = {trainer.iter}")
trainer.model.apply(observer_update_stat)
if not cfg.QUANTIZATION.QAT.ENABLED:
return None
return hooks.CallbackHook(before_step=qat_before_step_callback)
return QATHook(cfg, self.build_detection_train_loader)
class GeneralizedRCNNRunner(Detectron2GoRunner):
......
......@@ -335,6 +335,7 @@ class TestDefaultRunner(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmp_dir:
runner, cfg = setup(tmp_dir, backend=backend)
model = runner.build_model(cfg)
print(model)
runner.do_train(cfg, model, resume=True)
default_runner._close_all_tbx_writers()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment