Commit 3c6f71b4 authored by Zhicheng Yan's avatar Zhicheng Yan Committed by Facebook GitHub Bot
Browse files

expose example_input argument in setup_qat_model()

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/647

Major changes
- **example_input** argument in **prepare_fake_quant_model()** is useful in certain cases. For example, in Argos model **custom_prepare_fx()** method under FX graph + QAT setup (D52760682), it is used to prepare example inputs to individual sub-modules by running one forward pass and bookkeeping the inputs to individual sub-modules. Therefore, we export argument **example_input** in **setup_qat_model()** function.
- For QAT model, currently we assert # of state dict keys (excluding observers) should be equal to # of state dict keys in the original model. However, when the assertion fails, it does not log useful information for debugging. We make changes to report what are the unique keys in each state dict.

Reviewed By: navsud

Differential Revision: D52760688

fbshipit-source-id: 27535a0324ebe6513f198acb839918a0346720d0
parent 92af450b
......@@ -5,7 +5,7 @@
import copy
import logging
import math
from typing import Any, Dict, Tuple
from typing import Any, Dict, Optional, Tuple
import detectron2.utils.comm as comm
import torch
......@@ -471,6 +471,7 @@ def setup_qat_model(
enable_fake_quant: bool = False,
enable_observer: bool = False,
enable_learnable_observer: bool = False,
example_input: Optional[Any] = None,
):
assert cfg.QUANTIZATION.QAT.FAKE_QUANT_METHOD in [
"default",
......@@ -490,7 +491,7 @@ def setup_qat_model(
model_fp32_state_dict = model_fp32.state_dict()
# prepare model for qat
model = prepare_fake_quant_model(cfg, model_fp32, True)
model = prepare_fake_quant_model(cfg, model_fp32, True, example_input=example_input)
# make sure the proper qconfig are used in the model
learnable_qat.check_for_learnable_fake_quant_ops(qat_method, model)
......@@ -554,9 +555,23 @@ def _setup_non_qat_to_qat_state_dict_map(
)
)
assert len(new_state_dict_non_observer_keys_not_ignored) == len(
if not len(new_state_dict_non_observer_keys_not_ignored) == len(
original_state_dict_shapes
), f"keys in state dict of original and new qat model {len(new_state_dict_non_observer_keys_not_ignored)} vs {len(original_state_dict_shapes)}"
):
a = set(new_state_dict_non_observer_keys_not_ignored)
b = set(original_state_dict_shapes.keys())
a_diff_b = a.difference(b)
b_diff_a = b.difference(a)
logger.info("unique keys in qat model state dict")
for key in a_diff_b:
logger.info(f"{key}")
logger.info("unique keys in original model state dict")
for key in b_diff_a:
logger.info(f"{key}")
raise RuntimeError(
f"an inconsistent number of keys in state dict of new qat and original model: {len(a)} vs {len(b)}"
)
if is_eager_mode:
for n_k, o_k in zip(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment