Commit c2256758 authored by Zhicheng Yan's avatar Zhicheng Yan Committed by Facebook GitHub Bot
Browse files

allow to ignore state dict keys in QAT model

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/642

When we build a QAT model using FX graph mode API **prepare_qat_fx** and **convert_fx**, they will run symbolic tracing following **module.forward()**.

In certain cases, such as a module takes constant tensor input, the symbolic tracing will add new tensor attributes with name prefix **_tensor_constant** (https://fburl.com/code/msc4ch4o), which becomes new keys in the QAT model state dict.

In current implementation of **_setup_non_qat_to_qat_state_dict_map**, it asserts # of keys in the state dict of original- and QAT model should be the same.

Thus, we extend **qat_state_dict_keys_to_ignore** method by adding an argument, which allows to ignore specified state dict keys in the QAT model.

Reviewed By: wat3rBro

Differential Revision: D52152706

fbshipit-source-id: 92219feae43bf8841b0a3a71adfbfcb84d8e8f95
parent 8f130231
......@@ -21,6 +21,7 @@ from detectron2.utils.file_io import PathManager
from mobile_cv.arch.quantization.observer import update_stat as observer_update_stat
from mobile_cv.arch.utils import fuse_utils
from mobile_cv.common.misc.iter_utils import recursive_iterate
from torch import nn
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION > (1, 10):
......@@ -511,8 +512,14 @@ def setup_qat_model(
# qat state dict mapper
if not getattr(model, "_non_qat_to_qat_state_dict_map", None):
qat_state_dict_keys_to_ignore = getattr(
model, "qat_model_state_dict_keys_to_ignore", ()
)
model = _setup_non_qat_to_qat_state_dict_map(
model_fp32_state_dict, model, is_eager_mode=cfg.QUANTIZATION.EAGER_MODE
model_fp32_state_dict,
model,
cfg.QUANTIZATION.EAGER_MODE,
qat_state_dict_keys_to_ignore,
)
# qat optimizer group for learnable qat
......@@ -522,8 +529,18 @@ def setup_qat_model(
def _setup_non_qat_to_qat_state_dict_map(
model_fp32_state_dict, model_qat, is_eager_mode
model_fp32_state_dict: Dict,
model_qat: nn.Module,
is_eager_mode: bool,
qat_state_dict_keys_to_ignore: Tuple[str] = (), # pyre-ignore
):
"""
Args:
model_fp32_state_dict: state dict of the orignal fp32 pytorch model
model_qat: prepared qat model
is_eager_mode: whether the model is eager mode
qat_state_dict_keys_to_ignore: QAT model obtained by fuse_model_qat_fx() (https://fburl.com/code/t70qv2aq) may contain new state dict keys that are not present in the original model state dict (https://fburl.com/code/f8vk47w3). Such keys need to be ignored when we build a mapping from original model state dict keys to new qat model state dict keys below
"""
original_state_dict_shapes = {k: v.shape for k, v in model_fp32_state_dict.items()}
# fuse_model and prepare_qat may change the state_dict of model, keep a map from the
# orginal model to the key QAT in order to load weight from non-QAT model.
......@@ -531,7 +548,15 @@ def _setup_non_qat_to_qat_state_dict_map(
new_state_dict_non_observer_keys = [
k for k in new_state_dict_shapes if not _is_observer_key(k)
]
assert len(new_state_dict_non_observer_keys) == len(original_state_dict_shapes)
new_state_dict_non_observer_keys_not_ignored = list(
set(new_state_dict_non_observer_keys).difference(
set(qat_state_dict_keys_to_ignore)
)
)
assert len(new_state_dict_non_observer_keys_not_ignored) == len(
original_state_dict_shapes
), f"keys in state dict of original and new qat model {len(new_state_dict_non_observer_keys_not_ignored)} vs {len(original_state_dict_shapes)}"
if is_eager_mode:
for n_k, o_k in zip(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment