Commit 845d0b2c authored by Kai Zhang's avatar Kai Zhang Committed by Facebook GitHub Bot
Browse files

E2E QAT Workflow on Lightning

Summary:
As per title and sanity test E2E QAT workflow on Lightning Trainer.

- add `post_training_opts`. This is required to use `all_steps_qat.json` with Lightning. We don't actually support the post_training_opts in this diff though - we leave it part of T83437359.
- Update .yaml to specify the Quantize-able modules.
- Update `lightning_train_net.py` to use the QuantizationAwareTraining callback.

Reviewed By: kandluis

Differential Revision: D26304879

fbshipit-source-id: 948bef4817d385d8a0969e4990d7f17ecd6994b7
parent bd4ba04d
...@@ -79,6 +79,9 @@ def add_quantization_default_configs(_C): ...@@ -79,6 +79,9 @@ def add_quantization_default_configs(_C):
# used to enable metarch set_custom_qscheme (need to implement) # used to enable metarch set_custom_qscheme (need to implement)
# this is a limited implementation where only str is provided to change options # this is a limited implementation where only str is provided to change options
_C.QUANTIZATION.CUSTOM_QSCHEME = "" _C.QUANTIZATION.CUSTOM_QSCHEME = ""
_C.QUANTIZATION.MODULES = None
# Lightning quantization callback name
_C.QUANTIZATION.NAME = ""
# quantization-aware training # quantization-aware training
_C.QUANTIZATION.QAT = CfgNode() _C.QUANTIZATION.QAT = CfgNode()
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from detectron2.utils.registry import Registry
from d2go.config import CfgNode
CALLBACK_REGISTRY = Registry("D2GO_CALLBACK_REGISTRY")
def build_quantization_callback(cfg: CfgNode):
return CALLBACK_REGISTRY.get(cfg.QUANTIZATION.NAME).from_config(cfg)
...@@ -7,7 +7,10 @@ from types import MethodType ...@@ -7,7 +7,10 @@ from types import MethodType
from typing import Any, Callable, Dict, List, Set, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Set, Optional, Tuple, Union
import torch import torch
from d2go.config import CfgNode
from d2go.utils.misc import mode from d2go.utils.misc import mode
from d2go.runner.callbacks.build import CALLBACK_REGISTRY
from mobile_cv.arch.quantization.observer import update_stat as observer_update_stat
from pytorch_lightning import LightningModule, Trainer from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import rank_zero_info from pytorch_lightning.utilities import rank_zero_info
...@@ -235,6 +238,7 @@ class ModelTransform: ...@@ -235,6 +238,7 @@ class ModelTransform:
raise ValueError("interval must be positive.") raise ValueError("interval must be positive.")
@CALLBACK_REGISTRY.register()
class QuantizationAwareTraining(Callback, QuantizationMixin): class QuantizationAwareTraining(Callback, QuantizationMixin):
"""Enable QAT of a model using the STL Trainer. """Enable QAT of a model using the STL Trainer.
...@@ -396,6 +400,29 @@ class QuantizationAwareTraining(Callback, QuantizationMixin): ...@@ -396,6 +400,29 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
} }
self.quantized: Optional[torch.nn.Module] = None self.quantized: Optional[torch.nn.Module] = None
@classmethod
def from_config(cls, cfg: CfgNode):
qat = cfg.QUANTIZATION.QAT
callback = cls(
qconfig_dicts={submodule: None for submodule in cfg.QUANTIZATION.MODULES}
if cfg.QUANTIZATION.MODULES
else None,
# We explicitly pass this to maintain properties for now.
preserved_attrs=["model.backbone.size_divisibility"],
start_step=qat.START_ITER,
enable_observer=(qat.ENABLE_OBSERVER_ITER, qat.DISABLE_OBSERVER_ITER),
freeze_bn_step=qat.FREEZE_BN_ITER,
)
if qat.UPDATE_OBSERVER_STATS_PERIODICALLY:
callback.transforms.append(
ModelTransform(
interval=qat.UPDATE_OBSERVER_STATS_PERIOD,
fb=observer_update_stat,
message="Updating observers.",
)
)
return callback
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
"""Override the model with a quantized-aware version on setup. """Override the model with a quantized-aware version on setup.
...@@ -470,6 +497,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin): ...@@ -470,6 +497,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
self.quantized = pl_module._quantized self.quantized = pl_module._quantized
@CALLBACK_REGISTRY.register()
class PostTrainingQuantization(Callback, QuantizationMixin): class PostTrainingQuantization(Callback, QuantizationMixin):
"""Enable post-training quantization, such as dynamic, static, and weight-only. """Enable post-training quantization, such as dynamic, static, and weight-only.
...@@ -527,6 +555,16 @@ class PostTrainingQuantization(Callback, QuantizationMixin): ...@@ -527,6 +555,16 @@ class PostTrainingQuantization(Callback, QuantizationMixin):
self.quantized: Optional[torch.nn.Module] = None self.quantized: Optional[torch.nn.Module] = None
self.should_calibrate = _requires_calibration(self.qconfig_dicts) self.should_calibrate = _requires_calibration(self.qconfig_dicts)
@classmethod
def from_config(cls, cfg: CfgNode):
return cls(
qconfig_dicts={submodule: None for submodule in cfg.QUANTIZATION.MODULES}
if cfg.QUANTIZATION.MODULES
else None,
# We explicitly pass this to maintain properties for now.
preserved_attrs=["model.backbone.size_divisibility"],
)
def on_validation_start(self, trainer: Trainer, pl_module: LightningModule) -> None: def on_validation_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
""" """
On validation start, prepare a module for quantization by adding On validation start, prepare a module for quantization by adding
......
...@@ -62,6 +62,12 @@ def _convert_to_lightning(d2_checkpoint: Dict[str, Any]) -> None: ...@@ -62,6 +62,12 @@ def _convert_to_lightning(d2_checkpoint: Dict[str, Any]) -> None:
][key] ][key]
del d2_checkpoint[_OLD_STATE_DICT_KEY][key] del d2_checkpoint[_OLD_STATE_DICT_KEY][key]
if "model.pixel_mean" in d2_checkpoint[_OLD_STATE_DICT_KEY]:
del d2_checkpoint[_OLD_STATE_DICT_KEY]["model.pixel_mean"]
if "model.pixel_std" in d2_checkpoint[_OLD_STATE_DICT_KEY]:
del d2_checkpoint[_OLD_STATE_DICT_KEY]["model.pixel_std"]
for old, new in zip( for old, new in zip(
[_OLD_STATE_DICT_KEY, "iteration"], [_STATE_DICT_KEY, "global_step"] [_OLD_STATE_DICT_KEY, "iteration"], [_STATE_DICT_KEY, "global_step"]
): ):
......
...@@ -4,13 +4,17 @@ ...@@ -4,13 +4,17 @@
import logging import logging
import os import os
from d2go.runner.callbacks.build import build_quantization_callback
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Type from typing import Any, Dict, List, Optional, Type
import pytorch_lightning as pl # type: ignore import pytorch_lightning as pl # type: ignore
from d2go.config import CfgNode, temp_defrost from d2go.config import CfgNode, temp_defrost
from d2go.runner import create_runner from d2go.runner import create_runner
from d2go.runner.callbacks.quantization import QuantizationAwareTraining from d2go.runner.callbacks.quantization import (
QuantizationAwareTraining,
ModelTransform,
)
from d2go.runner.lightning_task import GeneralizedRCNNTask from d2go.runner.lightning_task import GeneralizedRCNNTask
from d2go.setup import basic_argument_parser from d2go.setup import basic_argument_parser
from d2go.utils.misc import dump_trained_model_configs from d2go.utils.misc import dump_trained_model_configs
...@@ -63,20 +67,8 @@ def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]: ...@@ -63,20 +67,8 @@ def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]:
save_last=True, save_last=True,
), ),
] ]
if cfg.QUANTIZATION.QAT.ENABLED: if cfg.QUANTIZATION.NAME and cfg.QUANTIZATION.QAT.ENABLED:
qat = cfg.QUANTIZATION.QAT callbacks.append(build_quantization_callback(cfg))
callbacks.append(
QuantizationAwareTraining(
qconfig_dicts={
submodule: None for submodule in cfg.QUANTIZATION.MODULES
}
if cfg.QUANTIZATION.MODULES
else None,
start_step=qat.START_ITER,
enable_observer=(qat.ENABLE_OBSERVER_ITER, qat.DISABLE_OBSERVER_ITER),
freeze_bn_step=qat.FREEZE_BN_ITER,
)
)
return callbacks return callbacks
...@@ -185,6 +177,11 @@ def main( ...@@ -185,6 +177,11 @@ def main(
else: else:
model_configs = do_train(cfg, trainer, task) model_configs = do_train(cfg, trainer, task)
for cb in trainer_params["callbacks"]:
if isinstance(cb, QuantizationAwareTraining):
print("################ quantized #################")
print(cb.quantized)
return TrainOutput( return TrainOutput(
output_dir=cfg.OUTPUT_DIR, output_dir=cfg.OUTPUT_DIR,
tensorboard_log_dir=tb_logger.log_dir, tensorboard_log_dir=tb_logger.log_dir,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment