Commit 62c21f25 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

remove deprecated silicon quantization

Summary: EZ

Reviewed By: zhanghang1989

Differential Revision: D29000628

fbshipit-source-id: f954214dfe3a989fc145663f8bb1870812e78ce7
parent fc690b45
......@@ -113,33 +113,12 @@ def add_quantization_default_configs(_C):
_C.QUANTIZATION.PTQ.CALIBRATION_NUM_IMAGES = 1
_C.QUANTIZATION.PTQ.CALIBRATION_FORCE_ON_GPU = False
# deprecated
_C.QUANTIZATION.SILICON_QAT = CfgNode()
_C.QUANTIZATION.SILICON_QAT.ENABLED = False
# register deprecated and renamed keys
_C.register_deprecated_key("QUANTIZATION.QAT.LOAD_PRETRAINED")
_C.register_renamed_key("QUANTIZATION.QAT.BACKEND", "QUANTIZATION.BACKEND")
_C.register_deprecated_key("QUANTIZATION.ENABLE_CUSTOM_QSCHEME")
@contextlib.contextmanager
def silicon_qat_build_model_context(cfg):
mock_ctx_managers = []
if cfg.QUANTIZATION.SILICON_QAT.ENABLED:
from mobile_cv.silicon_pytorch_qat.replace_op import mock_quant_ops
mock_ctx_managers.extend(
[
mock_quant_ops(quant_op="quant_add"),
mock_quant_ops(quant_op="quant_fbb_convbnrelu"),
]
)
with contextlib.ExitStack() as stack:
for mgr in mock_ctx_managers:
stack.enter_context(mgr)
yield
_C.register_deprecated_key("QUANTIZATION.SILICON_QAT")
_C.register_deprecated_key("QUANTIZATION.SILICON_QAT.ENABLED")
# TODO: model.to(device) might not work for detection meta-arch, this function is the
......@@ -168,26 +147,29 @@ def _cast_detection_model(model, device):
def add_d2_quant_mapping(mappings):
""" HACK: Add d2 specific module mapping for eager model quantization
"""
"""HACK: Add d2 specific module mapping for eager model quantization"""
import torch.quantization.quantization_mappings as qm
for k, v in mappings.items():
if k not in qm.get_default_static_quant_module_mappings():
qm.DEFAULT_STATIC_QUANT_MODULE_MAPPINGS[k] = v
if k not in qm.get_default_qat_module_mappings():
qm.DEFAULT_QAT_MODULE_MAPPINGS[k] = v
# The `mock_quantization_type` decorate may not be needed anymore to unify
# detectron2.layers modules and torch.nn modules since Pytorch 1.5. See comments on D23790034.
def mock_quantization_type(quant_func):
import mock
import builtins
import functools
import detectron2.layers as d2l
import mock
type_mapping = {d2l.Linear: torch.nn.Linear}
from d2go.utils.misc import check_version
if check_version(torch, '1.7.2', warning_only=True):
if check_version(torch, "1.7.2", warning_only=True):
add_d2_quant_mapping(type_mapping)
real_type = builtins.type
......@@ -270,13 +252,14 @@ def default_prepare_for_quant(cfg, model):
return model
def default_prepare_for_quant_convert(cfg, model):
return torch.quantization.quantize_fx.convert_fx(model)
@mock_quantization_type
def post_training_quantize(cfg, model, data_loader):
""" Calibrate a model, convert it to a quantized pytorch model """
"""Calibrate a model, convert it to a quantized pytorch model"""
model = copy.deepcopy(model)
model.eval()
# TODO: check why some parameters will have gradient
......@@ -372,7 +355,9 @@ def setup_qat_model(cfg, model, enable_fake_quant=False, enable_observer=False):
assert len(new_state_dict_non_observer_keys) == len(original_state_dict_shapes)
if cfg.QUANTIZATION.EAGER_MODE:
for n_k, o_k in zip(new_state_dict_non_observer_keys, original_state_dict_shapes):
for n_k, o_k in zip(
new_state_dict_non_observer_keys, original_state_dict_shapes
):
assert new_state_dict_shapes[n_k] == original_state_dict_shapes[o_k]
# _q_state_dict_map will store
model._non_qat_to_qat_state_dict_map = dict(
......
......@@ -36,7 +36,6 @@ from d2go.modeling.model_freezing_utils import (
from d2go.modeling.quantization import (
QATCheckpointer,
setup_qat_model,
silicon_qat_build_model_context,
)
from d2go.optimizer import build_optimizer_mapper
from d2go.utils.flop_calculator import add_print_flops_callback
......@@ -151,7 +150,7 @@ class BaseRunner(object):
torch._C._log_api_usage_once(identifier)
def _initialize(self, cfg):
""" Runner should be initialized in the sub-process in ddp setting """
"""Runner should be initialized in the sub-process in ddp setting"""
if getattr(self, "_has_initialized", False):
logger.warning("Runner has already been initialized, skip initialization.")
return
......@@ -174,10 +173,13 @@ class BaseRunner(object):
from detectron2.config import get_cfg as get_d2_cfg
cfg = get_d2_cfg()
cfg = CfgNode.cast_from_other_class(cfg) # upgrade from D2's CfgNode to D2Go's CfgNode
cfg = CfgNode.cast_from_other_class(
cfg
) # upgrade from D2's CfgNode to D2Go's CfgNode
try:
from d2go.runner import get_unintentional_added_configs_during_runner_import
for key in get_unintentional_added_configs_during_runner_import():
cfg.register_deprecated_key(key)
except ImportError:
......@@ -242,8 +244,6 @@ class Detectron2GoRunner(BaseRunner):
# build_model might modify the cfg, thus clone
cfg = cfg.clone()
# silicon_qat_build_model_context is deprecated
with silicon_qat_build_model_context(cfg):
model = build_model(cfg)
model_ema.may_build_model_ema(cfg, model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment