Commit 13b2fe71 authored by Matthew Yu's avatar Matthew Yu Committed by Facebook GitHub Bot
Browse files

swap the order of qat and layer freezing to preserve checkpoint values

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/399

Freezing the model before running quantization causes an issue with loading a saved checkpoint bc fusing does not support FrozenBatchNorm2d (which means that the checkpoint could have a fused weight conv.bn.weight whereas the model would have an unfused weight bn.weight). The longer term solution is to add FrozenBatchNorm2d to the fusing support but there are some subtle issues there that will take some time to fix:
* need to move FrozenBatchNorm2d out of D2 (https://github.com/facebookresearch/d2go/commit/87374efb134e539090e0b5c476809dc35bf6aedb) and into mobile_cv lib
* current fuser has options to add new bn ops (e.g., FrozenBatchNorm2d) which we use with ops like SyncBN but this currently is only tested with inference so we need to write some additional checks on training

The swap will make freezing compatible with QAT and should still work with standard models. One subtle potential issue is that the current BN swap assumes that BN is a leaf node. If a user runs QAT without fusing BN, the BN will no longer be the leaf node as it will obtain an activation_post_process module in order to record the output. The result is that BN will not be frozen in this specific instance. This should not occur as BN is usually fused. A small adjustment to the BN swap would just be to swap the BN regardless of whether it is a leaf node (but we have to check whether activation_post_process module is retained). Another long term consideration is moving both freezing and quant to modeling hooks so the user can decide the order.

Reviewed By: wat3rBro

Differential Revision: D40496052

fbshipit-source-id: 0d7e467b833821f7952cd2fce459ae1f76e1fa3b
parent 2e52e963
......@@ -216,10 +216,6 @@ class Detectron2GoRunner(BaseRunner):
model = build_d2go_model(cfg).model
model_ema.may_build_model_ema(cfg, model)
if cfg.MODEL.FROZEN_LAYER_REG_EXP:
set_requires_grad(model, cfg.MODEL.FROZEN_LAYER_REG_EXP, False)
model = freeze_matched_bn(model, cfg.MODEL.FROZEN_LAYER_REG_EXP)
if cfg.QUANTIZATION.QAT.ENABLED:
# Disable fake_quant and observer so that the model will be trained normally
# before QAT being turned on (controlled by QUANTIZATION.QAT.START_ITER).
......@@ -243,6 +239,10 @@ class Detectron2GoRunner(BaseRunner):
enable_observer=False,
)
if cfg.MODEL.FROZEN_LAYER_REG_EXP:
set_requires_grad(model, cfg.MODEL.FROZEN_LAYER_REG_EXP, False)
model = freeze_matched_bn(model, cfg.MODEL.FROZEN_LAYER_REG_EXP)
if eval_only:
checkpointer = self.build_checkpointer(cfg, model, save_dir=cfg.OUTPUT_DIR)
checkpointer.load(cfg.MODEL.WEIGHTS)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment