Commit ef9c20cc authored by Peizhao Zhang's avatar Peizhao Zhang Committed by Facebook GitHub Bot
Browse files

Refactored qat related code.

Summary:
Refactored qat related code.
* Moved `_prepare_model_for_qat` related code to a function.
* Moved `_setup_non_qat_to_qat_state_dict_map` related code to a function.
* Moved QATHook related code to the quantization file and implemented as a class.

Differential Revision: D31370819

fbshipit-source-id: 836550b2c8d68cd93a84d5877ad9cef6f0f0eb39
parent bfc08c53
...@@ -2,14 +2,18 @@ ...@@ -2,14 +2,18 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import contextlib
import copy import copy
import inspect import inspect
import logging import logging
import math
from typing import Tuple from typing import Tuple
import detectron2.utils.comm as comm
import torch import torch
from detectron2.checkpoint import DetectionCheckpointer from detectron2.checkpoint import DetectionCheckpointer
from detectron2.engine import HookBase
from detectron2.engine import SimpleTrainer
from mobile_cv.arch.quantization.observer import update_stat as observer_update_stat
from mobile_cv.arch.utils import fuse_utils from mobile_cv.arch.utils import fuse_utils
from mobile_cv.common.misc.iter_utils import recursive_iterate from mobile_cv.common.misc.iter_utils import recursive_iterate
...@@ -185,7 +189,7 @@ def mock_quantization_type(quant_func): ...@@ -185,7 +189,7 @@ def mock_quantization_type(quant_func):
@functools.wraps(quant_func) @functools.wraps(quant_func)
def wrapper(cfg, model, *args, **kwargs): def wrapper(cfg, model, *args, **kwargs):
if type(d2l.Linear) == torch.nn.Linear: if d2l.Linear == torch.nn.Linear:
# we do not need the moc after when the type is expected, consider # we do not need the moc after when the type is expected, consider
# remving those related code # remving those related code
logger.warning( logger.warning(
...@@ -317,15 +321,7 @@ def post_training_quantize(cfg, model, data_loader): ...@@ -317,15 +321,7 @@ def post_training_quantize(cfg, model, data_loader):
return model return model
@mock_quantization_type def _prepare_model_for_qat(cfg, model):
def setup_qat_model(cfg, model, enable_fake_quant=False, enable_observer=False):
if hasattr(model, "_non_qat_to_qat_state_dict_map"):
raise RuntimeError("The model is already setup to be QAT, cannot setup again!")
device = model.device
torch.backends.quantized.engine = cfg.QUANTIZATION.BACKEND
original_state_dict_shapes = {k: v.shape for k, v in model.state_dict().items()}
if cfg.QUANTIZATION.EAGER_MODE: if cfg.QUANTIZATION.EAGER_MODE:
if hasattr(model, "prepare_for_quant"): if hasattr(model, "prepare_for_quant"):
model = model.prepare_for_quant(cfg) model = model.prepare_for_quant(cfg)
...@@ -341,6 +337,27 @@ def setup_qat_model(cfg, model, enable_fake_quant=False, enable_observer=False): ...@@ -341,6 +337,27 @@ def setup_qat_model(cfg, model, enable_fake_quant=False, enable_observer=False):
else: else:
logger.info("Using default implementation for prepare_for_quant") logger.info("Using default implementation for prepare_for_quant")
model = default_prepare_for_quant(cfg, model) model = default_prepare_for_quant(cfg, model)
return model
@mock_quantization_type
def setup_qat_model(
cfg,
model_fp32,
enable_fake_quant: bool = False,
enable_observer: bool = False,
):
if hasattr(model_fp32, "_non_qat_to_qat_state_dict_map"):
raise RuntimeError("The model is already setup to be QAT, cannot setup again!")
device = model_fp32.device
torch.backends.quantized.engine = cfg.QUANTIZATION.BACKEND
# prepare for qat may modify the fp32 model directly so we create a copy
model_fp32_state_dict = model_fp32.state_dict()
# prepare model for qat
model = _prepare_model_for_qat(cfg, model_fp32)
# Move newly added observers to the original device # Move newly added observers to the original device
model.to(device) model.to(device)
...@@ -352,21 +369,33 @@ def setup_qat_model(cfg, model, enable_fake_quant=False, enable_observer=False): ...@@ -352,21 +369,33 @@ def setup_qat_model(cfg, model, enable_fake_quant=False, enable_observer=False):
logger.info("Disabling observer ...") logger.info("Disabling observer ...")
model.apply(torch.ao.quantization.disable_observer) model.apply(torch.ao.quantization.disable_observer)
# qat state dict mapper
model = _setup_non_qat_to_qat_state_dict_map(
model_fp32_state_dict, model, is_eager_mode=cfg.QUANTIZATION.EAGER_MODE
)
return model
def _setup_non_qat_to_qat_state_dict_map(
model_fp32_state_dict, model_qat, is_eager_mode
):
original_state_dict_shapes = {k: v.shape for k, v in model_fp32_state_dict.items()}
# fuse_model and prepare_qat may change the state_dict of model, keep a map from the # fuse_model and prepare_qat may change the state_dict of model, keep a map from the
# orginal model to the key QAT in order to load weight from non-QAT model. # orginal model to the key QAT in order to load weight from non-QAT model.
new_state_dict_shapes = {k: v.shape for k, v in model.state_dict().items()} new_state_dict_shapes = {k: v.shape for k, v in model_qat.state_dict().items()}
new_state_dict_non_observer_keys = [ new_state_dict_non_observer_keys = [
k for k in new_state_dict_shapes if not _is_observer_key(k) k for k in new_state_dict_shapes if not _is_observer_key(k)
] ]
assert len(new_state_dict_non_observer_keys) == len(original_state_dict_shapes) assert len(new_state_dict_non_observer_keys) == len(original_state_dict_shapes)
if cfg.QUANTIZATION.EAGER_MODE: if is_eager_mode:
for n_k, o_k in zip( for n_k, o_k in zip(
new_state_dict_non_observer_keys, original_state_dict_shapes new_state_dict_non_observer_keys, original_state_dict_shapes
): ):
assert new_state_dict_shapes[n_k] == original_state_dict_shapes[o_k] assert new_state_dict_shapes[n_k] == original_state_dict_shapes[o_k]
# _q_state_dict_map will store # _q_state_dict_map will store
model._non_qat_to_qat_state_dict_map = dict( model_qat._non_qat_to_qat_state_dict_map = dict(
zip(original_state_dict_shapes, new_state_dict_non_observer_keys) zip(original_state_dict_shapes, new_state_dict_non_observer_keys)
) )
else: else:
...@@ -386,13 +415,109 @@ def setup_qat_model(cfg, model, enable_fake_quant=False, enable_observer=False): ...@@ -386,13 +415,109 @@ def setup_qat_model(cfg, model, enable_fake_quant=False, enable_observer=False):
# - bn # - bn
return old_bn_key.replace(".bn.", ".conv.bn.") return old_bn_key.replace(".bn.", ".conv.bn.")
model._non_qat_to_qat_state_dict_map = {} model_qat._non_qat_to_qat_state_dict_map = {}
for key in original_state_dict_shapes.keys(): for key in original_state_dict_shapes.keys():
if key in new_state_dict_non_observer_keys: if key in new_state_dict_non_observer_keys:
model._non_qat_to_qat_state_dict_map[key] = key model_qat._non_qat_to_qat_state_dict_map[key] = key
else: else:
maybe_new_bn_key = get_new_bn_key(key) maybe_new_bn_key = get_new_bn_key(key)
if maybe_new_bn_key in new_state_dict_non_observer_keys: if maybe_new_bn_key in new_state_dict_non_observer_keys:
model._non_qat_to_qat_state_dict_map[key] = maybe_new_bn_key model_qat._non_qat_to_qat_state_dict_map[key] = maybe_new_bn_key
return model_qat
class QATHook(HookBase):
def __init__(self, cfg, build_data_loader_func=None):
self.cfg = cfg
self.build_data_loader_func = build_data_loader_func
self._applied = {
"enable_fake_quant": False,
"enable_observer": False,
"disable_observer": False,
"freeze_bn_stats": False,
}
return model assert (
cfg.QUANTIZATION.QAT.ENABLE_OBSERVER_ITER
<= cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER
), "Can't diable observer before enabling it"
def before_step(self):
cur_iter = self.trainer.iter
model = self.trainer.model
cfg = self.cfg
if (
not self._applied["enable_fake_quant"]
and cur_iter >= cfg.QUANTIZATION.QAT.START_ITER
):
logger.info(
"[QAT] enable fake quant to start QAT, iter = {}".format(cur_iter)
)
model.apply(torch.ao.quantization.enable_fake_quant)
self._applied["enable_fake_quant"] = True
_reset_qat_data_loader_if_needed(
self.cfg, self.trainer, self.build_data_loader_func
)
if (
not self._applied["enable_observer"]
and cur_iter >= cfg.QUANTIZATION.QAT.ENABLE_OBSERVER_ITER
and cur_iter < cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER
):
logger.info("[QAT] enable observer, iter = {}".format(cur_iter))
model.apply(torch.ao.quantization.enable_observer)
self._applied["enable_observer"] = True
if (
not self._applied["disable_observer"]
and cur_iter >= cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER
):
logger.info(
"[QAT] disabling observer for sub seq iters, iter = {}".format(cur_iter)
)
model.apply(torch.ao.quantization.disable_observer)
self._applied["disable_observer"] = True
if (
not self._applied["freeze_bn_stats"]
and cur_iter >= cfg.QUANTIZATION.QAT.FREEZE_BN_ITER
):
logger.info(
"[QAT] freezing BN for subseq iters, iter = {}".format(cur_iter)
)
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
self._applied["freeze_bn_stats"] = True
if (
self._applied["enable_fake_quant"]
and cfg.QUANTIZATION.QAT.UPDATE_OBSERVER_STATS_PERIODICALLY
and cur_iter % cfg.QUANTIZATION.QAT.UPDATE_OBSERVER_STATS_PERIOD == 0
):
logger.info(f"[QAT] updating observers, iter = {cur_iter}")
model.apply(observer_update_stat)
def _reset_qat_data_loader_if_needed(cfg, trainer, build_loader_func):
if cfg.QUANTIZATION.QAT.BATCH_SIZE_FACTOR != 1.0:
loader_cfg = cfg.clone()
loader_cfg.defrost()
num_gpus = comm.get_world_size()
old_bs = cfg.SOLVER.IMS_PER_BATCH // num_gpus
new_bs = math.ceil(old_bs * cfg.QUANTIZATION.QAT.BATCH_SIZE_FACTOR)
loader_cfg.SOLVER.IMS_PER_BATCH = new_bs * num_gpus
loader_cfg.freeze()
logger.info(
"[QAT] Rebuild data loader with batch size per GPU: {} -> {}".format(
old_bs, new_bs
)
)
# This method assumes the data loader can be replaced from trainer
assert trainer.__class__ == SimpleTrainer
del trainer._data_loader_iter
del trainer.data_loader
data_loader = build_loader_func(loader_cfg)
trainer.data_loader = data_loader
trainer._data_loader_iter = iter(data_loader)
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
import logging import logging
import math
import os import os
from collections import OrderedDict from collections import OrderedDict
from functools import lru_cache from functools import lru_cache
...@@ -35,6 +34,7 @@ from d2go.modeling.model_freezing_utils import ( ...@@ -35,6 +34,7 @@ from d2go.modeling.model_freezing_utils import (
from d2go.modeling.quantization import ( from d2go.modeling.quantization import (
QATCheckpointer, QATCheckpointer,
setup_qat_model, setup_qat_model,
QATHook,
) )
from d2go.optimizer import build_optimizer_mapper from d2go.optimizer import build_optimizer_mapper
from d2go.utils.flop_calculator import add_flop_printing_hook from d2go.utils.flop_calculator import add_flop_printing_hook
...@@ -65,7 +65,6 @@ from detectron2.solver import ( ...@@ -65,7 +65,6 @@ from detectron2.solver import (
) )
from detectron2.utils.events import CommonMetricPrinter, JSONWriter from detectron2.utils.events import CommonMetricPrinter, JSONWriter
from detectron2.utils.registry import Registry from detectron2.utils.registry import Registry
from mobile_cv.arch.quantization.observer import update_stat as observer_update_stat
from mobile_cv.predictor.api import PredictorWrapper from mobile_cv.predictor.api import PredictorWrapper
...@@ -593,98 +592,14 @@ class Detectron2GoRunner(BaseRunner): ...@@ -593,98 +592,14 @@ class Detectron2GoRunner(BaseRunner):
""" """
return None return None
def _create_qat_hook(self, cfg): def _create_qat_hook(self, cfg) -> Optional[QATHook]:
""" """
Create a hook to start QAT (during training) and/or change the phase of QAT. Create a hook to start QAT (during training) and/or change the phase of QAT.
""" """
applied = { if not cfg.QUANTIZATION.QAT.ENABLED:
"enable_fake_quant": False, return None
"enable_observer": False,
"disable_observer": False,
"freeze_bn_stats": False,
}
assert (
cfg.QUANTIZATION.QAT.ENABLE_OBSERVER_ITER
<= cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER
), "Can't diable observer before enabling it"
def qat_before_step_callback(trainer):
if (
not applied["enable_fake_quant"]
and trainer.iter >= cfg.QUANTIZATION.QAT.START_ITER
):
logger.info(
"[QAT] enable fake quant to start QAT, iter = {}".format(
trainer.iter
)
)
trainer.model.apply(torch.ao.quantization.enable_fake_quant)
applied["enable_fake_quant"] = True
if cfg.QUANTIZATION.QAT.BATCH_SIZE_FACTOR != 1.0:
loader_cfg = cfg.clone()
loader_cfg.defrost()
num_gpus = comm.get_world_size()
old_bs = cfg.SOLVER.IMS_PER_BATCH // num_gpus
new_bs = math.ceil(old_bs * cfg.QUANTIZATION.QAT.BATCH_SIZE_FACTOR)
loader_cfg.SOLVER.IMS_PER_BATCH = new_bs * num_gpus
loader_cfg.freeze()
logger.info(
"[QAT] Rebuild data loader with batch size per GPU: {} -> {}".format(
old_bs, new_bs
)
)
# This method assumes the data loader can be replaced from trainer
assert trainer.__class__ == SimpleTrainer
del trainer._data_loader_iter
del trainer.data_loader
data_loader = self.build_detection_train_loader(loader_cfg)
trainer.data_loader = data_loader
trainer._data_loader_iter = iter(data_loader)
if (
not applied["enable_observer"]
and trainer.iter >= cfg.QUANTIZATION.QAT.ENABLE_OBSERVER_ITER
and trainer.iter < cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER
):
logger.info("[QAT] enable observer, iter = {}".format(trainer.iter))
trainer.model.apply(torch.ao.quantization.enable_observer)
applied["enable_observer"] = True
if (
not applied["disable_observer"]
and trainer.iter >= cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER
):
logger.info(
"[QAT] disabling observer for sub seq iters, iter = {}".format(
trainer.iter
)
)
trainer.model.apply(torch.ao.quantization.disable_observer)
applied["disable_observer"] = True
if (
not applied["freeze_bn_stats"]
and trainer.iter >= cfg.QUANTIZATION.QAT.FREEZE_BN_ITER
):
logger.info(
"[QAT] freezing BN for subseq iters, iter = {}".format(trainer.iter)
)
trainer.model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
applied["freeze_bn_stats"] = True
if (
applied["enable_fake_quant"]
and cfg.QUANTIZATION.QAT.UPDATE_OBSERVER_STATS_PERIODICALLY
and trainer.iter % cfg.QUANTIZATION.QAT.UPDATE_OBSERVER_STATS_PERIOD
== 0
):
logger.info(f"[QAT] updating observers, iter = {trainer.iter}")
trainer.model.apply(observer_update_stat)
return hooks.CallbackHook(before_step=qat_before_step_callback) return QATHook(cfg, self.build_detection_train_loader)
class GeneralizedRCNNRunner(Detectron2GoRunner): class GeneralizedRCNNRunner(Detectron2GoRunner):
......
...@@ -335,6 +335,7 @@ class TestDefaultRunner(unittest.TestCase): ...@@ -335,6 +335,7 @@ class TestDefaultRunner(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
runner, cfg = setup(tmp_dir, backend=backend) runner, cfg = setup(tmp_dir, backend=backend)
model = runner.build_model(cfg) model = runner.build_model(cfg)
print(model)
runner.do_train(cfg, model, resume=True) runner.do_train(cfg, model, resume=True)
default_runner._close_all_tbx_writers() default_runner._close_all_tbx_writers()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment