Commit f6ce583e authored by Peizhao Zhang's avatar Peizhao Zhang Committed by Facebook GitHub Bot
Browse files

Supported learnable qat.

Summary:
Supported learnable qat.
* Added a config key `QUANTIZATION.QAT.FAKE_QUANT_METHOD` to specify the qat metod (`default` or `learnable`).
* Added a config key `QUANTIZATION.QAT.ENABLE_LEARNABLE_OBSERVER_ITER` to specify the start iteration for learnable observers (before that it is using static observers).
* Custom quantization code needs to call ` d2go.utils.qat_utils.get_qat_qconfig()` to get proper qconfig for learnable qat. An exception will raise if qat method is learnable but no learnable observers are used in the model.
* Set the weight decay for scale/zero_point to 0 for the optimizer automatically.
* The way to use larnable qat: enable static observers -> enable fake quant -> enable learnable observers -> freeze bn.

Differential Revision: D31370822

fbshipit-source-id: a5a5044a539d0d7fe1cc6b36e6821fc411ce752a
parent ef9c20cc
...@@ -8,6 +8,7 @@ import logging ...@@ -8,6 +8,7 @@ import logging
import math import math
from typing import Tuple from typing import Tuple
import d2go.utils.qat_utils as qat_utils
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
import torch import torch
from detectron2.checkpoint import DetectionCheckpointer from detectron2.checkpoint import DetectionCheckpointer
...@@ -95,6 +96,8 @@ def add_quantization_default_configs(_C): ...@@ -95,6 +96,8 @@ def add_quantization_default_configs(_C):
# quantization-aware training # quantization-aware training
_C.QUANTIZATION.QAT = CfgNode() _C.QUANTIZATION.QAT = CfgNode()
_C.QUANTIZATION.QAT.ENABLED = False _C.QUANTIZATION.QAT.ENABLED = False
# Methods for QAT training, could be "default" or "learnable"
_C.QUANTIZATION.QAT.FAKE_QUANT_METHOD = "default"
# QAT will use more GPU memory, user can change this factor to reduce the batch size # QAT will use more GPU memory, user can change this factor to reduce the batch size
# after fake quant is enabled. Setting it to 0.5 should guarantee no memory increase # after fake quant is enabled. Setting it to 0.5 should guarantee no memory increase
# compared with QAT is disabled. # compared with QAT is disabled.
...@@ -106,6 +109,8 @@ def add_quantization_default_configs(_C): ...@@ -106,6 +109,8 @@ def add_quantization_default_configs(_C):
# the iteration number to enable observer, it's usually set to be the same as # the iteration number to enable observer, it's usually set to be the same as
# QUANTIZATION.QAT.START_ITER. # QUANTIZATION.QAT.START_ITER.
_C.QUANTIZATION.QAT.ENABLE_OBSERVER_ITER = 35000 _C.QUANTIZATION.QAT.ENABLE_OBSERVER_ITER = 35000
# the iteration number to enable learnable observer, only used when METHOD == "learnable"
_C.QUANTIZATION.QAT.ENABLE_LEARNABLE_OBSERVER_ITER = 36000
# the iteration number to disable observer, here it's 3k after enabling the fake # the iteration number to disable observer, here it's 3k after enabling the fake
# quant, 3k roughly corresponds to 7 out of 90 epochs in classification. # quant, 3k roughly corresponds to 7 out of 90 epochs in classification.
_C.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER = 35000 + 3000 _C.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER = 35000 + 3000
...@@ -238,7 +243,9 @@ def default_prepare_for_quant(cfg, model): ...@@ -238,7 +243,9 @@ def default_prepare_for_quant(cfg, model):
nn.Module: a ready model for QAT training or PTQ calibration nn.Module: a ready model for QAT training or PTQ calibration
""" """
qconfig = ( qconfig = (
torch.ao.quantization.get_default_qat_qconfig(cfg.QUANTIZATION.BACKEND) qat_utils.get_qat_qconfig(
cfg.QUANTIZATION.BACKEND, cfg.QUANTIZATION.QAT.FAKE_QUANT_METHOD
)
if model.training if model.training
else torch.ao.quantization.get_default_qconfig(cfg.QUANTIZATION.BACKEND) else torch.ao.quantization.get_default_qconfig(cfg.QUANTIZATION.BACKEND)
) )
...@@ -346,12 +353,16 @@ def setup_qat_model( ...@@ -346,12 +353,16 @@ def setup_qat_model(
model_fp32, model_fp32,
enable_fake_quant: bool = False, enable_fake_quant: bool = False,
enable_observer: bool = False, enable_observer: bool = False,
enable_learnable_observer: bool = False,
): ):
assert cfg.QUANTIZATION.QAT.FAKE_QUANT_METHOD in ["default", "learnable"]
if hasattr(model_fp32, "_non_qat_to_qat_state_dict_map"): if hasattr(model_fp32, "_non_qat_to_qat_state_dict_map"):
raise RuntimeError("The model is already setup to be QAT, cannot setup again!") raise RuntimeError("The model is already setup to be QAT, cannot setup again!")
device = model_fp32.device device = model_fp32.device
torch.backends.quantized.engine = cfg.QUANTIZATION.BACKEND torch.backends.quantized.engine = cfg.QUANTIZATION.BACKEND
qat_method = cfg.QUANTIZATION.QAT.FAKE_QUANT_METHOD
# prepare for qat may modify the fp32 model directly so we create a copy # prepare for qat may modify the fp32 model directly so we create a copy
model_fp32_state_dict = model_fp32.state_dict() model_fp32_state_dict = model_fp32.state_dict()
...@@ -359,21 +370,32 @@ def setup_qat_model( ...@@ -359,21 +370,32 @@ def setup_qat_model(
# prepare model for qat # prepare model for qat
model = _prepare_model_for_qat(cfg, model_fp32) model = _prepare_model_for_qat(cfg, model_fp32)
# make sure the proper qconfig are used in the model
qat_utils.check_for_learnable_fake_quant_ops(qat_method, model)
# Move newly added observers to the original device # Move newly added observers to the original device
model.to(device) model.to(device)
if not enable_fake_quant: if not enable_fake_quant:
logger.info("Disabling fake quant ...") logger.info("Disabling fake quant ...")
model.apply(torch.ao.quantization.disable_fake_quant) model.apply(torch.ao.quantization.disable_fake_quant)
model.apply(qat_utils.disable_lqat_fake_quant)
if not enable_observer: if not enable_observer:
logger.info("Disabling observer ...") logger.info("Disabling static observer ...")
model.apply(torch.ao.quantization.disable_observer) model.apply(torch.ao.quantization.disable_observer)
model.apply(qat_utils.disable_lqat_static_observer)
if not enable_learnable_observer and qat_method == "learnable":
logger.info("Disabling learnable observer ...")
model.apply(qat_utils.disable_lqat_learnable_observer)
# qat state dict mapper # qat state dict mapper
model = _setup_non_qat_to_qat_state_dict_map( model = _setup_non_qat_to_qat_state_dict_map(
model_fp32_state_dict, model, is_eager_mode=cfg.QUANTIZATION.EAGER_MODE model_fp32_state_dict, model, is_eager_mode=cfg.QUANTIZATION.EAGER_MODE
) )
# qat optimizer group for learnable qat
model = qat_utils.setup_qat_get_optimizer_param_groups(model, qat_method)
return model return model
...@@ -433,6 +455,7 @@ class QATHook(HookBase): ...@@ -433,6 +455,7 @@ class QATHook(HookBase):
self._applied = { self._applied = {
"enable_fake_quant": False, "enable_fake_quant": False,
"enable_observer": False, "enable_observer": False,
"enable_learnable_observer": False,
"disable_observer": False, "disable_observer": False,
"freeze_bn_stats": False, "freeze_bn_stats": False,
} }
...@@ -455,6 +478,7 @@ class QATHook(HookBase): ...@@ -455,6 +478,7 @@ class QATHook(HookBase):
"[QAT] enable fake quant to start QAT, iter = {}".format(cur_iter) "[QAT] enable fake quant to start QAT, iter = {}".format(cur_iter)
) )
model.apply(torch.ao.quantization.enable_fake_quant) model.apply(torch.ao.quantization.enable_fake_quant)
model.apply(qat_utils.enable_lqat_fake_quant)
self._applied["enable_fake_quant"] = True self._applied["enable_fake_quant"] = True
_reset_qat_data_loader_if_needed( _reset_qat_data_loader_if_needed(
...@@ -466,10 +490,19 @@ class QATHook(HookBase): ...@@ -466,10 +490,19 @@ class QATHook(HookBase):
and cur_iter >= cfg.QUANTIZATION.QAT.ENABLE_OBSERVER_ITER and cur_iter >= cfg.QUANTIZATION.QAT.ENABLE_OBSERVER_ITER
and cur_iter < cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER and cur_iter < cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER
): ):
logger.info("[QAT] enable observer, iter = {}".format(cur_iter)) logger.info("[QAT] enable static observer, iter = {}".format(cur_iter))
model.apply(torch.ao.quantization.enable_observer) model.apply(torch.ao.quantization.enable_observer)
model.apply(qat_utils.enable_lqat_static_observer)
self._applied["enable_observer"] = True self._applied["enable_observer"] = True
if (
not self._applied["enable_learnable_observer"]
and cur_iter >= cfg.QUANTIZATION.QAT.ENABLE_LEARNABLE_OBSERVER_ITER
):
logger.info(f"[QAT] enabling learnable observer, iter = {cur_iter}")
model.apply(qat_utils.enable_lqat_learnable_observer)
self._applied["enable_learnable_observer"] = True
if ( if (
not self._applied["disable_observer"] not self._applied["disable_observer"]
and cur_iter >= cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER and cur_iter >= cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER
...@@ -478,6 +511,8 @@ class QATHook(HookBase): ...@@ -478,6 +511,8 @@ class QATHook(HookBase):
"[QAT] disabling observer for sub seq iters, iter = {}".format(cur_iter) "[QAT] disabling observer for sub seq iters, iter = {}".format(cur_iter)
) )
model.apply(torch.ao.quantization.disable_observer) model.apply(torch.ao.quantization.disable_observer)
model.apply(qat_utils.disable_lqat_static_observer)
model.apply(qat_utils.disable_lqat_learnable_observer)
self._applied["disable_observer"] = True self._applied["disable_observer"] = True
if ( if (
......
...@@ -6,11 +6,13 @@ from collections import defaultdict ...@@ -6,11 +6,13 @@ from collections import defaultdict
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import torch import torch
from d2go.utils.qat_utils import iterate_module_named_parameters
from detectron2.solver.build import ( from detectron2.solver.build import (
maybe_add_gradient_clipping as d2_maybe_add_gradient_clipping, maybe_add_gradient_clipping as d2_maybe_add_gradient_clipping,
) )
from detectron2.utils.registry import Registry from detectron2.utils.registry import Registry
D2GO_OPTIM_MAPPER_REGISTRY = Registry("D2GO_OPTIM_MAPPER") D2GO_OPTIM_MAPPER_REGISTRY = Registry("D2GO_OPTIM_MAPPER")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -101,23 +103,6 @@ def regroup_optimizer_param_groups(params: List[Dict[str, Any]]): ...@@ -101,23 +103,6 @@ def regroup_optimizer_param_groups(params: List[Dict[str, Any]]):
return ret return ret
def iterate_module_named_parameters(
model: OptimizerModelsType, check_requires_grad=True
):
"""Iterate over all parameters for the model"""
memo = set()
for module_name, module in model.named_modules():
for module_param_name, value in module.named_parameters(recurse=False):
if check_requires_grad and not value.requires_grad:
continue
# Avoid duplicating parameters
if value in memo:
continue
memo.add(value)
yield module_name, module, module_param_name, value
def get_optimizer_param_groups_default(model: OptimizerModelsType): def get_optimizer_param_groups_default(model: OptimizerModelsType):
ret = [ ret = [
{ {
......
#!/usr/bin/env python3
import logging
from functools import partial
import torch
import torch.distributed as dist
from torch.ao.quantization._learnable_fake_quantize import _LearnableFakeQuantize
logger = logging.getLogger(__name__)
def mixin_with_subclass(module, mix_class):
"""Create a subclass of type(module) and mix_class while using all the data
from the `module` object
"""
ModuleType = type(module)
class SubClass(mix_class, ModuleType):
def __init__(self, module):
assert isinstance(module, ModuleType)
# initialize the parent by copying the dict directly
self.__dict__ = module.__dict__.copy()
ret = SubClass(module)
return ret
def _has_module(model, module_type):
for x in model.modules():
if isinstance(x, module_type):
return True
return False
def check_for_learnable_fake_quant_ops(qat_method, model):
"""Make sure learnable observers are used if qat method is `learnable`"""
if qat_method == "learnable":
if not _has_module(model, _LearnableFakeQuantize):
raise Exception(
"No learnable fake quant is used for learnable quantzation, please use d2go.utils.qat_utils.get_qat_qconfig() to get proper qconfig"
)
def iterate_module_named_parameters(model, check_requires_grad=True):
"""Iterate over all parameters for the model"""
memo = set()
for module_name, module in model.named_modules():
for module_param_name, value in module.named_parameters(recurse=False):
if check_requires_grad and not value.requires_grad:
continue
# Avoid duplicating parameters
if value in memo:
continue
memo.add(value)
yield module_name, module, module_param_name, value
def get_qat_qconfig(backend, qat_method="default"):
assert backend in ["qnnpack", "fbgemm"]
assert qat_method in ["default", "learnable"]
if qat_method == "default":
return torch.quantization.get_default_qat_qconfig(backend)
ACT_CONFIGS = {
# follow `get_default_qat_qconfig()`
# fbcode/caffe2/torch/quantization/qconfig.py
"fbgemm": {
"reduce_range": True,
},
"qnnpack": {
"reduce_range": False,
},
}
WEIGHT_CONFIGS = {
# follow `default_per_channel_weight_fake_quant`
# fbcode/caffe2/torch/quantization/fake_quantize.py
"fbgemm": {
"observer": torch.quantization.MovingAveragePerChannelMinMaxObserver,
"qscheme": torch.per_channel_symmetric,
"reduce_range": False,
"ch_axis": 0,
},
# follow `default_weight_fake_quant`
# fbcode/caffe2/torch/quantization/fake_quantize.py
"qnnpack": {
"observer": torch.quantization.MovingAverageMinMaxObserver,
"qscheme": torch.per_tensor_symmetric,
"reduce_range": False,
},
}
act = _LearnableFakeQuantize.with_args(
observer=torch.quantization.MovingAverageMinMaxObserver,
quant_min=0,
quant_max=255,
use_grad_scaling=True,
**ACT_CONFIGS[backend],
)
weight = _LearnableFakeQuantize.with_args(
quant_min=-128,
quant_max=127,
dtype=torch.qint8,
use_grad_scaling=True,
**WEIGHT_CONFIGS[backend],
)
return torch.quantization.QConfig(activation=act, weight=weight)
def get_world_size() -> int:
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()
def sync_tensor(data):
world_size = get_world_size()
if world_size > 1:
dist.all_reduce(data, op=dist.ReduceOp.SUM)
data /= world_size
def toggle_lqat_fake_quant(mod, enable):
"""Toggle fake quantization for learnable qat"""
if type(mod) == _LearnableFakeQuantize:
mod.toggle_fake_quant(enable)
# enable/disable fake quantization for learnable qat
enable_lqat_fake_quant = partial(toggle_lqat_fake_quant, enable=True)
disable_lqat_fake_quant = partial(toggle_lqat_fake_quant, enable=False)
def toggle_lqat_static_observer(mod, enable):
"""Toggle static observers for learnable qat"""
if type(mod) == _LearnableFakeQuantize:
mod.toggle_observer_update(enable)
# enable/disable static observer for learnable qat
enable_lqat_static_observer = partial(toggle_lqat_static_observer, enable=True)
disable_lqat_static_observer = partial(toggle_lqat_static_observer, enable=False)
def enable_lqat_learnable_observer(mod):
"""Enable learning observers, will disable static observer updates"""
if type(mod) == _LearnableFakeQuantize:
sync_tensor(mod.scale.data)
sync_tensor(mod.zero_point.data)
mod.toggle_qparam_learning(enabled=True).toggle_observer_update(enabled=False)
def disable_lqat_learnable_observer(mod):
"""Disable learning observers"""
if type(mod) == _LearnableFakeQuantize:
mod.toggle_qparam_learning(enabled=False)
def get_optimizer_param_groups_learnable_qat(model, _):
"""Set the weight decay for scale/zero_point for learnable_fake_quant to 0"""
params = []
for (
_module_name,
module,
module_param_name,
value,
) in iterate_module_named_parameters(model, check_requires_grad=False):
if isinstance(module, _LearnableFakeQuantize):
if module_param_name in ("scale", "zero_point"):
params += [
{
"params": [value],
"weight_decay": 0.0,
}
]
return params
def _is_observer_key(state_dict_key):
observer_keys = ["activation_post_process", "weight_fake_quant"]
return any(x in state_dict_key for x in observer_keys)
def _is_q_state_dict(state_dict):
return any(_is_observer_key(k) for k in state_dict)
class ModelGetOptimizerParamGroupLearnableQATMixin:
def get_optimizer_param_groups(self, opts):
ret = []
if hasattr(super(), "get_optimizer_param_groups"):
ret = super().get_optimizer_param_groups(opts)
ret += get_optimizer_param_groups_learnable_qat(self, opts)
return ret
def setup_qat_get_optimizer_param_groups(model, qat_method):
"""Add a function `get_optimizer_param_groups` to the model so that it could
return proper weight decay for learnable qat
"""
if qat_method != "learnable":
return model
assert _is_q_state_dict(model.state_dict())
model = mixin_with_subclass(model, ModelGetOptimizerParamGroupLearnableQATMixin)
assert hasattr(model, "get_optimizer_param_groups")
return model
...@@ -316,7 +316,7 @@ class TestDefaultRunner(unittest.TestCase): ...@@ -316,7 +316,7 @@ class TestDefaultRunner(unittest.TestCase):
return losses return losses
def setup(tmp_dir, backend): def setup(tmp_dir, backend, qat_method):
ds_name = create_local_dataset(tmp_dir, 5, 10, 10) ds_name = create_local_dataset(tmp_dir, 5, 10, 10)
runner = default_runner.Detectron2GoRunner() runner = default_runner.Detectron2GoRunner()
cfg = _get_cfg(runner, tmp_dir, ds_name) cfg = _get_cfg(runner, tmp_dir, ds_name)
...@@ -324,19 +324,29 @@ class TestDefaultRunner(unittest.TestCase): ...@@ -324,19 +324,29 @@ class TestDefaultRunner(unittest.TestCase):
( (
["MODEL.META_ARCHITECTURE", "MetaArchForTestQAT1"] ["MODEL.META_ARCHITECTURE", "MetaArchForTestQAT1"]
+ ["QUANTIZATION.QAT.ENABLED", "True"] + ["QUANTIZATION.QAT.ENABLED", "True"]
+ ["QUANTIZATION.QAT.START_ITER", "0"] + ["QUANTIZATION.QAT.START_ITER", "1"]
+ ["QUANTIZATION.QAT.ENABLE_OBSERVER_ITER", "0"] + ["QUANTIZATION.QAT.ENABLE_OBSERVER_ITER", "0"]
+ ["QUANTIZATION.QAT.ENABLE_LEARNABLE_OBSERVER_ITER", "2"]
+ ["QUANTIZATION.QAT.DISABLE_OBSERVER_ITER", "4"]
+ ["QUANTIZATION.QAT.FREEZE_BN_ITER", "4"]
+ ["QUANTIZATION.BACKEND", backend] + ["QUANTIZATION.BACKEND", backend]
+ ["QUANTIZATION.QAT.FAKE_QUANT_METHOD", qat_method]
) )
) )
return runner, cfg return runner, cfg
for backend in ["fbgemm", "qnnpack"]: # seems that fbgemm with learnable qat is not supported
with tempfile.TemporaryDirectory() as tmp_dir: for backend, qat_method in [
runner, cfg = setup(tmp_dir, backend=backend) ("fbgemm", "default"),
model = runner.build_model(cfg) ("qnnpack", "default"),
print(model) ("qnnpack", "learnable"),
runner.do_train(cfg, model, resume=True) ]:
with self.subTest(backend=backend, qat_method=qat_method):
with tempfile.TemporaryDirectory() as tmp_dir:
runner, cfg = setup(tmp_dir, backend=backend, qat_method=qat_method)
model = runner.build_model(cfg)
print(model)
runner.do_train(cfg, model, resume=True)
default_runner._close_all_tbx_writers() default_runner._close_all_tbx_writers()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment