Commit 5074692f authored by Tsahi Glik's avatar Tsahi Glik Committed by Facebook GitHub Bot
Browse files

Cleanup QAT api

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/190

Currently there is some fragmentation in export for how to apply convert logic in various mode. `prepare_for_quant_convert` is only called in non eager modes and the logic in eager mode is not customizable.
This diff unify the `prepare_for_quant_convert` code path for all quantization modes.
Also in this diff we rename `_non_qat_to_qat_state_dict_map`, that is use in qat checkpointer to be publish var `non_qat_to_qat_state_dict_map` and allow models to populate it with custom mapping. This is useful in cases where the param mapping between the non qat model and the qat model cannot be inferred definitely (see note in https://fburl.com/code/9rx172ht) and have some ambiguity that can only be resolved by the model logic.

Reviewed By: wat3rBro

Differential Revision: D34741217

fbshipit-source-id: 38edfec64200ec986ffe4f3d47f527cb6a3fb5e9
parent 2b618211
...@@ -114,14 +114,13 @@ def convert_predictor( ...@@ -114,14 +114,13 @@ def convert_predictor(
logger.warn("Post training quantized model has bn inside fused ops") logger.warn("Post training quantized model has bn inside fused ops")
logger.info(f"Converting quantized model {cfg.QUANTIZATION.BACKEND}...") logger.info(f"Converting quantized model {cfg.QUANTIZATION.BACKEND}...")
if cfg.QUANTIZATION.EAGER_MODE:
# TODO(T93870278): move this logic to prepare_for_quant_convert
pytorch_model = convert(pytorch_model, inplace=False)
else: # FX graph mode quantization
if hasattr(pytorch_model, "prepare_for_quant_convert"): if hasattr(pytorch_model, "prepare_for_quant_convert"):
pytorch_model = pytorch_model.prepare_for_quant_convert(cfg) pytorch_model = pytorch_model.prepare_for_quant_convert(cfg)
else: else:
# TODO(T93870381): move this to a default function # TODO(T93870381): move this to a default function
if cfg.QUANTIZATION.EAGER_MODE:
pytorch_model = convert(pytorch_model, inplace=False)
else: # FX graph mode quantization
pytorch_model = convert_fx(pytorch_model) pytorch_model = convert_fx(pytorch_model)
logger.info("Quantized Model:\n{}".format(pytorch_model)) logger.info("Quantized Model:\n{}".format(pytorch_model))
......
...@@ -20,6 +20,7 @@ from mobile_cv.arch.utils.quantize_utils import ( ...@@ -20,6 +20,7 @@ from mobile_cv.arch.utils.quantize_utils import (
QuantWrapper, QuantWrapper,
) )
from mobile_cv.predictor.api import FuncInfo from mobile_cv.predictor.api import FuncInfo
from torch.ao.quantization import convert
from torch.ao.quantization.quantize_fx import prepare_fx, prepare_qat_fx, convert_fx from torch.ao.quantization.quantize_fx import prepare_fx, prepare_qat_fx, convert_fx
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -278,8 +279,8 @@ def default_rcnn_prepare_for_quant(self, cfg): ...@@ -278,8 +279,8 @@ def default_rcnn_prepare_for_quant(self, cfg):
@RCNN_PREPARE_FOR_QUANT_CONVERT_REGISTRY.register() @RCNN_PREPARE_FOR_QUANT_CONVERT_REGISTRY.register()
def default_rcnn_prepare_for_quant_convert(self, cfg): def default_rcnn_prepare_for_quant_convert(self, cfg):
if cfg.QUANTIZATION.EAGER_MODE: if cfg.QUANTIZATION.EAGER_MODE:
raise NotImplementedError() convert(self, inplace=True)
else:
assert not isinstance(self.backbone, FPN), "FPN is not supported in FX mode" assert not isinstance(self.backbone, FPN), "FPN is not supported in FX mode"
self.backbone = convert_fx( self.backbone = convert_fx(
self.backbone, self.backbone,
......
...@@ -393,6 +393,7 @@ def setup_qat_model( ...@@ -393,6 +393,7 @@ def setup_qat_model(
model.apply(qat_utils.disable_lqat_learnable_observer) model.apply(qat_utils.disable_lqat_learnable_observer)
# qat state dict mapper # qat state dict mapper
if not getattr(model, "_non_qat_to_qat_state_dict_map", None):
model = _setup_non_qat_to_qat_state_dict_map( model = _setup_non_qat_to_qat_state_dict_map(
model_fp32_state_dict, model, is_eager_mode=cfg.QUANTIZATION.EAGER_MODE model_fp32_state_dict, model, is_eager_mode=cfg.QUANTIZATION.EAGER_MODE
) )
...@@ -419,7 +420,9 @@ def _setup_non_qat_to_qat_state_dict_map( ...@@ -419,7 +420,9 @@ def _setup_non_qat_to_qat_state_dict_map(
for n_k, o_k in zip( for n_k, o_k in zip(
new_state_dict_non_observer_keys, original_state_dict_shapes new_state_dict_non_observer_keys, original_state_dict_shapes
): ):
assert new_state_dict_shapes[n_k] == original_state_dict_shapes[o_k] assert (
new_state_dict_shapes[n_k] == original_state_dict_shapes[o_k]
), f"QAT model shapes is inconsistent. FP32.{o_k}={original_state_dict_shapes[o_k]} , QAT.{n_k}={new_state_dict_shapes[n_k]}"
# _q_state_dict_map will store # _q_state_dict_map will store
model_qat._non_qat_to_qat_state_dict_map = dict( model_qat._non_qat_to_qat_state_dict_map = dict(
zip(original_state_dict_shapes, new_state_dict_non_observer_keys) zip(original_state_dict_shapes, new_state_dict_non_observer_keys)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment