Commit 8f130231 authored by Zhicheng Yan's avatar Zhicheng Yan Committed by Facebook GitHub Bot
Browse files

do not fuse model again for a QAT model

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/643

For a QAT model, it contains observers. After QAT training, those observers already contain updated statistics, such as min_val, max_val.

When we want to export FP32 QAT model for a sanity check, if we call **fuse_utils.fuse_model()** again (which is often already called when we build the QAT model before QAT training), it will remove statistics in the observers.

Reviewed By: wat3rBro

Differential Revision: D52152688

fbshipit-source-id: 08aa16f2aa72b3809e0ba2d346f1b806c0e6ede7
parent da53aa10
...@@ -86,6 +86,8 @@ def _convert_fp_model( ...@@ -86,6 +86,8 @@ def _convert_fp_model(
cfg: CfgNode, pytorch_model: nn.Module, data_loader: Iterable cfg: CfgNode, pytorch_model: nn.Module, data_loader: Iterable
) -> nn.Module: ) -> nn.Module:
"""Converts floating point predictor""" """Converts floating point predictor"""
if not isinstance(cfg, CfgNode) or (not cfg.QUANTIZATION.QAT.ENABLED):
# Do not fuse model again for QAT model since it will remove observer statistics (e.g. min_val, max_val)
pytorch_model = fuse_utils.fuse_model(pytorch_model) pytorch_model = fuse_utils.fuse_model(pytorch_model)
logger.info(f"Fused Model:\n{pytorch_model}") logger.info(f"Fused Model:\n{pytorch_model}")
if fuse_utils.count_bn_exist(pytorch_model) > 0: if fuse_utils.count_bn_exist(pytorch_model) > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment