"...git@developer.sourcefind.cn:OpenDAS/ktransformers.git" did not exist on "5aee6c0446b786e2ef3a7e6077a8456417e168ec"
Commit b375c290 authored by Olga Gerasimova's avatar Olga Gerasimova Committed by Facebook GitHub Bot
Browse files

disable_fake_quant on 0 step

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/623

If we load model d2go/runner/default_runner.py?lines=567
that was had enable_fake_quant, than on_begin_train we need to disable it.

Reviewed By: jiaxuzhu92

Differential Revision: D49911356

fbshipit-source-id: f51b2a043c0c3f754d5698eb4b5d968a28d601d1
parent 27918553
...@@ -594,6 +594,18 @@ class QATHook(HookBase): ...@@ -594,6 +594,18 @@ class QATHook(HookBase):
model = self.trainer.model model = self.trainer.model
cfg = self.cfg cfg = self.cfg
# if we load model in enable_fake_quant state, we need to disable fake quant again, if QAT.START_ITER > 0
if cur_iter < cfg.QUANTIZATION.QAT.START_ITER and cur_iter == 0:
logger.info(
"[QAT] disable fake quant to start QAT, iter = {}".format(cur_iter)
)
model.apply(torch.ao.quantization.disable_fake_quant)
model.apply(learnable_qat.disable_lqat_fake_quant)
self._applied["enable_fake_quant"] = False
_reset_qat_data_loader_if_needed(
self.cfg, self.trainer, self.build_data_loader_func
)
if ( if (
not self._applied["enable_fake_quant"] not self._applied["enable_fake_quant"]
and cur_iter >= cfg.QUANTIZATION.QAT.START_ITER and cur_iter >= cfg.QUANTIZATION.QAT.START_ITER
...@@ -609,6 +621,11 @@ class QATHook(HookBase): ...@@ -609,6 +621,11 @@ class QATHook(HookBase):
self.cfg, self.trainer, self.build_data_loader_func self.cfg, self.trainer, self.build_data_loader_func
) )
if cur_iter < cfg.QUANTIZATION.QAT.ENABLE_OBSERVER_ITER and cur_iter == 0:
logger.info("[QAT] disable static observer, iter = {}".format(cur_iter))
model.apply(torch.ao.quantization.disable_observer)
model.apply(learnable_qat.disable_lqat_static_observer)
self._applied["disable_observer"] = False
if ( if (
not self._applied["enable_observer"] not self._applied["enable_observer"]
and cur_iter >= cfg.QUANTIZATION.QAT.ENABLE_OBSERVER_ITER and cur_iter >= cfg.QUANTIZATION.QAT.ENABLE_OBSERVER_ITER
...@@ -619,6 +636,13 @@ class QATHook(HookBase): ...@@ -619,6 +636,13 @@ class QATHook(HookBase):
model.apply(learnable_qat.enable_lqat_static_observer) model.apply(learnable_qat.enable_lqat_static_observer)
self._applied["enable_observer"] = True self._applied["enable_observer"] = True
if (
cur_iter < cfg.QUANTIZATION.QAT.ENABLE_LEARNABLE_OBSERVER_ITER
and cur_iter == 0
):
logger.info(f"[QAT] disabling learnable observer, iter = {cur_iter}")
model.apply(learnable_qat.disable_lqat_learnable_observer)
self._applied["disable_learnable_observer"] = False
if ( if (
not self._applied["enable_learnable_observer"] not self._applied["enable_learnable_observer"]
and cur_iter >= cfg.QUANTIZATION.QAT.ENABLE_LEARNABLE_OBSERVER_ITER and cur_iter >= cfg.QUANTIZATION.QAT.ENABLE_LEARNABLE_OBSERVER_ITER
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment