Commit 5d8068d8 authored by Kai Zhang's avatar Kai Zhang Committed by Facebook GitHub Bot
Browse files

Copy quantization callback to D2go

Summary: As titled. Make a copy of quantization callback to unblock D2go OSS.

Reviewed By: zhanghang1989

Differential Revision: D26735525

fbshipit-source-id: 12b77f04cfa1361e856b26ea218a262da1fadd88
parent f23248c0
This diff is collapsed.
...@@ -4,10 +4,12 @@ ...@@ -4,10 +4,12 @@
import logging import logging
import os import os
from typing import Dict
import warnings import warnings
from contextlib import contextmanager
from typing import Dict, Iterator
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
import torch
from d2go.config import CfgNode from d2go.config import CfgNode
from fvcore.common.file_io import PathManager from fvcore.common.file_io import PathManager
from tabulate import tabulate from tabulate import tabulate
...@@ -16,6 +18,7 @@ from .tensorboard_log_util import get_tensorboard_log_dir ...@@ -16,6 +18,7 @@ from .tensorboard_log_util import get_tensorboard_log_dir
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def check_version(library, min_version, warning_only=False): def check_version(library, min_version, warning_only=False):
"""Check the version of the library satisfies the provided minimum version. """Check the version of the library satisfies the provided minimum version.
An exception is thrown if the check does not pass. An exception is thrown if the check does not pass.
...@@ -27,11 +30,14 @@ def check_version(library, min_version, warning_only=False): ...@@ -27,11 +30,14 @@ def check_version(library, min_version, warning_only=False):
Printing a warning instead of throwing an exception. Printing a warning instead of throwing an exception.
""" """
from distutils.version import LooseVersion from distutils.version import LooseVersion
version = library.__version__ version = library.__version__
bad_version = LooseVersion(version) < LooseVersion(min_version) bad_version = LooseVersion(version) < LooseVersion(min_version)
if bad_version: if bad_version:
msg = f'Installed {library.__name__} version {version} does not satisfy the ' \ msg = (
f'minimum required version {min_version}' f"Installed {library.__name__} version {version} does not satisfy the "
f"minimum required version {min_version}"
)
if warning_only: if warning_only:
warnings.warn(msg) warnings.warn(msg)
else: else:
...@@ -39,6 +45,7 @@ def check_version(library, min_version, warning_only=False): ...@@ -39,6 +45,7 @@ def check_version(library, min_version, warning_only=False):
return False return False
return True return True
def metrics_dict_to_metrics_table(dic): def metrics_dict_to_metrics_table(dic):
assert isinstance(dic, dict) assert isinstance(dic, dict)
ret = [] ret = []
...@@ -62,7 +69,9 @@ def print_metrics_table(metrics_dic): ...@@ -62,7 +69,9 @@ def print_metrics_table(metrics_dic):
logger.info("Metrics table: \n" + metrics_tabulate) logger.info("Metrics table: \n" + metrics_tabulate)
def dump_trained_model_configs(output_dir: str, trained_cfgs: Dict[str, CfgNode]) -> Dict[str, str]: def dump_trained_model_configs(
output_dir: str, trained_cfgs: Dict[str, CfgNode]
) -> Dict[str, str]:
"""Writes trained model config files to output_dir. """Writes trained model config files to output_dir.
Args: Args:
...@@ -82,3 +91,14 @@ def dump_trained_model_configs(output_dir: str, trained_cfgs: Dict[str, CfgNode] ...@@ -82,3 +91,14 @@ def dump_trained_model_configs(output_dir: str, trained_cfgs: Dict[str, CfgNode]
with PathManager.open(config_file, "w") as f: with PathManager.open(config_file, "w") as f:
f.write(trained_cfg.dump()) f.write(trained_cfg.dump())
return trained_model_configs return trained_model_configs
@contextmanager
def mode(net: torch.nn.Module, training: bool) -> Iterator[torch.nn.Module]:
"""Temporarily switch to training/evaluation mode."""
istrain = net.training
try:
net.train(training)
yield net
finally:
net.train(istrain)
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import os import os
from functools import wraps from functools import wraps
from tempfile import TemporaryDirectory
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -40,3 +41,13 @@ def enable_ddp_env(func): ...@@ -40,3 +41,13 @@ def enable_ddp_env(func):
return ret return ret
return wrapper return wrapper
def tempdir(func):
""" A decorator for creating a tempory directory that is cleaned up after function execution. """
@wraps(func)
def wrapper(self, *args, **kwargs):
with TemporaryDirectory() as temp:
return func(self, temp, *args, **kwargs)
return wrapper
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# pyre-unsafe
from functools import wraps
from tempfile import TemporaryDirectory
from typing import Optional
import torch
from pytorch_lightning import LightningModule
from torch.utils.data.dataset import Dataset
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class TestModule(LightningModule):
def __init__(self, epoch_min_loss_override: Optional[int] = None):
"""LightningModule for testing purposes
Args:
epoch_min_loss_override (int, optional): Pass in an epoch that will be set to the minimum
validation loss for testing purposes (zero based). If None this is ignored. Defaults to None.
"""
super().__init__()
self.layer = torch.nn.Linear(in_features=32, out_features=2)
self.another_layer = torch.nn.Linear(in_features=2, out_features=2)
self.epoch_min_loss_override = epoch_min_loss_override
def forward(self, x):
x = self.layer(x)
return self.another_layer(x)
def loss(self, batch, prediction):
# An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))
def training_step(self, batch, batch_idx):
output = self.forward(batch)
loss = self.loss(batch, output)
return {"output": output, "loss": loss, "checkpoint_on": loss}
def validation_step(self, batch, batch_idx):
output = self.forward(batch)
loss = self.loss(batch, output)
return {"output": output, "loss": loss, "checkpoint_on": loss}
def test_step(self, batch, batch_idx):
output = self.forward(batch)
loss = self.loss(batch, output)
return {"output": output, "loss": loss}
def training_epoch_end(self, outputs) -> None:
avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
self.log("avg_loss", avg_loss)
def validation_epoch_end(self, outputs) -> None:
avg_val_loss = torch.stack(
[torch.randn(1, requires_grad=True) for _ in outputs]
).mean()
# For testing purposes allow a nominated epoch to have a low loss
if self.current_epoch == self.epoch_min_loss_override:
avg_val_loss -= 1e10
self.log("val_loss", avg_val_loss)
self.log("checkpoint_on", avg_val_loss)
def test_epoch_end(self, outputs) -> None:
avg_loss = torch.stack(
[torch.randn(1, requires_grad=True) for _ in outputs]
).mean()
self.log("val_loss", avg_loss)
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=0.001)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]
def train_dataloader(self):
return torch.utils.data.DataLoader(RandomDataset(32, 64))
def val_dataloader(self):
return torch.utils.data.DataLoader(RandomDataset(32, 64))
def test_dataloader(self):
return torch.utils.data.DataLoader(RandomDataset(32, 64))
#!/usr/bin/env python3
# pyre-unsafe
import os
import unittest
import mock
import torch
from d2go.runner.callbacks.quantization import (
PostTrainingQuantization,
QuantizationAwareTraining,
ModelTransform,
get_default_qconfig,
get_default_qat_qconfig,
)
from d2go.tests.helper import tempdir
from d2go.tests.lightning_test_module import TestModule
from d2go.utils.misc import mode
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
from torch.quantization import ( # @manual; @manual
default_dynamic_qconfig,
get_default_qconfig,
)
class TestModelTransform(unittest.TestCase):
""" Tests ModelTransforms. """
def test_invalid_construction_type_error(self):
""" Validate construction of ModelTransforms. Always have fn, msg, and one of [step, interval]. """
with self.assertRaises(TypeError):
_ = ModelTransform()
with self.assertRaises(TypeError):
_ = ModelTransform(fn=lambda x: x)
with self.assertRaises(TypeError):
_ = ModelTransform(message="No function defined")
with self.assertRaises(TypeError):
_ = ModelTransform(
fn=lambda x: x,
message="Specified both step and interval",
step=1,
interval=1,
)
def test_positivity_value_error(self):
""" Validates ModelTransforms are constructed with only valid arguments. """
def identity(x):
return x
with self.assertRaises(ValueError):
_ = ModelTransform(fn=identity, message="Negative step", step=-1)
with self.assertRaises(ValueError):
_ = ModelTransform(fn=identity, message="Zero interval", interval=0)
with self.assertRaises(ValueError):
_ = ModelTransform(fn=identity, message="Negative interval", interval=-1)
class TestQuantizationAwareTraining(unittest.TestCase):
def test_qat_misconfiguration(self):
""" Tests failure when misconfiguring the QAT Callback. """
invalid_params = [
{"start_step": -1},
{"enable_observer": (42, 42)},
{"enable_observer": (42, 21)},
{"enable_observer": (-1, None)},
{"freeze_bn_step": -1},
]
for invalid_param in invalid_params:
with self.assertRaises(ValueError):
_ = QuantizationAwareTraining(**invalid_param)
def test_qat_transforms(self):
""" Tests the appropropriate ModelTransforms are defined with QAT."""
qat = QuantizationAwareTraining(
start_step=300, enable_observer=(350, 500), freeze_bn_step=550
)
trainer = Trainer()
module = TestModule()
qat.setup(trainer, module, stage="train")
self.assertGreater(len(qat.transforms), 0)
def assertContainsTransformsAtStep(step):
"""
Asserts at least one transform exists at the specified step and
that it is removed after the step begins.
"""
self.assertGreater(
len(
[
transform
for transform in qat.transforms
if transform.step == step
]
),
0,
f"step={step}",
)
trainer.global_step = step
qat.on_train_batch_start(
trainer, module, batch=None, batch_idx=0, dataloader_idx=0
)
self.assertEqual(
len(
[
transform
for transform in qat.transforms
if transform.step == step
]
),
0,
f"step={step}",
)
assertContainsTransformsAtStep(step=300)
assertContainsTransformsAtStep(step=350)
assertContainsTransformsAtStep(step=500)
assertContainsTransformsAtStep(step=550)
@tempdir
def test_qat_interval_transform(self, root_dir):
""" Tests an interval transform is applied multiple times. """
seed_everything(100)
def linear_fn_counter(mod):
if isinstance(mod, torch.nn.Linear):
linear_fn_counter.count += 1
linear_fn_counter.count = 0
model = TestModule()
num_epochs = 2
qat = QuantizationAwareTraining()
qat.transforms.append(
ModelTransform(fn=linear_fn_counter, message="Counter", interval=10)
)
trainer = Trainer(
default_root_dir=os.path.join(root_dir, "quantized"),
checkpoint_callback=False,
callbacks=[qat],
max_epochs=num_epochs,
logger=False,
)
trainer.fit(model)
# Model has 2 linear layers.
self.assertEqual(linear_fn_counter.count, 2 * (trainer.global_step // 10 + 1))
@tempdir
def test_module_quantized_during_train(self, root_dir):
""" Validate quantized aware training works as expected. """
seed_everything(100)
model = TestModule()
test_in = torch.randn(1, 32)
before_train = model.eval()(test_in)
num_epochs = 2
qat = QuantizationAwareTraining()
trainer = Trainer(
default_root_dir=os.path.join(root_dir, "quantized"),
checkpoint_callback=False,
callbacks=[qat],
max_epochs=num_epochs,
logger=False,
)
trainer.fit(model)
self.assertIsNotNone(qat.prepared)
self.assertIsNotNone(qat.quantized)
test_out = model.eval()(test_in)
self.assertGreater(
(test_out ** 2).sum(), 0.03, "With the given seend, L2^2 should be > 0.03."
)
base_out = qat.quantized.eval()(test_in)
self.assertTrue(torch.allclose(base_out, test_out))
# Weight changed during training.
self.assertFalse(torch.allclose(before_train, test_out))
# Validate .test() call works as expected and does not change model weights.
trainer.test(model)
self.assertTrue(torch.allclose(test_out, model.eval()(test_in)))
@tempdir
def test_quantization_without_train(self, root_dir):
""" Validate quantization occurs even without a call to .fit() first. """
seed_everything(100)
model = TestModule()
num_epochs = 2
qat = QuantizationAwareTraining()
trainer = Trainer(
default_root_dir=os.path.join(root_dir, "quantized"),
checkpoint_callback=False,
callbacks=[qat],
max_epochs=num_epochs,
logger=False,
)
trainer.test(model)
self.assertIsNotNone(qat.prepared)
self.assertIsNotNone(qat.quantized)
@tempdir
def test_quantization_and_checkpointing(self, root_dir):
""" Validate written checkpoints can be loaded back as expected. """
seed_everything(100)
model = TestModule()
num_epochs = 2
qat = QuantizationAwareTraining()
checkpoint_dir = os.path.join(root_dir, "checkpoints")
checkpoint = ModelCheckpoint(dirpath=checkpoint_dir, save_last=True)
trainer = Trainer(
default_root_dir=os.path.join(root_dir, "quantized"),
checkpoint_callback=checkpoint,
callbacks=[qat],
max_epochs=num_epochs,
logger=False,
)
# Mimick failing mid-training by not running on_fit_end.
with mock.patch.object(qat, "on_fit_end"):
trainer.fit(model)
ckpt = torch.load(os.path.join(checkpoint_dir, "last.ckpt"))
model.load_state_dict(ckpt["state_dict"])
@tempdir
def test_custom_qat(self, root_dir):
""" Tests that we can customize QAT by skipping certain layers. """
class _CustomQAT(QuantizationAwareTraining):
""" Only quantize TestModule.another_layer. """
def prepare(self, model, configs):
model.another_layer = prepare_qat_fx(model.another_layer, configs[""])
return model
def convert(self, model, submodules):
model.another_layer = convert_fx(model.another_layer)
return model
seed_everything(100)
model = TestModule()
test_in = torch.randn(1, 32)
before_train = model.eval()(test_in)
num_epochs = 2
qat = _CustomQAT()
trainer = Trainer(
default_root_dir=os.path.join(root_dir, "quantized"),
checkpoint_callback=False,
callbacks=[qat],
max_epochs=num_epochs,
logger=False,
)
trainer.fit(model)
self.assertIsNotNone(qat.prepared)
self.assertIsNotNone(qat.quantized)
test_out = model.eval()(test_in)
self.assertGreater(
(test_out ** 2).sum(), 0.03, "With the given seend, L2^2 should be > 0.03."
)
base_out = qat.quantized.eval()(test_in)
self.assertTrue(torch.allclose(base_out, test_out))
# Weight changed during training.
self.assertFalse(torch.allclose(before_train, test_out))
# Validate .test() call works as expected and does not change model weights.
trainer.test(model)
self.assertTrue(torch.allclose(test_out, model.eval()(test_in)))
@tempdir
def test_submodule_qat(self, root_dir):
""" Tests that we can customize QAT through exposed API. """
seed_everything(100)
model = TestModule()
test_in = torch.randn(1, 32)
before_train = model.eval()(test_in)
num_epochs = 2
qat = QuantizationAwareTraining(
qconfig_dicts={"another_layer": {"": get_default_qat_qconfig()}}
)
trainer = Trainer(
default_root_dir=os.path.join(root_dir, "quantized"),
checkpoint_callback=False,
callbacks=[qat],
max_epochs=num_epochs,
logger=False,
)
trainer.fit(model)
self.assertIsNotNone(qat.prepared)
self.assertIsNotNone(qat.quantized)
test_out = model.eval()(test_in)
self.assertGreater(
(test_out ** 2).sum(), 0.03, "With the given seend, L2^2 should be > 0.03."
)
base_out = qat.quantized.eval()(test_in)
self.assertTrue(torch.allclose(base_out, test_out))
# Weight changed during training.
self.assertFalse(torch.allclose(before_train, test_out))
# Validate .test() call works as expected and does not change model weights.
trainer.test(model)
self.assertTrue(torch.allclose(test_out, model.eval()(test_in)))
class TestPostTrainingQuantization(unittest.TestCase):
@tempdir
def test_post_training_static_quantization(self, root_dir):
""" Validate post-training static quantization. """
seed_everything(100)
model = TestModule()
num_epochs = 4
static_quantization = PostTrainingQuantization(
qconfig_dicts={"": {"": get_default_qconfig()}}
)
trainer = Trainer(
default_root_dir=os.path.join(root_dir, "quantized"),
checkpoint_callback=False,
callbacks=[static_quantization],
max_epochs=num_epochs,
logger=False,
)
# This will both train the model + quantize it.
trainer.fit(model)
self.assertIsNotNone(static_quantization.quantized)
# Default qconfig requires calibration.
self.assertTrue(static_quantization.should_calibrate)
test_in = torch.randn(12, 32)
with mode(model, training=False) as m:
base_out = m(test_in)
with mode(static_quantization.quantized, training=False) as q:
test_out = q(test_in)
# While quantized/original won't be exact, they should be close.
self.assertLess(
((((test_out - base_out) ** 2).sum(axis=1)) ** (1 / 2)).mean(),
0.015,
"RMSE should be less than 0.015 between quantized and original.",
)
@tempdir
def test_post_training_dynamic_quantization(self, root_dir):
""" Validates post-training dynamic quantization. """
seed_everything(100)
model = TestModule()
num_epochs = 2
dynamic_quant = PostTrainingQuantization(
qconfig_dicts={"": {"": default_dynamic_qconfig}}
)
trainer = Trainer(
default_root_dir=os.path.join(root_dir, "quantized"),
checkpoint_callback=False,
callbacks=[dynamic_quant],
max_epochs=num_epochs,
logger=False,
)
# This will both train the model + quantize it.
trainer.fit(model)
self.assertIsNotNone(dynamic_quant.quantized)
# Default qconfig requires calibration.
self.assertFalse(dynamic_quant.should_calibrate)
test_in = torch.randn(12, 32)
with mode(model, training=False) as m:
base_out = m(test_in)
with mode(dynamic_quant.quantized, training=False) as q:
test_out = q(test_in)
# While quantized/original won't be exact, they should be close.
self.assertLess(
((((test_out - base_out) ** 2).sum(axis=1)) ** (1 / 2)).mean(),
0.015,
"RMSE should be less than 0.015 between quantized and original.",
)
@tempdir
def test_custom_post_training_static_quant(self, root_dir):
""" Tests that we can customize Post-Training static by skipping certain layers. """
class _CustomStaticQuant(PostTrainingQuantization):
""" Only quantize TestModule.another_layer. """
def prepare(self, model, configs):
model.another_layer = prepare_fx(model.another_layer, configs[""])
return model
def convert(self, model, submodules):
model.another_layer = convert_fx(model.another_layer)
return model
seed_everything(100)
model = TestModule()
num_epochs = 2
static_quantization = _CustomStaticQuant()
trainer = Trainer(
default_root_dir=os.path.join(root_dir, "quantized"),
checkpoint_callback=False,
callbacks=[static_quantization],
max_epochs=num_epochs,
logger=False,
)
trainer.fit(model)
self.assertIsNotNone(static_quantization.quantized)
test_in = torch.randn(12, 32)
with mode(model, training=False) as m:
base_out = m(test_in)
with mode(static_quantization.quantized, training=False) as q:
test_out = q(test_in)
# While quantized/original won't be exact, they should be close.
self.assertLess(
((((test_out - base_out) ** 2).sum(axis=1)) ** (1 / 2)).mean(),
0.015,
"RMSE should be less than 0.007 between quantized and original.",
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment