"examples/git@developer.sourcefind.cn:OpenDAS/fairseq.git" did not exist on "ea1a410d590e63e6fd24942ab8376600c12e2194"
Commit c2256758 authored by Zhicheng Yan's avatar Zhicheng Yan Committed by Facebook GitHub Bot
Browse files

allow to ignore state dict keys in QAT model

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/642

When we build a QAT model using FX graph mode API **prepare_qat_fx** and **convert_fx**, they will run symbolic tracing following **module.forward()**.

In certain cases, such as a module takes constant tensor input, the symbolic tracing will add new tensor attributes with name prefix **_tensor_constant** (https://fburl.com/code/msc4ch4o), which becomes new keys in the QAT model state dict.

In current implementation of **_setup_non_qat_to_qat_state_dict_map**, it asserts # of keys in the state dict of original- and QAT model should be the same.

Thus, we extend **qat_state_dict_keys_to_ignore** method by adding an argument, which allows to ignore specified state dict keys in the QAT model.

Reviewed By: wat3rBro

Differential Revision: D52152706

fbshipit-source-id: 92219feae43bf8841b0a3a71adfbfcb84d8e8f95
parent 8f130231
...@@ -21,6 +21,7 @@ from detectron2.utils.file_io import PathManager ...@@ -21,6 +21,7 @@ from detectron2.utils.file_io import PathManager
from mobile_cv.arch.quantization.observer import update_stat as observer_update_stat from mobile_cv.arch.quantization.observer import update_stat as observer_update_stat
from mobile_cv.arch.utils import fuse_utils from mobile_cv.arch.utils import fuse_utils
from mobile_cv.common.misc.iter_utils import recursive_iterate from mobile_cv.common.misc.iter_utils import recursive_iterate
from torch import nn
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION > (1, 10): if TORCH_VERSION > (1, 10):
...@@ -511,8 +512,14 @@ def setup_qat_model( ...@@ -511,8 +512,14 @@ def setup_qat_model(
# qat state dict mapper # qat state dict mapper
if not getattr(model, "_non_qat_to_qat_state_dict_map", None): if not getattr(model, "_non_qat_to_qat_state_dict_map", None):
qat_state_dict_keys_to_ignore = getattr(
model, "qat_model_state_dict_keys_to_ignore", ()
)
model = _setup_non_qat_to_qat_state_dict_map( model = _setup_non_qat_to_qat_state_dict_map(
model_fp32_state_dict, model, is_eager_mode=cfg.QUANTIZATION.EAGER_MODE model_fp32_state_dict,
model,
cfg.QUANTIZATION.EAGER_MODE,
qat_state_dict_keys_to_ignore,
) )
# qat optimizer group for learnable qat # qat optimizer group for learnable qat
...@@ -522,8 +529,18 @@ def setup_qat_model( ...@@ -522,8 +529,18 @@ def setup_qat_model(
def _setup_non_qat_to_qat_state_dict_map( def _setup_non_qat_to_qat_state_dict_map(
model_fp32_state_dict, model_qat, is_eager_mode model_fp32_state_dict: Dict,
model_qat: nn.Module,
is_eager_mode: bool,
qat_state_dict_keys_to_ignore: Tuple[str] = (), # pyre-ignore
): ):
"""
Args:
model_fp32_state_dict: state dict of the orignal fp32 pytorch model
model_qat: prepared qat model
is_eager_mode: whether the model is eager mode
qat_state_dict_keys_to_ignore: QAT model obtained by fuse_model_qat_fx() (https://fburl.com/code/t70qv2aq) may contain new state dict keys that are not present in the original model state dict (https://fburl.com/code/f8vk47w3). Such keys need to be ignored when we build a mapping from original model state dict keys to new qat model state dict keys below
"""
original_state_dict_shapes = {k: v.shape for k, v in model_fp32_state_dict.items()} original_state_dict_shapes = {k: v.shape for k, v in model_fp32_state_dict.items()}
# fuse_model and prepare_qat may change the state_dict of model, keep a map from the # fuse_model and prepare_qat may change the state_dict of model, keep a map from the
# orginal model to the key QAT in order to load weight from non-QAT model. # orginal model to the key QAT in order to load weight from non-QAT model.
...@@ -531,7 +548,15 @@ def _setup_non_qat_to_qat_state_dict_map( ...@@ -531,7 +548,15 @@ def _setup_non_qat_to_qat_state_dict_map(
new_state_dict_non_observer_keys = [ new_state_dict_non_observer_keys = [
k for k in new_state_dict_shapes if not _is_observer_key(k) k for k in new_state_dict_shapes if not _is_observer_key(k)
] ]
assert len(new_state_dict_non_observer_keys) == len(original_state_dict_shapes) new_state_dict_non_observer_keys_not_ignored = list(
set(new_state_dict_non_observer_keys).difference(
set(qat_state_dict_keys_to_ignore)
)
)
assert len(new_state_dict_non_observer_keys_not_ignored) == len(
original_state_dict_shapes
), f"keys in state dict of original and new qat model {len(new_state_dict_non_observer_keys_not_ignored)} vs {len(original_state_dict_shapes)}"
if is_eager_mode: if is_eager_mode:
for n_k, o_k in zip( for n_k, o_k in zip(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment