"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "a8b744b547fbea939e73d1054cc3a3b5d95e93c6"
Commit aeb24a92 authored by Kai Zhang's avatar Kai Zhang Committed by Facebook GitHub Bot
Browse files

Delegate to model's customization

Summary: Delegate FX quantization callback's customization to model.

Reviewed By: wat3rBro

Differential Revision: D27669212

fbshipit-source-id: 2715546cf03134896da6f95ecddaf8503ff95d0b
parent 845d0b2c
...@@ -7,6 +7,7 @@ from functools import lru_cache ...@@ -7,6 +7,7 @@ from functools import lru_cache
import torch import torch
from d2go.export.api import PredictorExportConfig from d2go.export.api import PredictorExportConfig
from d2go.utils.prepare_for_export import d2_meta_arch_prepare_for_export
from detectron2.export.caffe2_modeling import ( from detectron2.export.caffe2_modeling import (
META_ARCH_CAFFE2_EXPORT_TYPE_MAP, META_ARCH_CAFFE2_EXPORT_TYPE_MAP,
convert_batched_inputs_to_c2_format, convert_batched_inputs_to_c2_format,
...@@ -21,7 +22,7 @@ from mobile_cv.arch.utils.quantize_utils import ( ...@@ -21,7 +22,7 @@ from mobile_cv.arch.utils.quantize_utils import (
QuantWrapper, QuantWrapper,
) )
from mobile_cv.predictor.api import FuncInfo from mobile_cv.predictor.api import FuncInfo
from d2go.utils.prepare_for_export import d2_meta_arch_prepare_for_export from torch.quantization.quantize_fx import prepare_fx, prepare_qat_fx, convert_fx
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -41,6 +42,16 @@ def patch_d2_meta_arch(): ...@@ -41,6 +42,16 @@ def patch_d2_meta_arch():
else: else:
cls_obj.prepare_for_quant = d2_meta_arch_prepare_for_quant cls_obj.prepare_for_quant = d2_meta_arch_prepare_for_quant
if hasattr(cls_obj, "prepare_for_quant_convert"):
assert (
cls_obj.prepare_for_quant_convert
== d2_meta_arch_prepare_for_quant_convert
)
else:
cls_obj.prepare_for_quant_convert = (
d2_meta_arch_prepare_for_quant_convert
)
def _apply_eager_mode_quant(cfg, model): def _apply_eager_mode_quant(cfg, model):
...@@ -107,15 +118,37 @@ def _apply_eager_mode_quant(cfg, model): ...@@ -107,15 +118,37 @@ def _apply_eager_mode_quant(cfg, model):
return model return model
def d2_meta_arch_prepare_for_quant(self, cfg): def _fx_quant_prepare(self, cfg):
model = self prep_fn = prepare_qat_fx if self.training else prepare_fx
qconfig = {"": self.qconfig}
# Modify the model for eager mode self.backbone = prep_fn(
if cfg.QUANTIZATION.EAGER_MODE: self.backbone,
model = _apply_eager_mode_quant(cfg, model) qconfig,
{"preserved_attributes": ["size_divisibility"]},
)
self.proposal_generator.rpn_head.rpn_feature = prep_fn(
self.proposal_generator.rpn_head.rpn_feature, qconfig
)
self.proposal_generator.rpn_head.rpn_regressor.cls_logits = prep_fn(
self.proposal_generator.rpn_head.rpn_regressor.cls_logits, qconfig
)
self.proposal_generator.rpn_head.rpn_regressor.bbox_pred = prep_fn(
self.proposal_generator.rpn_head.rpn_regressor.bbox_pred, qconfig
)
self.roi_heads.box_head.roi_box_conv = prep_fn(
self.roi_heads.box_head.roi_box_conv, qconfig
)
self.roi_heads.box_head.avgpool = prep_fn(self.roi_heads.box_head.avgpool, qconfig)
self.roi_heads.box_predictor.cls_score = prep_fn(
self.roi_heads.box_predictor.cls_score, qconfig
)
self.roi_heads.box_predictor.bbox_pred = prep_fn(
self.roi_heads.box_predictor.bbox_pred, qconfig
)
model = fuse_utils.fuse_model(model, inplace=True)
def d2_meta_arch_prepare_for_quant(self, cfg):
model = self
torch.backends.quantized.engine = cfg.QUANTIZATION.BACKEND torch.backends.quantized.engine = cfg.QUANTIZATION.BACKEND
model.qconfig = ( model.qconfig = (
torch.quantization.get_default_qat_qconfig(cfg.QUANTIZATION.BACKEND) torch.quantization.get_default_qat_qconfig(cfg.QUANTIZATION.BACKEND)
...@@ -124,4 +157,41 @@ def d2_meta_arch_prepare_for_quant(self, cfg): ...@@ -124,4 +157,41 @@ def d2_meta_arch_prepare_for_quant(self, cfg):
) )
logger.info("Setup the model with qconfig:\n{}".format(model.qconfig)) logger.info("Setup the model with qconfig:\n{}".format(model.qconfig))
# Modify the model for eager mode
if cfg.QUANTIZATION.EAGER_MODE:
model = _apply_eager_mode_quant(cfg, model)
model = fuse_utils.fuse_model(model, inplace=True)
else:
_fx_quant_prepare(model, cfg)
return model return model
def d2_meta_arch_prepare_for_quant_convert(self, cfg):
if cfg.QUANTIZATION.EAGER_MODE:
raise NotImplementedError()
self.backbone = convert_fx(
self.backbone,
convert_custom_config_dict={"preserved_attributes": ["size_divisibility"]},
)
self.proposal_generator.rpn_head.rpn_feature = convert_fx(
self.proposal_generator.rpn_head.rpn_feature
)
self.proposal_generator.rpn_head.rpn_regressor.cls_logits = convert_fx(
self.proposal_generator.rpn_head.rpn_regressor.cls_logits
)
self.proposal_generator.rpn_head.rpn_regressor.bbox_pred = convert_fx(
self.proposal_generator.rpn_head.rpn_regressor.bbox_pred
)
self.roi_heads.box_head.roi_box_conv = convert_fx(
self.roi_heads.box_head.roi_box_conv
)
self.roi_heads.box_head.avgpool = convert_fx(self.roi_heads.box_head.avgpool)
self.roi_heads.box_predictor.cls_score = convert_fx(
self.roi_heads.box_predictor.cls_score
)
self.roi_heads.box_predictor.bbox_pred = convert_fx(
self.roi_heads.box_predictor.bbox_pred
)
return self
...@@ -270,6 +270,9 @@ def default_prepare_for_quant(cfg, model): ...@@ -270,6 +270,9 @@ def default_prepare_for_quant(cfg, model):
return model return model
def default_prepare_for_quant_convert(cfg, model):
return torch.quantization.quantize_fx.convert_fx(model)
@mock_quantization_type @mock_quantization_type
def post_training_quantize(cfg, model, data_loader): def post_training_quantize(cfg, model, data_loader):
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from detectron2.utils.registry import Registry
from d2go.config import CfgNode
CALLBACK_REGISTRY = Registry("D2GO_CALLBACK_REGISTRY")
def build_quantization_callback(cfg: CfgNode):
return CALLBACK_REGISTRY.get(cfg.QUANTIZATION.NAME).from_config(cfg)
...@@ -9,7 +9,6 @@ from typing import Any, Callable, Dict, List, Set, Optional, Tuple, Union ...@@ -9,7 +9,6 @@ from typing import Any, Callable, Dict, List, Set, Optional, Tuple, Union
import torch import torch
from d2go.config import CfgNode from d2go.config import CfgNode
from d2go.utils.misc import mode from d2go.utils.misc import mode
from d2go.runner.callbacks.build import CALLBACK_REGISTRY
from mobile_cv.arch.quantization.observer import update_stat as observer_update_stat from mobile_cv.arch.quantization.observer import update_stat as observer_update_stat
from pytorch_lightning import LightningModule, Trainer from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback from pytorch_lightning.callbacks import Callback
...@@ -26,6 +25,7 @@ from torch.quantization.utils import get_quant_type ...@@ -26,6 +25,7 @@ from torch.quantization.utils import get_quant_type
QConfigDicts = Dict[str, Dict[str, Union[QConfig, QConfigDynamic]]] QConfigDicts = Dict[str, Dict[str, Union[QConfig, QConfigDynamic]]]
PREPARED = "_prepared"
def rsetattr(obj: Any, attr: str, val: Any) -> None: def rsetattr(obj: Any, attr: str, val: Any) -> None:
...@@ -96,6 +96,20 @@ def _requires_calibration(config_dicts: QConfigDicts) -> bool: ...@@ -96,6 +96,20 @@ def _requires_calibration(config_dicts: QConfigDicts) -> bool:
return False return False
def checkpoint_has_prepared(checkpoint: Dict[str, Any]) -> bool:
return any(k.startswith(PREPARED) for k in checkpoint["state_dict"].keys())
def maybe_prepare_for_quantization(model: LightningModule, checkpoint: Dict[str, Any]):
if checkpoint_has_prepared(checkpoint) and not hasattr(model, PREPARED):
# model has been prepared for QAT before saving into checkpoint
setattr(
model,
PREPARED,
_deepcopy(model).prepare_for_quant()
)
class QuantizationMixin(ABC): class QuantizationMixin(ABC):
"""Mixin defining an overrideable API for quantization customization. """Mixin defining an overrideable API for quantization customization.
...@@ -155,6 +169,8 @@ class QuantizationMixin(ABC): ...@@ -155,6 +169,8 @@ class QuantizationMixin(ABC):
Returns: Returns:
The prepared Module to be used for quantized aware training. The prepared Module to be used for quantized aware training.
""" """
if hasattr(root, "prepare_for_quant"):
return root.prepare_for_quant()
prep_fn = ( prep_fn = (
prepare_qat_fx prepare_qat_fx
if isinstance(self, QuantizationAwareTraining) if isinstance(self, QuantizationAwareTraining)
...@@ -191,6 +207,8 @@ class QuantizationMixin(ABC): ...@@ -191,6 +207,8 @@ class QuantizationMixin(ABC):
Returns: Returns:
The quantized model. The quantized model.
""" """
if hasattr(root, "prepare_for_quant_convert"):
return root.prepare_for_quant_convert()
old_attrs = { old_attrs = {
attr: rgetattr(root, attr) for attr in attrs if rhasattr(root, attr) attr: rgetattr(root, attr) for attr in attrs if rhasattr(root, attr)
} }
...@@ -238,7 +256,6 @@ class ModelTransform: ...@@ -238,7 +256,6 @@ class ModelTransform:
raise ValueError("interval must be positive.") raise ValueError("interval must be positive.")
@CALLBACK_REGISTRY.register()
class QuantizationAwareTraining(Callback, QuantizationMixin): class QuantizationAwareTraining(Callback, QuantizationMixin):
"""Enable QAT of a model using the STL Trainer. """Enable QAT of a model using the STL Trainer.
...@@ -305,6 +322,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin): ...@@ -305,6 +322,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
Dict[str, Optional[Dict[str, Union[QConfig, QConfigDynamic]]]] Dict[str, Optional[Dict[str, Union[QConfig, QConfigDynamic]]]]
] = None, ] = None,
preserved_attrs: Optional[List[str]] = None, preserved_attrs: Optional[List[str]] = None,
skip_conversion: bool = False,
) -> None: ) -> None:
""" """
Args: Args:
...@@ -399,6 +417,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin): ...@@ -399,6 +417,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
for key, value in qconfig_dicts.items() for key, value in qconfig_dicts.items()
} }
self.quantized: Optional[torch.nn.Module] = None self.quantized: Optional[torch.nn.Module] = None
self.skip_conversion = skip_conversion
@classmethod @classmethod
def from_config(cls, cfg: CfgNode): def from_config(cls, cfg: CfgNode):
...@@ -412,6 +431,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin): ...@@ -412,6 +431,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
start_step=qat.START_ITER, start_step=qat.START_ITER,
enable_observer=(qat.ENABLE_OBSERVER_ITER, qat.DISABLE_OBSERVER_ITER), enable_observer=(qat.ENABLE_OBSERVER_ITER, qat.DISABLE_OBSERVER_ITER),
freeze_bn_step=qat.FREEZE_BN_ITER, freeze_bn_step=qat.FREEZE_BN_ITER,
skip_conversion=True, # convert_fx will be handled by D2Go exporter
) )
if qat.UPDATE_OBSERVER_STATS_PERIODICALLY: if qat.UPDATE_OBSERVER_STATS_PERIODICALLY:
callback.transforms.append( callback.transforms.append(
...@@ -477,7 +497,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin): ...@@ -477,7 +497,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
""" Quantize the weights since training has finalized. """ """ Quantize the weights since training has finalized. """
if hasattr(pl_module, "_quantized"): if hasattr(pl_module, "_quantized") or self.skip_conversion:
return return
pl_module._quantized = self.convert( pl_module._quantized = self.convert(
pl_module._prepared, self.qconfig_dicts.keys(), attrs=self.preserved_attrs pl_module._prepared, self.qconfig_dicts.keys(), attrs=self.preserved_attrs
...@@ -497,7 +517,6 @@ class QuantizationAwareTraining(Callback, QuantizationMixin): ...@@ -497,7 +517,6 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
self.quantized = pl_module._quantized self.quantized = pl_module._quantized
@CALLBACK_REGISTRY.register()
class PostTrainingQuantization(Callback, QuantizationMixin): class PostTrainingQuantization(Callback, QuantizationMixin):
"""Enable post-training quantization, such as dynamic, static, and weight-only. """Enable post-training quantization, such as dynamic, static, and weight-only.
......
...@@ -24,12 +24,14 @@ from d2go.runner.default_runner import ( ...@@ -24,12 +24,14 @@ from d2go.runner.default_runner import (
) )
from d2go.setup import setup_after_launch from d2go.setup import setup_after_launch
from d2go.utils.ema_state import EMAState from d2go.utils.ema_state import EMAState
from d2go.runner.callbacks.quantization import maybe_prepare_for_quantization, PREPARED
from detectron2.modeling import build_model from detectron2.modeling import build_model
from detectron2.solver import ( from detectron2.solver import (
build_lr_scheduler as d2_build_lr_scheduler, build_lr_scheduler as d2_build_lr_scheduler,
build_optimizer as d2_build_optimizer, build_optimizer as d2_build_optimizer,
) )
from pytorch_lightning.utilities import rank_zero_info from pytorch_lightning.utilities import rank_zero_info
from d2go.modeling.quantization import default_prepare_for_quant, default_prepare_for_quant_convert
_STATE_DICT_KEY = "state_dict" _STATE_DICT_KEY = "state_dict"
_OLD_STATE_DICT_KEY = "model" _OLD_STATE_DICT_KEY = "model"
...@@ -157,7 +159,8 @@ class DefaultTask(pl.LightningModule): ...@@ -157,7 +159,8 @@ class DefaultTask(pl.LightningModule):
@classmethod @classmethod
def build_model(cls, cfg: CfgNode, eval_only=False): def build_model(cls, cfg: CfgNode, eval_only=False):
"""Builds D2go model instance from config. """Builds D2go model instance from config. If model has been prepared
for quantization, the function returns the prepared model.
NOTE: For backward compatible with existing D2Go tools. Prefer NOTE: For backward compatible with existing D2Go tools. Prefer
`from_config` in other use cases. `from_config` in other use cases.
...@@ -165,7 +168,10 @@ class DefaultTask(pl.LightningModule): ...@@ -165,7 +168,10 @@ class DefaultTask(pl.LightningModule):
cfg: D2go config node. cfg: D2go config node.
eval_only: True if model should be in eval mode. eval_only: True if model should be in eval mode.
""" """
return cls.from_config(cfg, eval_only).model task = cls.from_config(cfg, eval_only)
if hasattr(task, PREPARED):
task = getattr(task, PREPARED)
return task.model
@classmethod @classmethod
def get_default_cfg(cls): def get_default_cfg(cls):
...@@ -362,6 +368,8 @@ class DefaultTask(pl.LightningModule): ...@@ -362,6 +368,8 @@ class DefaultTask(pl.LightningModule):
if not _is_lightning_checkpoint(checkpointed_state): if not _is_lightning_checkpoint(checkpointed_state):
_convert_to_lightning(checkpointed_state) _convert_to_lightning(checkpointed_state)
maybe_prepare_for_quantization(self, checkpointed_state)
if self.ema_state: if self.ema_state:
if "model_ema" not in checkpointed_state: if "model_ema" not in checkpointed_state:
rank_zero_info( rank_zero_info(
...@@ -374,6 +382,20 @@ class DefaultTask(pl.LightningModule): ...@@ -374,6 +382,20 @@ class DefaultTask(pl.LightningModule):
# EMA state device not given, move to module device # EMA state device not given, move to module device
self.ema_state.to(self.device) self.ema_state.to(self.device)
def prepare_for_quant(self) -> pl.LightningModule:
if hasattr(self.model, "prepare_for_quant"):
self.model = self.model.prepare_for_quant(self.cfg)
else:
self.model = default_prepare_for_quant(self.cfg, self.model)
return self
def prepare_for_quant_convert(self) -> pl.LightningModule:
if hasattr(self.model, "prepare_for_quant_convert"):
self.model = self.model.prepare_for_quant_convert(self.cfg)
else:
self.model = default_prepare_for_quant_convert(self.cfg, self.model)
return self
class GeneralizedRCNNTask(DefaultTask): class GeneralizedRCNNTask(DefaultTask):
@classmethod @classmethod
......
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
from d2go.utils.testing.data_loader_helper import create_local_dataset from d2go.utils.testing.data_loader_helper import create_local_dataset
from detectron2.modeling import META_ARCH_REGISTRY from detectron2.modeling import META_ARCH_REGISTRY
from detectron2.structures import Boxes, ImageList, Instances from detectron2.structures import Boxes, ImageList, Instances
from torch.quantization.quantize_fx import prepare_qat_fx, convert_fx
@META_ARCH_REGISTRY.register() @META_ARCH_REGISTRY.register()
...@@ -50,6 +51,17 @@ class DetMetaArchForTest(torch.nn.Module): ...@@ -50,6 +51,17 @@ class DetMetaArchForTest(torch.nn.Module):
ret = [{"instances": instance}] ret = [{"instances": instance}]
return ret return ret
def prepare_for_quant(self, cfg):
self.avgpool = prepare_qat_fx(
self.avgpool,
{"": torch.quantization.get_default_qat_qconfig()},
)
return self
def prepare_for_quant_convert(self, cfg):
self.avgpool = convert_fx(self.avgpool)
return self
def get_det_meta_arch_cfg(cfg, dataset_name, output_dir): def get_det_meta_arch_cfg(cfg, dataset_name, output_dir):
cfg.MODEL.DEVICE = "cpu" cfg.MODEL.DEVICE = "cpu"
......
...@@ -11,12 +11,18 @@ import pytorch_lightning as pl # type: ignore ...@@ -11,12 +11,18 @@ import pytorch_lightning as pl # type: ignore
import torch import torch
from d2go.config import CfgNode, temp_defrost from d2go.config import CfgNode, temp_defrost
from d2go.runner import create_runner from d2go.runner import create_runner
from d2go.runner.callbacks.quantization import (
QuantizationAwareTraining,
)
from d2go.runner.lightning_task import GeneralizedRCNNTask from d2go.runner.lightning_task import GeneralizedRCNNTask
from d2go.utils.testing import meta_arch_helper as mah from d2go.utils.testing import meta_arch_helper as mah
from d2go.utils.testing.helper import tempdir from d2go.utils.testing.helper import tempdir
from detectron2.modeling import META_ARCH_REGISTRY
from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN
from detectron2.utils.events import EventStorage from detectron2.utils.events import EventStorage
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from torch import Tensor from torch import Tensor
from torch.quantization.quantize_fx import prepare_qat_fx, convert_fx
class TestLightningTask(unittest.TestCase): class TestLightningTask(unittest.TestCase):
...@@ -175,3 +181,80 @@ class TestLightningTask(unittest.TestCase): ...@@ -175,3 +181,80 @@ class TestLightningTask(unittest.TestCase):
model.state_dict(), task.ema_state.state_dict() model.state_dict(), task.ema_state.state_dict()
) )
) )
@tempdir
def test_qat(self, tmp_dir):
@META_ARCH_REGISTRY.register()
class QuantizableDetMetaArchForTest(mah.DetMetaArchForTest):
custom_config_dict = {"preserved_attributes": ["preserved_attr"]}
def __init__(self, cfg):
super().__init__(cfg)
self.avgpool.preserved_attr = "foo"
self.avgpool.not_preserved_attr = "bar"
def prepare_for_quant(self, cfg):
self.avgpool = prepare_qat_fx(
self.avgpool,
{"": torch.quantization.get_default_qat_qconfig()},
self.custom_config_dict,
)
return self
def prepare_for_quant_convert(self, cfg):
self.avgpool = convert_fx(
self.avgpool, convert_custom_config_dict=self.custom_config_dict
)
return self
cfg = self._get_cfg(tmp_dir)
cfg.MODEL.META_ARCHITECTURE = "QuantizableDetMetaArchForTest"
cfg.QUANTIZATION.QAT.ENABLED = True
task = GeneralizedRCNNTask(cfg)
callbacks = [
QuantizationAwareTraining.from_config(cfg),
ModelCheckpoint(dirpath=task.cfg.OUTPUT_DIR, save_last=True),
]
trainer = pl.Trainer(
max_steps=1,
limit_train_batches=1,
num_sanity_val_steps=0,
callbacks=callbacks,
logger=None,
)
with EventStorage() as storage:
task.storage = storage
trainer.fit(task)
prepared_avgpool = task._prepared.model.avgpool
self.assertEqual(prepared_avgpool.preserved_attr, "foo")
self.assertFalse(hasattr(prepared_avgpool, "not_preserved_attr"))
with temp_defrost(cfg):
cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
model = GeneralizedRCNNTask.build_model(cfg, eval_only=True)
self.assertTrue(isinstance(model.avgpool, torch.fx.GraphModule))
@tempdir
def test_generalized_rcnn_qat(self, tmp_dir):
cfg = GeneralizedRCNNTask.get_default_cfg()
cfg.merge_from_file("detectron2go://e2e_mask_rcnn_fbnet_600_qat.yaml")
cfg.MODEL.DEVICE = "cpu"
cfg.QUANTIZATION.EAGER_MODE = False
cfg.OUTPUT_DIR = tmp_dir
task = GeneralizedRCNNTask(cfg)
callbacks = [
QuantizationAwareTraining.from_config(cfg),
ModelCheckpoint(dirpath=task.cfg.OUTPUT_DIR, save_last=True),
]
trainer = pl.Trainer(
max_steps=1,
limit_train_batches=1,
num_sanity_val_steps=0,
callbacks=callbacks,
logger=None,
)
with EventStorage() as storage:
task.storage = storage
trainer.fit(task)
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
import logging import logging
import os import os
from d2go.runner.callbacks.build import build_quantization_callback
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Type from typing import Any, Dict, List, Optional, Type
...@@ -67,8 +66,8 @@ def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]: ...@@ -67,8 +66,8 @@ def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]:
save_last=True, save_last=True,
), ),
] ]
if cfg.QUANTIZATION.NAME and cfg.QUANTIZATION.QAT.ENABLED: if cfg.QUANTIZATION.QAT.ENABLED:
callbacks.append(build_quantization_callback(cfg)) callbacks.append(QuantizationAwareTraining.from_config(cfg))
return callbacks return callbacks
...@@ -177,11 +176,6 @@ def main( ...@@ -177,11 +176,6 @@ def main(
else: else:
model_configs = do_train(cfg, trainer, task) model_configs = do_train(cfg, trainer, task)
for cb in trainer_params["callbacks"]:
if isinstance(cb, QuantizationAwareTraining):
print("################ quantized #################")
print(cb.quantized)
return TrainOutput( return TrainOutput(
output_dir=cfg.OUTPUT_DIR, output_dir=cfg.OUTPUT_DIR,
tensorboard_log_dir=tb_logger.log_dir, tensorboard_log_dir=tb_logger.log_dir,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment