Commit aeb24a92 authored by Kai Zhang's avatar Kai Zhang Committed by Facebook GitHub Bot
Browse files

Delegate to model's customization

Summary: Delegate FX quantization callback's customization to model.

Reviewed By: wat3rBro

Differential Revision: D27669212

fbshipit-source-id: 2715546cf03134896da6f95ecddaf8503ff95d0b
parent 845d0b2c
......@@ -7,6 +7,7 @@ from functools import lru_cache
import torch
from d2go.export.api import PredictorExportConfig
from d2go.utils.prepare_for_export import d2_meta_arch_prepare_for_export
from detectron2.export.caffe2_modeling import (
META_ARCH_CAFFE2_EXPORT_TYPE_MAP,
convert_batched_inputs_to_c2_format,
......@@ -21,7 +22,7 @@ from mobile_cv.arch.utils.quantize_utils import (
QuantWrapper,
)
from mobile_cv.predictor.api import FuncInfo
from d2go.utils.prepare_for_export import d2_meta_arch_prepare_for_export
from torch.quantization.quantize_fx import prepare_fx, prepare_qat_fx, convert_fx
logger = logging.getLogger(__name__)
......@@ -41,6 +42,16 @@ def patch_d2_meta_arch():
else:
cls_obj.prepare_for_quant = d2_meta_arch_prepare_for_quant
if hasattr(cls_obj, "prepare_for_quant_convert"):
assert (
cls_obj.prepare_for_quant_convert
== d2_meta_arch_prepare_for_quant_convert
)
else:
cls_obj.prepare_for_quant_convert = (
d2_meta_arch_prepare_for_quant_convert
)
def _apply_eager_mode_quant(cfg, model):
......@@ -107,15 +118,37 @@ def _apply_eager_mode_quant(cfg, model):
return model
def d2_meta_arch_prepare_for_quant(self, cfg):
model = self
# Modify the model for eager mode
if cfg.QUANTIZATION.EAGER_MODE:
model = _apply_eager_mode_quant(cfg, model)
def _fx_quant_prepare(self, cfg):
prep_fn = prepare_qat_fx if self.training else prepare_fx
qconfig = {"": self.qconfig}
self.backbone = prep_fn(
self.backbone,
qconfig,
{"preserved_attributes": ["size_divisibility"]},
)
self.proposal_generator.rpn_head.rpn_feature = prep_fn(
self.proposal_generator.rpn_head.rpn_feature, qconfig
)
self.proposal_generator.rpn_head.rpn_regressor.cls_logits = prep_fn(
self.proposal_generator.rpn_head.rpn_regressor.cls_logits, qconfig
)
self.proposal_generator.rpn_head.rpn_regressor.bbox_pred = prep_fn(
self.proposal_generator.rpn_head.rpn_regressor.bbox_pred, qconfig
)
self.roi_heads.box_head.roi_box_conv = prep_fn(
self.roi_heads.box_head.roi_box_conv, qconfig
)
self.roi_heads.box_head.avgpool = prep_fn(self.roi_heads.box_head.avgpool, qconfig)
self.roi_heads.box_predictor.cls_score = prep_fn(
self.roi_heads.box_predictor.cls_score, qconfig
)
self.roi_heads.box_predictor.bbox_pred = prep_fn(
self.roi_heads.box_predictor.bbox_pred, qconfig
)
model = fuse_utils.fuse_model(model, inplace=True)
def d2_meta_arch_prepare_for_quant(self, cfg):
model = self
torch.backends.quantized.engine = cfg.QUANTIZATION.BACKEND
model.qconfig = (
torch.quantization.get_default_qat_qconfig(cfg.QUANTIZATION.BACKEND)
......@@ -124,4 +157,41 @@ def d2_meta_arch_prepare_for_quant(self, cfg):
)
logger.info("Setup the model with qconfig:\n{}".format(model.qconfig))
# Modify the model for eager mode
if cfg.QUANTIZATION.EAGER_MODE:
model = _apply_eager_mode_quant(cfg, model)
model = fuse_utils.fuse_model(model, inplace=True)
else:
_fx_quant_prepare(model, cfg)
return model
def d2_meta_arch_prepare_for_quant_convert(self, cfg):
if cfg.QUANTIZATION.EAGER_MODE:
raise NotImplementedError()
self.backbone = convert_fx(
self.backbone,
convert_custom_config_dict={"preserved_attributes": ["size_divisibility"]},
)
self.proposal_generator.rpn_head.rpn_feature = convert_fx(
self.proposal_generator.rpn_head.rpn_feature
)
self.proposal_generator.rpn_head.rpn_regressor.cls_logits = convert_fx(
self.proposal_generator.rpn_head.rpn_regressor.cls_logits
)
self.proposal_generator.rpn_head.rpn_regressor.bbox_pred = convert_fx(
self.proposal_generator.rpn_head.rpn_regressor.bbox_pred
)
self.roi_heads.box_head.roi_box_conv = convert_fx(
self.roi_heads.box_head.roi_box_conv
)
self.roi_heads.box_head.avgpool = convert_fx(self.roi_heads.box_head.avgpool)
self.roi_heads.box_predictor.cls_score = convert_fx(
self.roi_heads.box_predictor.cls_score
)
self.roi_heads.box_predictor.bbox_pred = convert_fx(
self.roi_heads.box_predictor.bbox_pred
)
return self
......@@ -270,6 +270,9 @@ def default_prepare_for_quant(cfg, model):
return model
def default_prepare_for_quant_convert(cfg, model):
return torch.quantization.quantize_fx.convert_fx(model)
@mock_quantization_type
def post_training_quantize(cfg, model, data_loader):
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from detectron2.utils.registry import Registry
from d2go.config import CfgNode
CALLBACK_REGISTRY = Registry("D2GO_CALLBACK_REGISTRY")
def build_quantization_callback(cfg: CfgNode):
return CALLBACK_REGISTRY.get(cfg.QUANTIZATION.NAME).from_config(cfg)
......@@ -9,7 +9,6 @@ from typing import Any, Callable, Dict, List, Set, Optional, Tuple, Union
import torch
from d2go.config import CfgNode
from d2go.utils.misc import mode
from d2go.runner.callbacks.build import CALLBACK_REGISTRY
from mobile_cv.arch.quantization.observer import update_stat as observer_update_stat
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback
......@@ -26,6 +25,7 @@ from torch.quantization.utils import get_quant_type
QConfigDicts = Dict[str, Dict[str, Union[QConfig, QConfigDynamic]]]
PREPARED = "_prepared"
def rsetattr(obj: Any, attr: str, val: Any) -> None:
......@@ -96,6 +96,20 @@ def _requires_calibration(config_dicts: QConfigDicts) -> bool:
return False
def checkpoint_has_prepared(checkpoint: Dict[str, Any]) -> bool:
return any(k.startswith(PREPARED) for k in checkpoint["state_dict"].keys())
def maybe_prepare_for_quantization(model: LightningModule, checkpoint: Dict[str, Any]):
if checkpoint_has_prepared(checkpoint) and not hasattr(model, PREPARED):
# model has been prepared for QAT before saving into checkpoint
setattr(
model,
PREPARED,
_deepcopy(model).prepare_for_quant()
)
class QuantizationMixin(ABC):
"""Mixin defining an overrideable API for quantization customization.
......@@ -155,6 +169,8 @@ class QuantizationMixin(ABC):
Returns:
The prepared Module to be used for quantized aware training.
"""
if hasattr(root, "prepare_for_quant"):
return root.prepare_for_quant()
prep_fn = (
prepare_qat_fx
if isinstance(self, QuantizationAwareTraining)
......@@ -191,6 +207,8 @@ class QuantizationMixin(ABC):
Returns:
The quantized model.
"""
if hasattr(root, "prepare_for_quant_convert"):
return root.prepare_for_quant_convert()
old_attrs = {
attr: rgetattr(root, attr) for attr in attrs if rhasattr(root, attr)
}
......@@ -238,7 +256,6 @@ class ModelTransform:
raise ValueError("interval must be positive.")
@CALLBACK_REGISTRY.register()
class QuantizationAwareTraining(Callback, QuantizationMixin):
"""Enable QAT of a model using the STL Trainer.
......@@ -305,6 +322,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
Dict[str, Optional[Dict[str, Union[QConfig, QConfigDynamic]]]]
] = None,
preserved_attrs: Optional[List[str]] = None,
skip_conversion: bool = False,
) -> None:
"""
Args:
......@@ -399,6 +417,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
for key, value in qconfig_dicts.items()
}
self.quantized: Optional[torch.nn.Module] = None
self.skip_conversion = skip_conversion
@classmethod
def from_config(cls, cfg: CfgNode):
......@@ -412,6 +431,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
start_step=qat.START_ITER,
enable_observer=(qat.ENABLE_OBSERVER_ITER, qat.DISABLE_OBSERVER_ITER),
freeze_bn_step=qat.FREEZE_BN_ITER,
skip_conversion=True, # convert_fx will be handled by D2Go exporter
)
if qat.UPDATE_OBSERVER_STATS_PERIODICALLY:
callback.transforms.append(
......@@ -477,7 +497,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
""" Quantize the weights since training has finalized. """
if hasattr(pl_module, "_quantized"):
if hasattr(pl_module, "_quantized") or self.skip_conversion:
return
pl_module._quantized = self.convert(
pl_module._prepared, self.qconfig_dicts.keys(), attrs=self.preserved_attrs
......@@ -497,7 +517,6 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
self.quantized = pl_module._quantized
@CALLBACK_REGISTRY.register()
class PostTrainingQuantization(Callback, QuantizationMixin):
"""Enable post-training quantization, such as dynamic, static, and weight-only.
......
......@@ -24,12 +24,14 @@ from d2go.runner.default_runner import (
)
from d2go.setup import setup_after_launch
from d2go.utils.ema_state import EMAState
from d2go.runner.callbacks.quantization import maybe_prepare_for_quantization, PREPARED
from detectron2.modeling import build_model
from detectron2.solver import (
build_lr_scheduler as d2_build_lr_scheduler,
build_optimizer as d2_build_optimizer,
)
from pytorch_lightning.utilities import rank_zero_info
from d2go.modeling.quantization import default_prepare_for_quant, default_prepare_for_quant_convert
_STATE_DICT_KEY = "state_dict"
_OLD_STATE_DICT_KEY = "model"
......@@ -157,7 +159,8 @@ class DefaultTask(pl.LightningModule):
@classmethod
def build_model(cls, cfg: CfgNode, eval_only=False):
"""Builds D2go model instance from config.
"""Builds D2go model instance from config. If model has been prepared
for quantization, the function returns the prepared model.
NOTE: For backward compatible with existing D2Go tools. Prefer
`from_config` in other use cases.
......@@ -165,7 +168,10 @@ class DefaultTask(pl.LightningModule):
cfg: D2go config node.
eval_only: True if model should be in eval mode.
"""
return cls.from_config(cfg, eval_only).model
task = cls.from_config(cfg, eval_only)
if hasattr(task, PREPARED):
task = getattr(task, PREPARED)
return task.model
@classmethod
def get_default_cfg(cls):
......@@ -362,6 +368,8 @@ class DefaultTask(pl.LightningModule):
if not _is_lightning_checkpoint(checkpointed_state):
_convert_to_lightning(checkpointed_state)
maybe_prepare_for_quantization(self, checkpointed_state)
if self.ema_state:
if "model_ema" not in checkpointed_state:
rank_zero_info(
......@@ -374,6 +382,20 @@ class DefaultTask(pl.LightningModule):
# EMA state device not given, move to module device
self.ema_state.to(self.device)
def prepare_for_quant(self) -> pl.LightningModule:
if hasattr(self.model, "prepare_for_quant"):
self.model = self.model.prepare_for_quant(self.cfg)
else:
self.model = default_prepare_for_quant(self.cfg, self.model)
return self
def prepare_for_quant_convert(self) -> pl.LightningModule:
if hasattr(self.model, "prepare_for_quant_convert"):
self.model = self.model.prepare_for_quant_convert(self.cfg)
else:
self.model = default_prepare_for_quant_convert(self.cfg, self.model)
return self
class GeneralizedRCNNTask(DefaultTask):
@classmethod
......
......@@ -6,6 +6,7 @@ import torch
from d2go.utils.testing.data_loader_helper import create_local_dataset
from detectron2.modeling import META_ARCH_REGISTRY
from detectron2.structures import Boxes, ImageList, Instances
from torch.quantization.quantize_fx import prepare_qat_fx, convert_fx
@META_ARCH_REGISTRY.register()
......@@ -50,6 +51,17 @@ class DetMetaArchForTest(torch.nn.Module):
ret = [{"instances": instance}]
return ret
def prepare_for_quant(self, cfg):
self.avgpool = prepare_qat_fx(
self.avgpool,
{"": torch.quantization.get_default_qat_qconfig()},
)
return self
def prepare_for_quant_convert(self, cfg):
self.avgpool = convert_fx(self.avgpool)
return self
def get_det_meta_arch_cfg(cfg, dataset_name, output_dir):
cfg.MODEL.DEVICE = "cpu"
......
......@@ -11,12 +11,18 @@ import pytorch_lightning as pl # type: ignore
import torch
from d2go.config import CfgNode, temp_defrost
from d2go.runner import create_runner
from d2go.runner.callbacks.quantization import (
QuantizationAwareTraining,
)
from d2go.runner.lightning_task import GeneralizedRCNNTask
from d2go.utils.testing import meta_arch_helper as mah
from d2go.utils.testing.helper import tempdir
from detectron2.modeling import META_ARCH_REGISTRY
from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN
from detectron2.utils.events import EventStorage
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from torch import Tensor
from torch.quantization.quantize_fx import prepare_qat_fx, convert_fx
class TestLightningTask(unittest.TestCase):
......@@ -175,3 +181,80 @@ class TestLightningTask(unittest.TestCase):
model.state_dict(), task.ema_state.state_dict()
)
)
@tempdir
def test_qat(self, tmp_dir):
@META_ARCH_REGISTRY.register()
class QuantizableDetMetaArchForTest(mah.DetMetaArchForTest):
custom_config_dict = {"preserved_attributes": ["preserved_attr"]}
def __init__(self, cfg):
super().__init__(cfg)
self.avgpool.preserved_attr = "foo"
self.avgpool.not_preserved_attr = "bar"
def prepare_for_quant(self, cfg):
self.avgpool = prepare_qat_fx(
self.avgpool,
{"": torch.quantization.get_default_qat_qconfig()},
self.custom_config_dict,
)
return self
def prepare_for_quant_convert(self, cfg):
self.avgpool = convert_fx(
self.avgpool, convert_custom_config_dict=self.custom_config_dict
)
return self
cfg = self._get_cfg(tmp_dir)
cfg.MODEL.META_ARCHITECTURE = "QuantizableDetMetaArchForTest"
cfg.QUANTIZATION.QAT.ENABLED = True
task = GeneralizedRCNNTask(cfg)
callbacks = [
QuantizationAwareTraining.from_config(cfg),
ModelCheckpoint(dirpath=task.cfg.OUTPUT_DIR, save_last=True),
]
trainer = pl.Trainer(
max_steps=1,
limit_train_batches=1,
num_sanity_val_steps=0,
callbacks=callbacks,
logger=None,
)
with EventStorage() as storage:
task.storage = storage
trainer.fit(task)
prepared_avgpool = task._prepared.model.avgpool
self.assertEqual(prepared_avgpool.preserved_attr, "foo")
self.assertFalse(hasattr(prepared_avgpool, "not_preserved_attr"))
with temp_defrost(cfg):
cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
model = GeneralizedRCNNTask.build_model(cfg, eval_only=True)
self.assertTrue(isinstance(model.avgpool, torch.fx.GraphModule))
@tempdir
def test_generalized_rcnn_qat(self, tmp_dir):
cfg = GeneralizedRCNNTask.get_default_cfg()
cfg.merge_from_file("detectron2go://e2e_mask_rcnn_fbnet_600_qat.yaml")
cfg.MODEL.DEVICE = "cpu"
cfg.QUANTIZATION.EAGER_MODE = False
cfg.OUTPUT_DIR = tmp_dir
task = GeneralizedRCNNTask(cfg)
callbacks = [
QuantizationAwareTraining.from_config(cfg),
ModelCheckpoint(dirpath=task.cfg.OUTPUT_DIR, save_last=True),
]
trainer = pl.Trainer(
max_steps=1,
limit_train_batches=1,
num_sanity_val_steps=0,
callbacks=callbacks,
logger=None,
)
with EventStorage() as storage:
task.storage = storage
trainer.fit(task)
......@@ -4,7 +4,6 @@
import logging
import os
from d2go.runner.callbacks.build import build_quantization_callback
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Type
......@@ -67,8 +66,8 @@ def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]:
save_last=True,
),
]
if cfg.QUANTIZATION.NAME and cfg.QUANTIZATION.QAT.ENABLED:
callbacks.append(build_quantization_callback(cfg))
if cfg.QUANTIZATION.QAT.ENABLED:
callbacks.append(QuantizationAwareTraining.from_config(cfg))
return callbacks
......@@ -177,11 +176,6 @@ def main(
else:
model_configs = do_train(cfg, trainer, task)
for cb in trainer_params["callbacks"]:
if isinstance(cb, QuantizationAwareTraining):
print("################ quantized #################")
print(cb.quantized)
return TrainOutput(
output_dir=cfg.OUTPUT_DIR,
tensorboard_log_dir=tb_logger.log_dir,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment