Commit 5d8068d8 authored by Kai Zhang's avatar Kai Zhang Committed by Facebook GitHub Bot
Browse files

Copy quantization callback to D2go

Summary: As titled. Make a copy of quantization callback to unblock D2go OSS.

Reviewed By: zhanghang1989

Differential Revision: D26735525

fbshipit-source-id: 12b77f04cfa1361e856b26ea218a262da1fadd88
parent f23248c0
# pyre-ignore-all-errors
import functools
from abc import ABC
from copy import deepcopy
from dataclasses import dataclass
from types import MethodType
from typing import Any, Callable, Dict, List, Set, Optional, Tuple, Union
import torch
from d2go.utils.misc import mode
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import rank_zero_info
from torch.quantization import ( # @manual
QConfig,
QConfigDynamic,
QuantType,
get_default_qat_qconfig,
get_default_qconfig,
)
from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
from torch.quantization.utils import get_quant_type
QConfigDicts = Dict[str, Dict[str, Union[QConfig, QConfigDynamic]]]
def rsetattr(obj: Any, attr: str, val: Any) -> None:
""" Same as setattr but supports deeply nested objects. """
pre, _, post = attr.rpartition(".")
return setattr(rgetattr(obj, pre) if pre else obj, post, val)
def rgetattr(obj: Any, attr: str, *args) -> Any:
""" Same as getattr but supports deeply nested objects. """
def _getattr(obj, attr):
return getattr(obj, attr, *args)
return functools.reduce(_getattr, [obj] + attr.split("."))
def _deepcopy(pl_module: LightningModule) -> LightningModule:
"""Copy a LightningModule. Some properties need to be ignored. """
# Remove _result before call to deepcopy since it store non-leaf Tensors.
# If not removed, you'll see this error on deepcopy() attempts: P150283141.
if hasattr(pl_module, "_results"):
result = pl_module._results
delattr(pl_module, "_results")
copy = deepcopy(pl_module)
# Set back.
pl_module._results = result
else:
copy = deepcopy(pl_module)
return copy
def _quantized_forward(self, *args, **kwargs):
""" Forward method for a quantized module. """
if not self.training and hasattr(self, "_quantized"):
return self._quantized(*args, **kwargs)
return self._prepared(*args, **kwargs)
def _requires_calibration(config_dicts: QConfigDicts) -> bool:
"""Returns whether the given config_dicts for quantization requires calibration.
A config_dicts requires calibration if at least one of the configs in the
dictioary is a QConfig with an activation observer.
Args:
config: The config dictionary to check.
Returns:
Boolean as described.
"""
for qconfig_dict in config_dicts.values():
for qconfig in qconfig_dict.values():
qtype = get_quant_type(qconfig)
if qtype == QuantType.STATIC:
return True
return False
class QuantizationMixin(ABC):
"""Mixin defining an overrideable API for quantization customization.
For example, suppose our model contains traceable and non-traceable modules:
>>> class MyNonTraceableModel(LightningModule):
... def __init__(self):
... self.traceable = ...
... self.non_traceable = ...
...
... def forward(self, x):
... x = self.traceable(x)
... return self.non_traceable(x)
Then using FX-mode quantization, we can only quantize the traceable pieces.
As such, we could do something like the below, shown here for QAT.
>>> class MyQuantizationCallback(QuantizedAwareTraining):
... def prepare(self, model, config):
... model.traceable = prepare_qat_fx(model.traceable, config)
... return model
...
... def convert(self, model):
... model.traceable = convert_fx(model.traceable)
... return model
We can then use this callback as with any other.:
Example::
>>> model = MyNonTraceableModel(...)
>>> quantization = MyQuantizationCallback()
>>> trainer = Trainer(
... callbacks=[quantization],
... )
>>> trainer.fit(model)
"""
def prepare(self, root: LightningModule, configs: QConfigDicts) -> torch.nn.Module:
"""Prepares the root user modules for quantization.
By default, this tries to prepare the entire LightningModule. If this is
not possible (eg, due to traceability, etc.), the recommended method to
use is to override the `prepare` method to prepare the root as
appropriate, and also override the `quantize` method to only quantize
the prepared pieces of the root.
Args:
root: The LightningModule as given to the lightning Trainer in train mode.
configs: Specification to be used when preparing the model, as provided by the user.
It is guaranteed that no key is a suffix of another.
Returns:
The prepared Module to be used for quantized aware training.
"""
prep_fn = (
prepare_qat_fx
if isinstance(self, QuantizationAwareTraining)
else prepare_fx
)
if "" in configs:
return prep_fn(root, configs[""])
for name, config in configs.items():
submodule = rgetattr(root, name)
rsetattr(root, name, prep_fn(submodule, config))
return root
def convert(self, root: torch.nn.Module, submodules: Set[str]) -> torch.nn.Module:
"""Quantizes a previously prepared module (as returned by `prepare`).
By default, this simply quantizes the entire root. If the `prepare`
method was customized, this will need to be changed as well.
Args:
root: The prepared model as returned by `prepare`, after training.
submodules: An iterator of fully qualified submodules names that require
converting.
Returns:
The quantized model.
"""
if "" in submodules:
return convert_fx(root)
for name in submodules:
prepared = rgetattr(root, name)
rsetattr(root, name, convert_fx(prepared))
return root
@dataclass(frozen=True)
class ModelTransform:
"""Defines a step or interval at which fn should be .apply(fn)'ed and a message to log.
Properties:
fn: The function to apply. Must be passable to torch.nn.Module.apply(fn).
step: Only one of `step` or `interval` must be defined. If step is defined,
`fn` will be applied exactly once right before `step` step begins.
interval: Only one of `step` or `interval` must be defined. If `interval`
is defined, the transform will be applied periodically every
`interval` steps.
message: A short non-punctuated message to log in the master worker when
this transform is triggered.
"""
fn: Callable[[torch.nn.Module], None]
message: str
step: Optional[int] = None
interval: Optional[int] = None
def __post_init__(self):
""" Validate a few properties for early failure. """
if (self.step is None and self.interval is None) or (
self.step is not None and self.interval is not None
):
raise TypeError("Exactly one of step or interval must be defined.")
if self.step is not None and self.step < 0:
raise ValueError("step must be non-negative.")
if self.interval is not None and self.interval <= 0:
raise ValueError("interval must be positive.")
class QuantizationAwareTraining(Callback, QuantizationMixin):
"""Enable QAT of a model using the STL Trainer.
Node that this callback makes changes during training in order to properly
quantize the provided LightningModule.
Example::
>>> from stl.lightning.callbacks.quantization import QuantizationAwareTraining
>>> from pytorch_lightning import Trainer
>>> from stl.lightning.utilities.model import mode
...
# MyLightningModule must define val_dataloader() which is used both for
# validation as well as calibration of the quantized model.
>>> model = MyLightningModule(...)
>>> qat = QuantizationAwareTraining()
>>> trainer = Trainer(
... callbacks=[qat],
... )
# This will convert the model into one that is quantizeable, train it,
# and then quantize it after training is done.
>>> trainer.fit(model)
# You can use the model directly.
>>> input = ...
>>> with mode(model, training=False) as m:
... quantized_out = m(input)
If you only wish to quantize parts of your model, please see QuantizationMixin
for an example of how to do this.
Properties:
transforms: A list of ModelTransform's applied to the model exactly once
as specified during training. Example transforms are enabling/disabling
observers/quants, which are added to this list based on the init
parameters to this callback. Users can further augment the list
with more custom modules.
prepared: If set, this is the prepared model. Only available
after .fit() starts.
qconfig_dicts:
This is a map from the `module_qualified_name` to the corresponding QConfigDict
to apply to that module. For example, suppose your LightningModule contains
two submodules module.scriptable and module.not_scriptable. You'd provide
a qconfig_dicts like:
{
"scriptable": ...
}
This will quantize just module.scriptable using the provided QConfigDict,
or a default one. If you wish to quantize the entire LightningModule,
simply use "" as the qualified name. The name should match the names
returned by module.named_modules().
quantized: If set, this is the fully quantized model. Only available
after .fit() finishes.
"""
def __init__(
self,
start_step: int = 0,
enable_observer: Tuple[int, Optional[int]] = (0, None),
freeze_bn_step: Optional[int] = None,
qconfig_dicts: Optional[
Dict[str, Optional[Dict[str, Union[QConfig, QConfigDynamic]]]]
] = None,
) -> None:
"""
Args:
start_step: The training step at which QAT is enabled. The model is
always mutated with the appropriate stubs, but they are disabled
until the start of this training step.
See FakeQuantizeBase.fake_quant_enabled
enable_observer: The half-open interval [a, b) in steps during which the
observers are enabled. See FakeQuantizeBase.observer_enabled. If
b is None, the observer is never disabled once enabled.
freeze_bn_step: If specified, the step at which we apply freeze the
collection of batch normalization layer statistics for QAT.
qconfig_dicts: If given, used for quantization of the model during training.
"""
if start_step < 0:
raise ValueError(
f"The starting step of QAT must be non-negative. Got {start_step}."
)
start_observer, end_observer = enable_observer
if start_observer < 0:
raise ValueError(
f"The starting step for the observer must be non-negative. Got {start_observer}."
)
if end_observer and end_observer <= start_observer:
raise ValueError(
f"The observation interval must contain at least one step. Got [{start_step}, {end_observer})."
)
if freeze_bn_step and freeze_bn_step < 0:
raise ValueError(
f"The step at which batch norm layers are frozen must be non-negative. Got {freeze_bn_step}."
)
self.transforms: List[ModelTransform] = []
if start_step > 0:
self.transforms.extend(
[
# Enabled by default, so the assumption for > 0 is that the
# user wants it disabled then enabled.
ModelTransform(
fn=torch.quantization.disable_fake_quant,
step=0,
message="Disable fake quant",
),
ModelTransform(
fn=torch.quantization.enable_fake_quant,
step=start_step,
message="Enable fake quant to start QAT",
),
]
)
if start_observer > 0:
self.transforms.extend(
# See comment for start_step above.
[
ModelTransform(
fn=torch.quantization.disable_observer,
step=0,
message="Disable observer",
),
ModelTransform(
fn=torch.quantization.enable_observer,
step=start_observer,
message="Start observer",
),
]
)
if end_observer is not None:
self.transforms.append(
ModelTransform(
fn=torch.quantization.disable_observer,
step=end_observer,
message="End observer",
)
)
if freeze_bn_step is not None:
self.transforms.append(
ModelTransform(
fn=torch.nn.intrinsic.qat.freeze_bn_stats,
step=freeze_bn_step,
message="Freeze BN",
)
)
self.prepared: Optional[torch.nn.Module] = None
if not qconfig_dicts:
self.qconfig_dicts: QConfigDicts = {"": {"": get_default_qat_qconfig()}}
else:
self.qconfig_dicts: QConfigDicts = {
key: value if value else {"": get_default_qat_qconfig()}
for key, value in qconfig_dicts.items()
}
self.quantized: Optional[torch.nn.Module] = None
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
"""Override the model with a quantized-aware version on setup.
This is the earliest place we can override this model which allows for
appropriate behavior when restoring from checkpoints, as well as connecting
to accelerators, etc.
The model is only prepared once.
"""
# Only prepare the model once.
if hasattr(pl_module, "_prepared"):
return
with mode(pl_module, training=True) as train:
pl_module._prepared = self.prepare(
_deepcopy(train), configs=self.qconfig_dicts
)
pl_module.forward = MethodType(_quantized_forward, pl_module)
self.prepared = pl_module._prepared
def on_train_batch_start(
self,
trainer: Trainer,
pl_module: LightningModule,
batch: Any,
batch_idx: int,
dataloader_idx: int,
) -> None:
""" Applies model transforms at as specified during training. """
apply_only_once = []
current_step = trainer.global_step
for i, transform in enumerate(self.transforms):
if (transform.step is not None and transform.step <= current_step) or (
transform.interval is not None
and current_step % transform.interval == 0
):
self.prepared.apply(transform.fn)
rank_zero_info(
f"[QAT] {transform.message} at step={trainer.global_step}."
)
if transform.step is not None and transform.step <= current_step:
apply_only_once.append(i)
if apply_only_once:
self.transforms = [
transform
for i, transform in enumerate(self.transforms)
if i not in set(apply_only_once)
]
def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
""" Quantize the weights since training has finalized. """
if hasattr(pl_module, "_quantized"):
return
pl_module._quantized = self.convert(
pl_module._prepared, self.qconfig_dicts.keys()
)
self.quantized = pl_module._quantized
def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Make sure we have a quantized version.
This handles the edge case where a user does .test() without .fit() first.
"""
if hasattr(pl_module, "_quantized"):
return
pl_module._quantized = self.convert(
pl_module._prepared, self.qconfig_dicts.keys()
)
self.quantized = pl_module._quantized
class PostTrainingQuantization(Callback, QuantizationMixin):
"""Enable post-training quantization, such as dynamic, static, and weight-only.
This is an idempotent callback (to contrast with QuantizationAwareTraining).
If calibration is required, we will use the validation data set provided to
the STL Trainer, and this occurs on each validation run.
The quantized model is made available as a property of the callback.
Example::
>>> from stl.lightning.callbacks.quantization import PostTrainingQuantization
>>> from pytorch_lightning import Trainer
>>> from stl.lightning.utilities.model import mode
...
# MyLightningModule must define val_dataloader() which is used both for
# validation as well as calibration of the quantized model.
>>> model = MyLightningModule(...)
>>> post_training_quant = PostTrainingQuantization()
>>> trainer = Trainer(
... callbacks=[post_training_quant],
... )
# This will both train the model + create a *separate* quantized version.
# The original model is left unchaged.
>>> trainer.fit(model)
# You can access the quantized version of the model directly.
>>> input = ...
>>> with mode(post_training_quant.quantized, training=False) as m:
... quantized_out = m(input)
If you only wish to quantize parts of your model, please see QuantizationMixin
for an example of how to do this.
Properties:
prepared: If set, this is the prepared model which can be used for
calibration. Only available after validation start.
qconfig_dicts: See `QuantizedAwareTraining` for full description.
quantized: If set, this is the fully quantized model calibrated using
the validation data. Only available after validation has ended.
"""
def __init__(self, qconfig_dicts: Optional[QConfigDicts] = None) -> None:
""" Initialize the callback. """
self.qconfig_dicts = qconfig_dicts or {"": {"": get_default_qconfig()}}
self.prepared: Optional[torch.nn.Module] = None
self.quantized: Optional[torch.nn.Module] = None
self.should_calibrate = _requires_calibration(self.qconfig_dicts)
def on_validation_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""
On validation start, prepare a module for quantization by adding
observers and loading weights from current model.
"""
# Pass a copy to quantization APIs.
self.prepared = self.prepare(
_deepcopy(pl_module).eval(), configs=self.qconfig_dicts
)
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
""" Convert the calibrated model to its finalized quantized version. """
self.quantized = self.convert(self.prepared, self.qconfig_dicts.keys())
def on_validation_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: Any,
batch: Any,
batch_idx: int,
dataloader_idx: int,
) -> None:
""" Also run the validation batch through the quantized model for calibration. """
if self.should_calibrate:
with torch.no_grad():
self.prepared(batch)
...@@ -4,10 +4,12 @@ ...@@ -4,10 +4,12 @@
import logging import logging
import os import os
from typing import Dict
import warnings import warnings
from contextlib import contextmanager
from typing import Dict, Iterator
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
import torch
from d2go.config import CfgNode from d2go.config import CfgNode
from fvcore.common.file_io import PathManager from fvcore.common.file_io import PathManager
from tabulate import tabulate from tabulate import tabulate
...@@ -16,6 +18,7 @@ from .tensorboard_log_util import get_tensorboard_log_dir ...@@ -16,6 +18,7 @@ from .tensorboard_log_util import get_tensorboard_log_dir
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def check_version(library, min_version, warning_only=False): def check_version(library, min_version, warning_only=False):
"""Check the version of the library satisfies the provided minimum version. """Check the version of the library satisfies the provided minimum version.
An exception is thrown if the check does not pass. An exception is thrown if the check does not pass.
...@@ -27,11 +30,14 @@ def check_version(library, min_version, warning_only=False): ...@@ -27,11 +30,14 @@ def check_version(library, min_version, warning_only=False):
Printing a warning instead of throwing an exception. Printing a warning instead of throwing an exception.
""" """
from distutils.version import LooseVersion from distutils.version import LooseVersion
version = library.__version__ version = library.__version__
bad_version = LooseVersion(version) < LooseVersion(min_version) bad_version = LooseVersion(version) < LooseVersion(min_version)
if bad_version: if bad_version:
msg = f'Installed {library.__name__} version {version} does not satisfy the ' \ msg = (
f'minimum required version {min_version}' f"Installed {library.__name__} version {version} does not satisfy the "
f"minimum required version {min_version}"
)
if warning_only: if warning_only:
warnings.warn(msg) warnings.warn(msg)
else: else:
...@@ -39,6 +45,7 @@ def check_version(library, min_version, warning_only=False): ...@@ -39,6 +45,7 @@ def check_version(library, min_version, warning_only=False):
return False return False
return True return True
def metrics_dict_to_metrics_table(dic): def metrics_dict_to_metrics_table(dic):
assert isinstance(dic, dict) assert isinstance(dic, dict)
ret = [] ret = []
...@@ -62,7 +69,9 @@ def print_metrics_table(metrics_dic): ...@@ -62,7 +69,9 @@ def print_metrics_table(metrics_dic):
logger.info("Metrics table: \n" + metrics_tabulate) logger.info("Metrics table: \n" + metrics_tabulate)
def dump_trained_model_configs(output_dir: str, trained_cfgs: Dict[str, CfgNode]) -> Dict[str, str]: def dump_trained_model_configs(
output_dir: str, trained_cfgs: Dict[str, CfgNode]
) -> Dict[str, str]:
"""Writes trained model config files to output_dir. """Writes trained model config files to output_dir.
Args: Args:
...@@ -82,3 +91,14 @@ def dump_trained_model_configs(output_dir: str, trained_cfgs: Dict[str, CfgNode] ...@@ -82,3 +91,14 @@ def dump_trained_model_configs(output_dir: str, trained_cfgs: Dict[str, CfgNode]
with PathManager.open(config_file, "w") as f: with PathManager.open(config_file, "w") as f:
f.write(trained_cfg.dump()) f.write(trained_cfg.dump())
return trained_model_configs return trained_model_configs
@contextmanager
def mode(net: torch.nn.Module, training: bool) -> Iterator[torch.nn.Module]:
"""Temporarily switch to training/evaluation mode."""
istrain = net.training
try:
net.train(training)
yield net
finally:
net.train(istrain)
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import os import os
from functools import wraps from functools import wraps
from tempfile import TemporaryDirectory
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -40,3 +41,13 @@ def enable_ddp_env(func): ...@@ -40,3 +41,13 @@ def enable_ddp_env(func):
return ret return ret
return wrapper return wrapper
def tempdir(func):
""" A decorator for creating a tempory directory that is cleaned up after function execution. """
@wraps(func)
def wrapper(self, *args, **kwargs):
with TemporaryDirectory() as temp:
return func(self, temp, *args, **kwargs)
return wrapper
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# pyre-unsafe
from functools import wraps
from tempfile import TemporaryDirectory
from typing import Optional
import torch
from pytorch_lightning import LightningModule
from torch.utils.data.dataset import Dataset
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class TestModule(LightningModule):
def __init__(self, epoch_min_loss_override: Optional[int] = None):
"""LightningModule for testing purposes
Args:
epoch_min_loss_override (int, optional): Pass in an epoch that will be set to the minimum
validation loss for testing purposes (zero based). If None this is ignored. Defaults to None.
"""
super().__init__()
self.layer = torch.nn.Linear(in_features=32, out_features=2)
self.another_layer = torch.nn.Linear(in_features=2, out_features=2)
self.epoch_min_loss_override = epoch_min_loss_override
def forward(self, x):
x = self.layer(x)
return self.another_layer(x)
def loss(self, batch, prediction):
# An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))
def training_step(self, batch, batch_idx):
output = self.forward(batch)
loss = self.loss(batch, output)
return {"output": output, "loss": loss, "checkpoint_on": loss}
def validation_step(self, batch, batch_idx):
output = self.forward(batch)
loss = self.loss(batch, output)
return {"output": output, "loss": loss, "checkpoint_on": loss}
def test_step(self, batch, batch_idx):
output = self.forward(batch)
loss = self.loss(batch, output)
return {"output": output, "loss": loss}
def training_epoch_end(self, outputs) -> None:
avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
self.log("avg_loss", avg_loss)
def validation_epoch_end(self, outputs) -> None:
avg_val_loss = torch.stack(
[torch.randn(1, requires_grad=True) for _ in outputs]
).mean()
# For testing purposes allow a nominated epoch to have a low loss
if self.current_epoch == self.epoch_min_loss_override:
avg_val_loss -= 1e10
self.log("val_loss", avg_val_loss)
self.log("checkpoint_on", avg_val_loss)
def test_epoch_end(self, outputs) -> None:
avg_loss = torch.stack(
[torch.randn(1, requires_grad=True) for _ in outputs]
).mean()
self.log("val_loss", avg_loss)
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=0.001)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]
def train_dataloader(self):
return torch.utils.data.DataLoader(RandomDataset(32, 64))
def val_dataloader(self):
return torch.utils.data.DataLoader(RandomDataset(32, 64))
def test_dataloader(self):
return torch.utils.data.DataLoader(RandomDataset(32, 64))
#!/usr/bin/env python3
# pyre-unsafe
import os
import unittest
import mock
import torch
from d2go.runner.callbacks.quantization import (
PostTrainingQuantization,
QuantizationAwareTraining,
ModelTransform,
get_default_qconfig,
get_default_qat_qconfig,
)
from d2go.tests.helper import tempdir
from d2go.tests.lightning_test_module import TestModule
from d2go.utils.misc import mode
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
from torch.quantization import ( # @manual; @manual
default_dynamic_qconfig,
get_default_qconfig,
)
class TestModelTransform(unittest.TestCase):
""" Tests ModelTransforms. """
def test_invalid_construction_type_error(self):
""" Validate construction of ModelTransforms. Always have fn, msg, and one of [step, interval]. """
with self.assertRaises(TypeError):
_ = ModelTransform()
with self.assertRaises(TypeError):
_ = ModelTransform(fn=lambda x: x)
with self.assertRaises(TypeError):
_ = ModelTransform(message="No function defined")
with self.assertRaises(TypeError):
_ = ModelTransform(
fn=lambda x: x,
message="Specified both step and interval",
step=1,
interval=1,
)
def test_positivity_value_error(self):
""" Validates ModelTransforms are constructed with only valid arguments. """
def identity(x):
return x
with self.assertRaises(ValueError):
_ = ModelTransform(fn=identity, message="Negative step", step=-1)
with self.assertRaises(ValueError):
_ = ModelTransform(fn=identity, message="Zero interval", interval=0)
with self.assertRaises(ValueError):
_ = ModelTransform(fn=identity, message="Negative interval", interval=-1)
class TestQuantizationAwareTraining(unittest.TestCase):
def test_qat_misconfiguration(self):
""" Tests failure when misconfiguring the QAT Callback. """
invalid_params = [
{"start_step": -1},
{"enable_observer": (42, 42)},
{"enable_observer": (42, 21)},
{"enable_observer": (-1, None)},
{"freeze_bn_step": -1},
]
for invalid_param in invalid_params:
with self.assertRaises(ValueError):
_ = QuantizationAwareTraining(**invalid_param)
def test_qat_transforms(self):
""" Tests the appropropriate ModelTransforms are defined with QAT."""
qat = QuantizationAwareTraining(
start_step=300, enable_observer=(350, 500), freeze_bn_step=550
)
trainer = Trainer()
module = TestModule()
qat.setup(trainer, module, stage="train")
self.assertGreater(len(qat.transforms), 0)
def assertContainsTransformsAtStep(step):
"""
Asserts at least one transform exists at the specified step and
that it is removed after the step begins.
"""
self.assertGreater(
len(
[
transform
for transform in qat.transforms
if transform.step == step
]
),
0,
f"step={step}",
)
trainer.global_step = step
qat.on_train_batch_start(
trainer, module, batch=None, batch_idx=0, dataloader_idx=0
)
self.assertEqual(
len(
[
transform
for transform in qat.transforms
if transform.step == step
]
),
0,
f"step={step}",
)
assertContainsTransformsAtStep(step=300)
assertContainsTransformsAtStep(step=350)
assertContainsTransformsAtStep(step=500)
assertContainsTransformsAtStep(step=550)
@tempdir
def test_qat_interval_transform(self, root_dir):
""" Tests an interval transform is applied multiple times. """
seed_everything(100)
def linear_fn_counter(mod):
if isinstance(mod, torch.nn.Linear):
linear_fn_counter.count += 1
linear_fn_counter.count = 0
model = TestModule()
num_epochs = 2
qat = QuantizationAwareTraining()
qat.transforms.append(
ModelTransform(fn=linear_fn_counter, message="Counter", interval=10)
)
trainer = Trainer(
default_root_dir=os.path.join(root_dir, "quantized"),
checkpoint_callback=False,
callbacks=[qat],
max_epochs=num_epochs,
logger=False,
)
trainer.fit(model)
# Model has 2 linear layers.
self.assertEqual(linear_fn_counter.count, 2 * (trainer.global_step // 10 + 1))
@tempdir
def test_module_quantized_during_train(self, root_dir):
""" Validate quantized aware training works as expected. """
seed_everything(100)
model = TestModule()
test_in = torch.randn(1, 32)
before_train = model.eval()(test_in)
num_epochs = 2
qat = QuantizationAwareTraining()
trainer = Trainer(
default_root_dir=os.path.join(root_dir, "quantized"),
checkpoint_callback=False,
callbacks=[qat],
max_epochs=num_epochs,
logger=False,
)
trainer.fit(model)
self.assertIsNotNone(qat.prepared)
self.assertIsNotNone(qat.quantized)
test_out = model.eval()(test_in)
self.assertGreater(
(test_out ** 2).sum(), 0.03, "With the given seend, L2^2 should be > 0.03."
)
base_out = qat.quantized.eval()(test_in)
self.assertTrue(torch.allclose(base_out, test_out))
# Weight changed during training.
self.assertFalse(torch.allclose(before_train, test_out))
# Validate .test() call works as expected and does not change model weights.
trainer.test(model)
self.assertTrue(torch.allclose(test_out, model.eval()(test_in)))
@tempdir
def test_quantization_without_train(self, root_dir):
""" Validate quantization occurs even without a call to .fit() first. """
seed_everything(100)
model = TestModule()
num_epochs = 2
qat = QuantizationAwareTraining()
trainer = Trainer(
default_root_dir=os.path.join(root_dir, "quantized"),
checkpoint_callback=False,
callbacks=[qat],
max_epochs=num_epochs,
logger=False,
)
trainer.test(model)
self.assertIsNotNone(qat.prepared)
self.assertIsNotNone(qat.quantized)
@tempdir
def test_quantization_and_checkpointing(self, root_dir):
""" Validate written checkpoints can be loaded back as expected. """
seed_everything(100)
model = TestModule()
num_epochs = 2
qat = QuantizationAwareTraining()
checkpoint_dir = os.path.join(root_dir, "checkpoints")
checkpoint = ModelCheckpoint(dirpath=checkpoint_dir, save_last=True)
trainer = Trainer(
default_root_dir=os.path.join(root_dir, "quantized"),
checkpoint_callback=checkpoint,
callbacks=[qat],
max_epochs=num_epochs,
logger=False,
)
# Mimick failing mid-training by not running on_fit_end.
with mock.patch.object(qat, "on_fit_end"):
trainer.fit(model)
ckpt = torch.load(os.path.join(checkpoint_dir, "last.ckpt"))
model.load_state_dict(ckpt["state_dict"])
@tempdir
def test_custom_qat(self, root_dir):
""" Tests that we can customize QAT by skipping certain layers. """
class _CustomQAT(QuantizationAwareTraining):
""" Only quantize TestModule.another_layer. """
def prepare(self, model, configs):
model.another_layer = prepare_qat_fx(model.another_layer, configs[""])
return model
def convert(self, model, submodules):
model.another_layer = convert_fx(model.another_layer)
return model
seed_everything(100)
model = TestModule()
test_in = torch.randn(1, 32)
before_train = model.eval()(test_in)
num_epochs = 2
qat = _CustomQAT()
trainer = Trainer(
default_root_dir=os.path.join(root_dir, "quantized"),
checkpoint_callback=False,
callbacks=[qat],
max_epochs=num_epochs,
logger=False,
)
trainer.fit(model)
self.assertIsNotNone(qat.prepared)
self.assertIsNotNone(qat.quantized)
test_out = model.eval()(test_in)
self.assertGreater(
(test_out ** 2).sum(), 0.03, "With the given seend, L2^2 should be > 0.03."
)
base_out = qat.quantized.eval()(test_in)
self.assertTrue(torch.allclose(base_out, test_out))
# Weight changed during training.
self.assertFalse(torch.allclose(before_train, test_out))
# Validate .test() call works as expected and does not change model weights.
trainer.test(model)
self.assertTrue(torch.allclose(test_out, model.eval()(test_in)))
@tempdir
def test_submodule_qat(self, root_dir):
""" Tests that we can customize QAT through exposed API. """
seed_everything(100)
model = TestModule()
test_in = torch.randn(1, 32)
before_train = model.eval()(test_in)
num_epochs = 2
qat = QuantizationAwareTraining(
qconfig_dicts={"another_layer": {"": get_default_qat_qconfig()}}
)
trainer = Trainer(
default_root_dir=os.path.join(root_dir, "quantized"),
checkpoint_callback=False,
callbacks=[qat],
max_epochs=num_epochs,
logger=False,
)
trainer.fit(model)
self.assertIsNotNone(qat.prepared)
self.assertIsNotNone(qat.quantized)
test_out = model.eval()(test_in)
self.assertGreater(
(test_out ** 2).sum(), 0.03, "With the given seend, L2^2 should be > 0.03."
)
base_out = qat.quantized.eval()(test_in)
self.assertTrue(torch.allclose(base_out, test_out))
# Weight changed during training.
self.assertFalse(torch.allclose(before_train, test_out))
# Validate .test() call works as expected and does not change model weights.
trainer.test(model)
self.assertTrue(torch.allclose(test_out, model.eval()(test_in)))
class TestPostTrainingQuantization(unittest.TestCase):
@tempdir
def test_post_training_static_quantization(self, root_dir):
""" Validate post-training static quantization. """
seed_everything(100)
model = TestModule()
num_epochs = 4
static_quantization = PostTrainingQuantization(
qconfig_dicts={"": {"": get_default_qconfig()}}
)
trainer = Trainer(
default_root_dir=os.path.join(root_dir, "quantized"),
checkpoint_callback=False,
callbacks=[static_quantization],
max_epochs=num_epochs,
logger=False,
)
# This will both train the model + quantize it.
trainer.fit(model)
self.assertIsNotNone(static_quantization.quantized)
# Default qconfig requires calibration.
self.assertTrue(static_quantization.should_calibrate)
test_in = torch.randn(12, 32)
with mode(model, training=False) as m:
base_out = m(test_in)
with mode(static_quantization.quantized, training=False) as q:
test_out = q(test_in)
# While quantized/original won't be exact, they should be close.
self.assertLess(
((((test_out - base_out) ** 2).sum(axis=1)) ** (1 / 2)).mean(),
0.015,
"RMSE should be less than 0.015 between quantized and original.",
)
@tempdir
def test_post_training_dynamic_quantization(self, root_dir):
""" Validates post-training dynamic quantization. """
seed_everything(100)
model = TestModule()
num_epochs = 2
dynamic_quant = PostTrainingQuantization(
qconfig_dicts={"": {"": default_dynamic_qconfig}}
)
trainer = Trainer(
default_root_dir=os.path.join(root_dir, "quantized"),
checkpoint_callback=False,
callbacks=[dynamic_quant],
max_epochs=num_epochs,
logger=False,
)
# This will both train the model + quantize it.
trainer.fit(model)
self.assertIsNotNone(dynamic_quant.quantized)
# Default qconfig requires calibration.
self.assertFalse(dynamic_quant.should_calibrate)
test_in = torch.randn(12, 32)
with mode(model, training=False) as m:
base_out = m(test_in)
with mode(dynamic_quant.quantized, training=False) as q:
test_out = q(test_in)
# While quantized/original won't be exact, they should be close.
self.assertLess(
((((test_out - base_out) ** 2).sum(axis=1)) ** (1 / 2)).mean(),
0.015,
"RMSE should be less than 0.015 between quantized and original.",
)
@tempdir
def test_custom_post_training_static_quant(self, root_dir):
""" Tests that we can customize Post-Training static by skipping certain layers. """
class _CustomStaticQuant(PostTrainingQuantization):
""" Only quantize TestModule.another_layer. """
def prepare(self, model, configs):
model.another_layer = prepare_fx(model.another_layer, configs[""])
return model
def convert(self, model, submodules):
model.another_layer = convert_fx(model.another_layer)
return model
seed_everything(100)
model = TestModule()
num_epochs = 2
static_quantization = _CustomStaticQuant()
trainer = Trainer(
default_root_dir=os.path.join(root_dir, "quantized"),
checkpoint_callback=False,
callbacks=[static_quantization],
max_epochs=num_epochs,
logger=False,
)
trainer.fit(model)
self.assertIsNotNone(static_quantization.quantized)
test_in = torch.randn(12, 32)
with mode(model, training=False) as m:
base_out = m(test_in)
with mode(static_quantization.quantized, training=False) as q:
test_out = q(test_in)
# While quantized/original won't be exact, they should be close.
self.assertLess(
((((test_out - base_out) ** 2).sum(axis=1)) ** (1 / 2)).mean(),
0.015,
"RMSE should be less than 0.007 between quantized and original.",
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment