Commit 5a068943 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

move quantization out from modeling

Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/221

Reviewed By: tglik

Differential Revision: D35855051

fbshipit-source-id: f742dfbc91bb7a20f632a508743fa93e3a7e9aa9
parent 251eaed2
......@@ -106,7 +106,7 @@ def convert_predictor(
" training quantization ..."
)
# delayed import to avoid circular import since d2go.modeling depends on d2go.export
from d2go.modeling.quantization import post_training_quantize
from d2go.quantization.modeling import post_training_quantize
pytorch_model = post_training_quantize(cfg, pytorch_model, data_loader)
# only check bn exists in ptq as qat still has bn inside fused ops
......
......@@ -7,7 +7,7 @@ import logging
import torch.nn as nn
from d2go.export.api import PredictorExportConfig
from d2go.modeling.quantization import set_backend_and_create_qconfig
from d2go.quantization.modeling import set_backend_and_create_qconfig
from detectron2.modeling import GeneralizedRCNN
from detectron2.modeling.backbone.fpn import FPN
from detectron2.modeling.postprocessing import detector_postprocess
......
......@@ -5,7 +5,9 @@ import logging
from typing import Any, Dict, List, Optional, Union
import torch
from d2go.utils.qat_utils import iterate_module_named_parameters
# FIXME: optimizer should not depend on quantization (or vice versa)
from d2go.quantization.learnable_qat import iterate_module_named_parameters
from detectron2.solver.build import (
maybe_add_gradient_clipping as d2_maybe_add_gradient_clipping,
reduce_param_groups,
......
......@@ -38,7 +38,7 @@ def check_for_learnable_fake_quant_ops(qat_method, model):
if qat_method == "learnable":
if not _has_module(model, _LearnableFakeQuantize):
raise Exception(
"No learnable fake quant is used for learnable quantzation, please use d2go.utils.qat_utils.get_learnable_qat_qconfig() to get proper qconfig"
"No learnable fake quant is used for learnable quantzation, please use d2go.quantization.learnable_qat.get_learnable_qat_qconfig() to get proper qconfig"
)
......
......@@ -7,9 +7,9 @@ import logging
import math
from typing import Tuple
import d2go.utils.qat_utils as qat_utils
import detectron2.utils.comm as comm
import torch
from d2go.quantization import learnable_qat
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.engine import HookBase
from detectron2.engine import SimpleTrainer
......@@ -291,7 +291,7 @@ def _smart_set_backend_and_create_qconfig(cfg, *, is_train):
backend=backend, is_qat=is_train, use_symmetric=is_symmetric
)
if is_train and qat_method == "learnable":
qconfig = qat_utils.convert_to_learnable_qconfig(qconfig)
qconfig = learnable_qat.convert_to_learnable_qconfig(qconfig)
return qconfig
......@@ -450,7 +450,7 @@ def setup_qat_model(
torch.ao.quantization.prepare_qat(model, inplace=True)
# make sure the proper qconfig are used in the model
qat_utils.check_for_learnable_fake_quant_ops(qat_method, model)
learnable_qat.check_for_learnable_fake_quant_ops(qat_method, model)
# Move newly added observers to the original device
model.to(device)
......@@ -458,14 +458,14 @@ def setup_qat_model(
if not enable_fake_quant:
logger.info("Disabling fake quant ...")
model.apply(torch.ao.quantization.disable_fake_quant)
model.apply(qat_utils.disable_lqat_fake_quant)
model.apply(learnable_qat.disable_lqat_fake_quant)
if not enable_observer:
logger.info("Disabling static observer ...")
model.apply(torch.ao.quantization.disable_observer)
model.apply(qat_utils.disable_lqat_static_observer)
model.apply(learnable_qat.disable_lqat_static_observer)
if not enable_learnable_observer and qat_method == "learnable":
logger.info("Disabling learnable observer ...")
model.apply(qat_utils.disable_lqat_learnable_observer)
model.apply(learnable_qat.disable_lqat_learnable_observer)
# qat state dict mapper
if not getattr(model, "_non_qat_to_qat_state_dict_map", None):
......@@ -474,7 +474,7 @@ def setup_qat_model(
)
# qat optimizer group for learnable qat
model = qat_utils.setup_qat_get_optimizer_param_groups(model, qat_method)
model = learnable_qat.setup_qat_get_optimizer_param_groups(model, qat_method)
return model
......@@ -560,7 +560,7 @@ class QATHook(HookBase):
"[QAT] enable fake quant to start QAT, iter = {}".format(cur_iter)
)
model.apply(torch.ao.quantization.enable_fake_quant)
model.apply(qat_utils.enable_lqat_fake_quant)
model.apply(learnable_qat.enable_lqat_fake_quant)
self._applied["enable_fake_quant"] = True
_reset_qat_data_loader_if_needed(
......@@ -574,7 +574,7 @@ class QATHook(HookBase):
):
logger.info("[QAT] enable static observer, iter = {}".format(cur_iter))
model.apply(torch.ao.quantization.enable_observer)
model.apply(qat_utils.enable_lqat_static_observer)
model.apply(learnable_qat.enable_lqat_static_observer)
self._applied["enable_observer"] = True
if (
......@@ -582,7 +582,7 @@ class QATHook(HookBase):
and cur_iter >= cfg.QUANTIZATION.QAT.ENABLE_LEARNABLE_OBSERVER_ITER
):
logger.info(f"[QAT] enabling learnable observer, iter = {cur_iter}")
model.apply(qat_utils.enable_lqat_learnable_observer)
model.apply(learnable_qat.enable_lqat_learnable_observer)
self._applied["enable_learnable_observer"] = True
if (
......@@ -593,8 +593,8 @@ class QATHook(HookBase):
"[QAT] disabling observer for sub seq iters, iter = {}".format(cur_iter)
)
model.apply(torch.ao.quantization.disable_observer)
model.apply(qat_utils.disable_lqat_static_observer)
model.apply(qat_utils.disable_lqat_learnable_observer)
model.apply(learnable_qat.disable_lqat_static_observer)
model.apply(learnable_qat.disable_lqat_learnable_observer)
self._applied["disable_observer"] = True
if (
......
......@@ -7,7 +7,7 @@ import random
import torch
import torch.nn as nn
from d2go.modeling.quantization import QATCheckpointer
from d2go.quantization.modeling import QATCheckpointer
from d2go.runner.default_runner import BaseRunner
from d2go.utils.get_default_cfg import add_tensorboard_default_configs
from detectron2.utils.file_io import PathManager
......
......@@ -32,12 +32,12 @@ from d2go.modeling.model_freezing_utils import (
freeze_matched_bn,
set_requires_grad,
)
from d2go.modeling.quantization import (
from d2go.optimizer import build_optimizer_mapper
from d2go.quantization.modeling import (
QATCheckpointer,
setup_qat_model,
QATHook,
)
from d2go.optimizer import build_optimizer_mapper
from d2go.utils.flop_calculator import attach_profilers
from d2go.utils.get_default_cfg import get_default_cfg
from d2go.utils.helper import TensorboardXWriter, D2Trainer
......
......@@ -21,11 +21,11 @@ from d2go.modeling import build_model
from d2go.modeling.model_freezing_utils import (
set_requires_grad,
)
from d2go.modeling.quantization import (
from d2go.optimizer import build_optimizer_mapper
from d2go.quantization.modeling import (
default_prepare_for_quant,
default_prepare_for_quant_convert,
)
from d2go.optimizer import build_optimizer_mapper
from d2go.runner.callbacks.quantization import maybe_prepare_for_quantization, PREPARED
from d2go.runner.default_runner import (
Detectron2GoRunner,
......
......@@ -10,8 +10,8 @@ from d2go.modeling import kmeans_anchors, model_ema
from d2go.modeling.backbone.fbnet_cfg import add_fbnet_v2_default_configs
from d2go.modeling.meta_arch.fcos import add_fcos_configs
from d2go.modeling.model_freezing_utils import add_model_freezing_configs
from d2go.modeling.quantization import add_quantization_default_configs
from d2go.modeling.subclass import add_subclass_configs
from d2go.quantization.modeling import add_quantization_default_configs
def add_tensorboard_default_configs(_C):
......
......@@ -3,7 +3,7 @@
import torch
from d2go.modeling.quantization import set_backend_and_create_qconfig
from d2go.quantization.modeling import set_backend_and_create_qconfig
from d2go.utils.testing.data_loader_helper import create_local_dataset
from detectron2.modeling import META_ARCH_REGISTRY
from detectron2.structures import Boxes, ImageList, Instances
......
......@@ -10,7 +10,7 @@ from typing import Dict
import pytorch_lightning as pl # type: ignore
import torch
from d2go.config import CfgNode, temp_defrost
from d2go.modeling.quantization import set_backend_and_create_qconfig
from d2go.quantization.modeling import set_backend_and_create_qconfig
from d2go.runner import create_runner
from d2go.runner.callbacks.quantization import (
QuantizationAwareTraining,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment