"docs/vscode:/vscode.git/clone" did not exist on "3e43d7b8d203df1f2e2e2f0b5c029dfebeec549b"
Commit 5a068943 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

move quantization out from modeling

Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/221

Reviewed By: tglik

Differential Revision: D35855051

fbshipit-source-id: f742dfbc91bb7a20f632a508743fa93e3a7e9aa9
parent 251eaed2
...@@ -106,7 +106,7 @@ def convert_predictor( ...@@ -106,7 +106,7 @@ def convert_predictor(
" training quantization ..." " training quantization ..."
) )
# delayed import to avoid circular import since d2go.modeling depends on d2go.export # delayed import to avoid circular import since d2go.modeling depends on d2go.export
from d2go.modeling.quantization import post_training_quantize from d2go.quantization.modeling import post_training_quantize
pytorch_model = post_training_quantize(cfg, pytorch_model, data_loader) pytorch_model = post_training_quantize(cfg, pytorch_model, data_loader)
# only check bn exists in ptq as qat still has bn inside fused ops # only check bn exists in ptq as qat still has bn inside fused ops
......
...@@ -7,7 +7,7 @@ import logging ...@@ -7,7 +7,7 @@ import logging
import torch.nn as nn import torch.nn as nn
from d2go.export.api import PredictorExportConfig from d2go.export.api import PredictorExportConfig
from d2go.modeling.quantization import set_backend_and_create_qconfig from d2go.quantization.modeling import set_backend_and_create_qconfig
from detectron2.modeling import GeneralizedRCNN from detectron2.modeling import GeneralizedRCNN
from detectron2.modeling.backbone.fpn import FPN from detectron2.modeling.backbone.fpn import FPN
from detectron2.modeling.postprocessing import detector_postprocess from detectron2.modeling.postprocessing import detector_postprocess
......
...@@ -5,7 +5,9 @@ import logging ...@@ -5,7 +5,9 @@ import logging
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import torch import torch
from d2go.utils.qat_utils import iterate_module_named_parameters
# FIXME: optimizer should not depend on quantization (or vice versa)
from d2go.quantization.learnable_qat import iterate_module_named_parameters
from detectron2.solver.build import ( from detectron2.solver.build import (
maybe_add_gradient_clipping as d2_maybe_add_gradient_clipping, maybe_add_gradient_clipping as d2_maybe_add_gradient_clipping,
reduce_param_groups, reduce_param_groups,
......
...@@ -38,7 +38,7 @@ def check_for_learnable_fake_quant_ops(qat_method, model): ...@@ -38,7 +38,7 @@ def check_for_learnable_fake_quant_ops(qat_method, model):
if qat_method == "learnable": if qat_method == "learnable":
if not _has_module(model, _LearnableFakeQuantize): if not _has_module(model, _LearnableFakeQuantize):
raise Exception( raise Exception(
"No learnable fake quant is used for learnable quantzation, please use d2go.utils.qat_utils.get_learnable_qat_qconfig() to get proper qconfig" "No learnable fake quant is used for learnable quantzation, please use d2go.quantization.learnable_qat.get_learnable_qat_qconfig() to get proper qconfig"
) )
......
...@@ -7,9 +7,9 @@ import logging ...@@ -7,9 +7,9 @@ import logging
import math import math
from typing import Tuple from typing import Tuple
import d2go.utils.qat_utils as qat_utils
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
import torch import torch
from d2go.quantization import learnable_qat
from detectron2.checkpoint import DetectionCheckpointer from detectron2.checkpoint import DetectionCheckpointer
from detectron2.engine import HookBase from detectron2.engine import HookBase
from detectron2.engine import SimpleTrainer from detectron2.engine import SimpleTrainer
...@@ -291,7 +291,7 @@ def _smart_set_backend_and_create_qconfig(cfg, *, is_train): ...@@ -291,7 +291,7 @@ def _smart_set_backend_and_create_qconfig(cfg, *, is_train):
backend=backend, is_qat=is_train, use_symmetric=is_symmetric backend=backend, is_qat=is_train, use_symmetric=is_symmetric
) )
if is_train and qat_method == "learnable": if is_train and qat_method == "learnable":
qconfig = qat_utils.convert_to_learnable_qconfig(qconfig) qconfig = learnable_qat.convert_to_learnable_qconfig(qconfig)
return qconfig return qconfig
...@@ -450,7 +450,7 @@ def setup_qat_model( ...@@ -450,7 +450,7 @@ def setup_qat_model(
torch.ao.quantization.prepare_qat(model, inplace=True) torch.ao.quantization.prepare_qat(model, inplace=True)
# make sure the proper qconfig are used in the model # make sure the proper qconfig are used in the model
qat_utils.check_for_learnable_fake_quant_ops(qat_method, model) learnable_qat.check_for_learnable_fake_quant_ops(qat_method, model)
# Move newly added observers to the original device # Move newly added observers to the original device
model.to(device) model.to(device)
...@@ -458,14 +458,14 @@ def setup_qat_model( ...@@ -458,14 +458,14 @@ def setup_qat_model(
if not enable_fake_quant: if not enable_fake_quant:
logger.info("Disabling fake quant ...") logger.info("Disabling fake quant ...")
model.apply(torch.ao.quantization.disable_fake_quant) model.apply(torch.ao.quantization.disable_fake_quant)
model.apply(qat_utils.disable_lqat_fake_quant) model.apply(learnable_qat.disable_lqat_fake_quant)
if not enable_observer: if not enable_observer:
logger.info("Disabling static observer ...") logger.info("Disabling static observer ...")
model.apply(torch.ao.quantization.disable_observer) model.apply(torch.ao.quantization.disable_observer)
model.apply(qat_utils.disable_lqat_static_observer) model.apply(learnable_qat.disable_lqat_static_observer)
if not enable_learnable_observer and qat_method == "learnable": if not enable_learnable_observer and qat_method == "learnable":
logger.info("Disabling learnable observer ...") logger.info("Disabling learnable observer ...")
model.apply(qat_utils.disable_lqat_learnable_observer) model.apply(learnable_qat.disable_lqat_learnable_observer)
# qat state dict mapper # qat state dict mapper
if not getattr(model, "_non_qat_to_qat_state_dict_map", None): if not getattr(model, "_non_qat_to_qat_state_dict_map", None):
...@@ -474,7 +474,7 @@ def setup_qat_model( ...@@ -474,7 +474,7 @@ def setup_qat_model(
) )
# qat optimizer group for learnable qat # qat optimizer group for learnable qat
model = qat_utils.setup_qat_get_optimizer_param_groups(model, qat_method) model = learnable_qat.setup_qat_get_optimizer_param_groups(model, qat_method)
return model return model
...@@ -560,7 +560,7 @@ class QATHook(HookBase): ...@@ -560,7 +560,7 @@ class QATHook(HookBase):
"[QAT] enable fake quant to start QAT, iter = {}".format(cur_iter) "[QAT] enable fake quant to start QAT, iter = {}".format(cur_iter)
) )
model.apply(torch.ao.quantization.enable_fake_quant) model.apply(torch.ao.quantization.enable_fake_quant)
model.apply(qat_utils.enable_lqat_fake_quant) model.apply(learnable_qat.enable_lqat_fake_quant)
self._applied["enable_fake_quant"] = True self._applied["enable_fake_quant"] = True
_reset_qat_data_loader_if_needed( _reset_qat_data_loader_if_needed(
...@@ -574,7 +574,7 @@ class QATHook(HookBase): ...@@ -574,7 +574,7 @@ class QATHook(HookBase):
): ):
logger.info("[QAT] enable static observer, iter = {}".format(cur_iter)) logger.info("[QAT] enable static observer, iter = {}".format(cur_iter))
model.apply(torch.ao.quantization.enable_observer) model.apply(torch.ao.quantization.enable_observer)
model.apply(qat_utils.enable_lqat_static_observer) model.apply(learnable_qat.enable_lqat_static_observer)
self._applied["enable_observer"] = True self._applied["enable_observer"] = True
if ( if (
...@@ -582,7 +582,7 @@ class QATHook(HookBase): ...@@ -582,7 +582,7 @@ class QATHook(HookBase):
and cur_iter >= cfg.QUANTIZATION.QAT.ENABLE_LEARNABLE_OBSERVER_ITER and cur_iter >= cfg.QUANTIZATION.QAT.ENABLE_LEARNABLE_OBSERVER_ITER
): ):
logger.info(f"[QAT] enabling learnable observer, iter = {cur_iter}") logger.info(f"[QAT] enabling learnable observer, iter = {cur_iter}")
model.apply(qat_utils.enable_lqat_learnable_observer) model.apply(learnable_qat.enable_lqat_learnable_observer)
self._applied["enable_learnable_observer"] = True self._applied["enable_learnable_observer"] = True
if ( if (
...@@ -593,8 +593,8 @@ class QATHook(HookBase): ...@@ -593,8 +593,8 @@ class QATHook(HookBase):
"[QAT] disabling observer for sub seq iters, iter = {}".format(cur_iter) "[QAT] disabling observer for sub seq iters, iter = {}".format(cur_iter)
) )
model.apply(torch.ao.quantization.disable_observer) model.apply(torch.ao.quantization.disable_observer)
model.apply(qat_utils.disable_lqat_static_observer) model.apply(learnable_qat.disable_lqat_static_observer)
model.apply(qat_utils.disable_lqat_learnable_observer) model.apply(learnable_qat.disable_lqat_learnable_observer)
self._applied["disable_observer"] = True self._applied["disable_observer"] = True
if ( if (
......
...@@ -7,7 +7,7 @@ import random ...@@ -7,7 +7,7 @@ import random
import torch import torch
import torch.nn as nn import torch.nn as nn
from d2go.modeling.quantization import QATCheckpointer from d2go.quantization.modeling import QATCheckpointer
from d2go.runner.default_runner import BaseRunner from d2go.runner.default_runner import BaseRunner
from d2go.utils.get_default_cfg import add_tensorboard_default_configs from d2go.utils.get_default_cfg import add_tensorboard_default_configs
from detectron2.utils.file_io import PathManager from detectron2.utils.file_io import PathManager
......
...@@ -32,12 +32,12 @@ from d2go.modeling.model_freezing_utils import ( ...@@ -32,12 +32,12 @@ from d2go.modeling.model_freezing_utils import (
freeze_matched_bn, freeze_matched_bn,
set_requires_grad, set_requires_grad,
) )
from d2go.modeling.quantization import ( from d2go.optimizer import build_optimizer_mapper
from d2go.quantization.modeling import (
QATCheckpointer, QATCheckpointer,
setup_qat_model, setup_qat_model,
QATHook, QATHook,
) )
from d2go.optimizer import build_optimizer_mapper
from d2go.utils.flop_calculator import attach_profilers from d2go.utils.flop_calculator import attach_profilers
from d2go.utils.get_default_cfg import get_default_cfg from d2go.utils.get_default_cfg import get_default_cfg
from d2go.utils.helper import TensorboardXWriter, D2Trainer from d2go.utils.helper import TensorboardXWriter, D2Trainer
......
...@@ -21,11 +21,11 @@ from d2go.modeling import build_model ...@@ -21,11 +21,11 @@ from d2go.modeling import build_model
from d2go.modeling.model_freezing_utils import ( from d2go.modeling.model_freezing_utils import (
set_requires_grad, set_requires_grad,
) )
from d2go.modeling.quantization import ( from d2go.optimizer import build_optimizer_mapper
from d2go.quantization.modeling import (
default_prepare_for_quant, default_prepare_for_quant,
default_prepare_for_quant_convert, default_prepare_for_quant_convert,
) )
from d2go.optimizer import build_optimizer_mapper
from d2go.runner.callbacks.quantization import maybe_prepare_for_quantization, PREPARED from d2go.runner.callbacks.quantization import maybe_prepare_for_quantization, PREPARED
from d2go.runner.default_runner import ( from d2go.runner.default_runner import (
Detectron2GoRunner, Detectron2GoRunner,
......
...@@ -10,8 +10,8 @@ from d2go.modeling import kmeans_anchors, model_ema ...@@ -10,8 +10,8 @@ from d2go.modeling import kmeans_anchors, model_ema
from d2go.modeling.backbone.fbnet_cfg import add_fbnet_v2_default_configs from d2go.modeling.backbone.fbnet_cfg import add_fbnet_v2_default_configs
from d2go.modeling.meta_arch.fcos import add_fcos_configs from d2go.modeling.meta_arch.fcos import add_fcos_configs
from d2go.modeling.model_freezing_utils import add_model_freezing_configs from d2go.modeling.model_freezing_utils import add_model_freezing_configs
from d2go.modeling.quantization import add_quantization_default_configs
from d2go.modeling.subclass import add_subclass_configs from d2go.modeling.subclass import add_subclass_configs
from d2go.quantization.modeling import add_quantization_default_configs
def add_tensorboard_default_configs(_C): def add_tensorboard_default_configs(_C):
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import torch import torch
from d2go.modeling.quantization import set_backend_and_create_qconfig from d2go.quantization.modeling import set_backend_and_create_qconfig
from d2go.utils.testing.data_loader_helper import create_local_dataset from d2go.utils.testing.data_loader_helper import create_local_dataset
from detectron2.modeling import META_ARCH_REGISTRY from detectron2.modeling import META_ARCH_REGISTRY
from detectron2.structures import Boxes, ImageList, Instances from detectron2.structures import Boxes, ImageList, Instances
......
...@@ -10,7 +10,7 @@ from typing import Dict ...@@ -10,7 +10,7 @@ from typing import Dict
import pytorch_lightning as pl # type: ignore import pytorch_lightning as pl # type: ignore
import torch import torch
from d2go.config import CfgNode, temp_defrost from d2go.config import CfgNode, temp_defrost
from d2go.modeling.quantization import set_backend_and_create_qconfig from d2go.quantization.modeling import set_backend_and_create_qconfig
from d2go.runner import create_runner from d2go.runner import create_runner
from d2go.runner.callbacks.quantization import ( from d2go.runner.callbacks.quantization import (
QuantizationAwareTraining, QuantizationAwareTraining,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment