Commit 62c21f25 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

remove deprecated silicon quantization

Summary: EZ

Reviewed By: zhanghang1989

Differential Revision: D29000628

fbshipit-source-id: f954214dfe3a989fc145663f8bb1870812e78ce7
parent fc690b45
...@@ -113,33 +113,12 @@ def add_quantization_default_configs(_C): ...@@ -113,33 +113,12 @@ def add_quantization_default_configs(_C):
_C.QUANTIZATION.PTQ.CALIBRATION_NUM_IMAGES = 1 _C.QUANTIZATION.PTQ.CALIBRATION_NUM_IMAGES = 1
_C.QUANTIZATION.PTQ.CALIBRATION_FORCE_ON_GPU = False _C.QUANTIZATION.PTQ.CALIBRATION_FORCE_ON_GPU = False
# deprecated
_C.QUANTIZATION.SILICON_QAT = CfgNode()
_C.QUANTIZATION.SILICON_QAT.ENABLED = False
# register deprecated and renamed keys # register deprecated and renamed keys
_C.register_deprecated_key("QUANTIZATION.QAT.LOAD_PRETRAINED") _C.register_deprecated_key("QUANTIZATION.QAT.LOAD_PRETRAINED")
_C.register_renamed_key("QUANTIZATION.QAT.BACKEND", "QUANTIZATION.BACKEND") _C.register_renamed_key("QUANTIZATION.QAT.BACKEND", "QUANTIZATION.BACKEND")
_C.register_deprecated_key("QUANTIZATION.ENABLE_CUSTOM_QSCHEME") _C.register_deprecated_key("QUANTIZATION.ENABLE_CUSTOM_QSCHEME")
_C.register_deprecated_key("QUANTIZATION.SILICON_QAT")
_C.register_deprecated_key("QUANTIZATION.SILICON_QAT.ENABLED")
@contextlib.contextmanager
def silicon_qat_build_model_context(cfg):
mock_ctx_managers = []
if cfg.QUANTIZATION.SILICON_QAT.ENABLED:
from mobile_cv.silicon_pytorch_qat.replace_op import mock_quant_ops
mock_ctx_managers.extend(
[
mock_quant_ops(quant_op="quant_add"),
mock_quant_ops(quant_op="quant_fbb_convbnrelu"),
]
)
with contextlib.ExitStack() as stack:
for mgr in mock_ctx_managers:
stack.enter_context(mgr)
yield
# TODO: model.to(device) might not work for detection meta-arch, this function is the # TODO: model.to(device) might not work for detection meta-arch, this function is the
...@@ -168,26 +147,29 @@ def _cast_detection_model(model, device): ...@@ -168,26 +147,29 @@ def _cast_detection_model(model, device):
def add_d2_quant_mapping(mappings): def add_d2_quant_mapping(mappings):
""" HACK: Add d2 specific module mapping for eager model quantization """HACK: Add d2 specific module mapping for eager model quantization"""
"""
import torch.quantization.quantization_mappings as qm import torch.quantization.quantization_mappings as qm
for k, v in mappings.items(): for k, v in mappings.items():
if k not in qm.get_default_static_quant_module_mappings(): if k not in qm.get_default_static_quant_module_mappings():
qm.DEFAULT_STATIC_QUANT_MODULE_MAPPINGS[k] = v qm.DEFAULT_STATIC_QUANT_MODULE_MAPPINGS[k] = v
if k not in qm.get_default_qat_module_mappings(): if k not in qm.get_default_qat_module_mappings():
qm.DEFAULT_QAT_MODULE_MAPPINGS[k] = v qm.DEFAULT_QAT_MODULE_MAPPINGS[k] = v
# The `mock_quantization_type` decorate may not be needed anymore to unify # The `mock_quantization_type` decorate may not be needed anymore to unify
# detectron2.layers modules and torch.nn modules since Pytorch 1.5. See comments on D23790034. # detectron2.layers modules and torch.nn modules since Pytorch 1.5. See comments on D23790034.
def mock_quantization_type(quant_func): def mock_quantization_type(quant_func):
import mock
import builtins import builtins
import functools import functools
import detectron2.layers as d2l import detectron2.layers as d2l
import mock
type_mapping = {d2l.Linear: torch.nn.Linear} type_mapping = {d2l.Linear: torch.nn.Linear}
from d2go.utils.misc import check_version from d2go.utils.misc import check_version
if check_version(torch, '1.7.2', warning_only=True):
if check_version(torch, "1.7.2", warning_only=True):
add_d2_quant_mapping(type_mapping) add_d2_quant_mapping(type_mapping)
real_type = builtins.type real_type = builtins.type
...@@ -270,13 +252,14 @@ def default_prepare_for_quant(cfg, model): ...@@ -270,13 +252,14 @@ def default_prepare_for_quant(cfg, model):
return model return model
def default_prepare_for_quant_convert(cfg, model): def default_prepare_for_quant_convert(cfg, model):
return torch.quantization.quantize_fx.convert_fx(model) return torch.quantization.quantize_fx.convert_fx(model)
@mock_quantization_type @mock_quantization_type
def post_training_quantize(cfg, model, data_loader): def post_training_quantize(cfg, model, data_loader):
""" Calibrate a model, convert it to a quantized pytorch model """ """Calibrate a model, convert it to a quantized pytorch model"""
model = copy.deepcopy(model) model = copy.deepcopy(model)
model.eval() model.eval()
# TODO: check why some parameters will have gradient # TODO: check why some parameters will have gradient
...@@ -372,7 +355,9 @@ def setup_qat_model(cfg, model, enable_fake_quant=False, enable_observer=False): ...@@ -372,7 +355,9 @@ def setup_qat_model(cfg, model, enable_fake_quant=False, enable_observer=False):
assert len(new_state_dict_non_observer_keys) == len(original_state_dict_shapes) assert len(new_state_dict_non_observer_keys) == len(original_state_dict_shapes)
if cfg.QUANTIZATION.EAGER_MODE: if cfg.QUANTIZATION.EAGER_MODE:
for n_k, o_k in zip(new_state_dict_non_observer_keys, original_state_dict_shapes): for n_k, o_k in zip(
new_state_dict_non_observer_keys, original_state_dict_shapes
):
assert new_state_dict_shapes[n_k] == original_state_dict_shapes[o_k] assert new_state_dict_shapes[n_k] == original_state_dict_shapes[o_k]
# _q_state_dict_map will store # _q_state_dict_map will store
model._non_qat_to_qat_state_dict_map = dict( model._non_qat_to_qat_state_dict_map = dict(
......
...@@ -36,7 +36,6 @@ from d2go.modeling.model_freezing_utils import ( ...@@ -36,7 +36,6 @@ from d2go.modeling.model_freezing_utils import (
from d2go.modeling.quantization import ( from d2go.modeling.quantization import (
QATCheckpointer, QATCheckpointer,
setup_qat_model, setup_qat_model,
silicon_qat_build_model_context,
) )
from d2go.optimizer import build_optimizer_mapper from d2go.optimizer import build_optimizer_mapper
from d2go.utils.flop_calculator import add_print_flops_callback from d2go.utils.flop_calculator import add_print_flops_callback
...@@ -151,7 +150,7 @@ class BaseRunner(object): ...@@ -151,7 +150,7 @@ class BaseRunner(object):
torch._C._log_api_usage_once(identifier) torch._C._log_api_usage_once(identifier)
def _initialize(self, cfg): def _initialize(self, cfg):
""" Runner should be initialized in the sub-process in ddp setting """ """Runner should be initialized in the sub-process in ddp setting"""
if getattr(self, "_has_initialized", False): if getattr(self, "_has_initialized", False):
logger.warning("Runner has already been initialized, skip initialization.") logger.warning("Runner has already been initialized, skip initialization.")
return return
...@@ -174,10 +173,13 @@ class BaseRunner(object): ...@@ -174,10 +173,13 @@ class BaseRunner(object):
from detectron2.config import get_cfg as get_d2_cfg from detectron2.config import get_cfg as get_d2_cfg
cfg = get_d2_cfg() cfg = get_d2_cfg()
cfg = CfgNode.cast_from_other_class(cfg) # upgrade from D2's CfgNode to D2Go's CfgNode cfg = CfgNode.cast_from_other_class(
cfg
) # upgrade from D2's CfgNode to D2Go's CfgNode
try: try:
from d2go.runner import get_unintentional_added_configs_during_runner_import from d2go.runner import get_unintentional_added_configs_during_runner_import
for key in get_unintentional_added_configs_during_runner_import(): for key in get_unintentional_added_configs_during_runner_import():
cfg.register_deprecated_key(key) cfg.register_deprecated_key(key)
except ImportError: except ImportError:
...@@ -242,10 +244,8 @@ class Detectron2GoRunner(BaseRunner): ...@@ -242,10 +244,8 @@ class Detectron2GoRunner(BaseRunner):
# build_model might modify the cfg, thus clone # build_model might modify the cfg, thus clone
cfg = cfg.clone() cfg = cfg.clone()
# silicon_qat_build_model_context is deprecated model = build_model(cfg)
with silicon_qat_build_model_context(cfg): model_ema.may_build_model_ema(cfg, model)
model = build_model(cfg)
model_ema.may_build_model_ema(cfg, model)
if cfg.MODEL.FROZEN_LAYER_REG_EXP: if cfg.MODEL.FROZEN_LAYER_REG_EXP:
set_requires_grad(model, cfg.MODEL.FROZEN_LAYER_REG_EXP, False) set_requires_grad(model, cfg.MODEL.FROZEN_LAYER_REG_EXP, False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment