Commit 845d0b2c authored by Kai Zhang's avatar Kai Zhang Committed by Facebook GitHub Bot
Browse files

E2E QAT Workflow on Lightning

Summary:
As per title and sanity test E2E QAT workflow on Lightning Trainer.

- add `post_training_opts`. This is required to use `all_steps_qat.json` with Lightning. We don't actually support the post_training_opts in this diff though - we leave it part of T83437359.
- Update .yaml to specify the Quantize-able modules.
- Update `lightning_train_net.py` to use the QuantizationAwareTraining callback.

Reviewed By: kandluis

Differential Revision: D26304879

fbshipit-source-id: 948bef4817d385d8a0969e4990d7f17ecd6994b7
parent bd4ba04d
......@@ -79,6 +79,9 @@ def add_quantization_default_configs(_C):
# used to enable metarch set_custom_qscheme (need to implement)
# this is a limited implementation where only str is provided to change options
_C.QUANTIZATION.CUSTOM_QSCHEME = ""
_C.QUANTIZATION.MODULES = None
# Lightning quantization callback name
_C.QUANTIZATION.NAME = ""
# quantization-aware training
_C.QUANTIZATION.QAT = CfgNode()
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from detectron2.utils.registry import Registry
from d2go.config import CfgNode
CALLBACK_REGISTRY = Registry("D2GO_CALLBACK_REGISTRY")
def build_quantization_callback(cfg: CfgNode):
return CALLBACK_REGISTRY.get(cfg.QUANTIZATION.NAME).from_config(cfg)
......@@ -7,7 +7,10 @@ from types import MethodType
from typing import Any, Callable, Dict, List, Set, Optional, Tuple, Union
import torch
from d2go.config import CfgNode
from d2go.utils.misc import mode
from d2go.runner.callbacks.build import CALLBACK_REGISTRY
from mobile_cv.arch.quantization.observer import update_stat as observer_update_stat
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import rank_zero_info
......@@ -235,6 +238,7 @@ class ModelTransform:
raise ValueError("interval must be positive.")
@CALLBACK_REGISTRY.register()
class QuantizationAwareTraining(Callback, QuantizationMixin):
"""Enable QAT of a model using the STL Trainer.
......@@ -396,6 +400,29 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
}
self.quantized: Optional[torch.nn.Module] = None
@classmethod
def from_config(cls, cfg: CfgNode):
qat = cfg.QUANTIZATION.QAT
callback = cls(
qconfig_dicts={submodule: None for submodule in cfg.QUANTIZATION.MODULES}
if cfg.QUANTIZATION.MODULES
else None,
# We explicitly pass this to maintain properties for now.
preserved_attrs=["model.backbone.size_divisibility"],
start_step=qat.START_ITER,
enable_observer=(qat.ENABLE_OBSERVER_ITER, qat.DISABLE_OBSERVER_ITER),
freeze_bn_step=qat.FREEZE_BN_ITER,
)
if qat.UPDATE_OBSERVER_STATS_PERIODICALLY:
callback.transforms.append(
ModelTransform(
interval=qat.UPDATE_OBSERVER_STATS_PERIOD,
fb=observer_update_stat,
message="Updating observers.",
)
)
return callback
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
"""Override the model with a quantized-aware version on setup.
......@@ -470,6 +497,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
self.quantized = pl_module._quantized
@CALLBACK_REGISTRY.register()
class PostTrainingQuantization(Callback, QuantizationMixin):
"""Enable post-training quantization, such as dynamic, static, and weight-only.
......@@ -527,6 +555,16 @@ class PostTrainingQuantization(Callback, QuantizationMixin):
self.quantized: Optional[torch.nn.Module] = None
self.should_calibrate = _requires_calibration(self.qconfig_dicts)
@classmethod
def from_config(cls, cfg: CfgNode):
return cls(
qconfig_dicts={submodule: None for submodule in cfg.QUANTIZATION.MODULES}
if cfg.QUANTIZATION.MODULES
else None,
# We explicitly pass this to maintain properties for now.
preserved_attrs=["model.backbone.size_divisibility"],
)
def on_validation_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""
On validation start, prepare a module for quantization by adding
......
......@@ -62,6 +62,12 @@ def _convert_to_lightning(d2_checkpoint: Dict[str, Any]) -> None:
][key]
del d2_checkpoint[_OLD_STATE_DICT_KEY][key]
if "model.pixel_mean" in d2_checkpoint[_OLD_STATE_DICT_KEY]:
del d2_checkpoint[_OLD_STATE_DICT_KEY]["model.pixel_mean"]
if "model.pixel_std" in d2_checkpoint[_OLD_STATE_DICT_KEY]:
del d2_checkpoint[_OLD_STATE_DICT_KEY]["model.pixel_std"]
for old, new in zip(
[_OLD_STATE_DICT_KEY, "iteration"], [_STATE_DICT_KEY, "global_step"]
):
......
......@@ -4,13 +4,17 @@
import logging
import os
from d2go.runner.callbacks.build import build_quantization_callback
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Type
import pytorch_lightning as pl # type: ignore
from d2go.config import CfgNode, temp_defrost
from d2go.runner import create_runner
from d2go.runner.callbacks.quantization import QuantizationAwareTraining
from d2go.runner.callbacks.quantization import (
QuantizationAwareTraining,
ModelTransform,
)
from d2go.runner.lightning_task import GeneralizedRCNNTask
from d2go.setup import basic_argument_parser
from d2go.utils.misc import dump_trained_model_configs
......@@ -63,20 +67,8 @@ def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]:
save_last=True,
),
]
if cfg.QUANTIZATION.QAT.ENABLED:
qat = cfg.QUANTIZATION.QAT
callbacks.append(
QuantizationAwareTraining(
qconfig_dicts={
submodule: None for submodule in cfg.QUANTIZATION.MODULES
}
if cfg.QUANTIZATION.MODULES
else None,
start_step=qat.START_ITER,
enable_observer=(qat.ENABLE_OBSERVER_ITER, qat.DISABLE_OBSERVER_ITER),
freeze_bn_step=qat.FREEZE_BN_ITER,
)
)
if cfg.QUANTIZATION.NAME and cfg.QUANTIZATION.QAT.ENABLED:
callbacks.append(build_quantization_callback(cfg))
return callbacks
......@@ -185,6 +177,11 @@ def main(
else:
model_configs = do_train(cfg, trainer, task)
for cb in trainer_params["callbacks"]:
if isinstance(cb, QuantizationAwareTraining):
print("################ quantized #################")
print(cb.quantized)
return TrainOutput(
output_dir=cfg.OUTPUT_DIR,
tensorboard_log_dir=tb_logger.log_dir,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment