Commit b4cb5765 authored by raghuramank100's avatar raghuramank100 Committed by Francisco Massa
Browse files

Quantizable resnet and mobilenet models (#1471)

* add quantized models

* Modify mobilenet.py documentation and clean up comments
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Move fuse_model method to QuantizableInvertedResidual and clean up args documentation
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Restore relu settings to default in resnet.py
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Fix missing return in forward
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Fix missing return in forwards
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Change pretrained -> pretrained_float_models
Replace InvertedResidual with block

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Update tests to follow similar structure to test_models.py, allowing for modular testing

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Replace forward method with simple function assignment

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Fix error in arguments for resnet18

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* pretrained_float_model argument missing for mobilenet

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* reference script for quantization aware training and post training quantization

* reference script for quantization aware training and post training quantization

* set pretrained_float_model as False and explicitly provide float model

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Address review comments:
1. Replace forward with _forward
2. Use pretrained models in reference train/eval script
3. Modify test to skip if fbgemm is not supported

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Fix lint errors.
Use _forward for common code between float and quantized models
Clean up linting for reference train scripts
Test over all quantizable models

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Update default values for args in quantization/train.py

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Update models to conform to new API with quantize argument
Remove apex in training script, add post training quant as an option
Add support for separate calibration data set.

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Fix minor errors in train_quantization.py

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Remove duplicate file

* Bugfix

* Minor improvements on the models

* Expose print_freq to evaluate

* Minor improvements on train_quantization.py

* Ensure that quantized models are created and run on the specified backends
Fix errors in test only mode

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Add model urls

* Fix errors in quantized model tests.
Speedup creation of random quantized model by removing histogram observers

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Move setting qengine prior to convert.

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Fix lint error

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Add readme.md

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Readme.md

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Fix lint
parent e79caddf
...@@ -28,3 +28,32 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ ...@@ -28,3 +28,32 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--model mobilenet_v2 --epochs 300 --lr 0.045 --wd 0.00004\ --model mobilenet_v2 --epochs 300 --lr 0.045 --wd 0.00004\
--lr-step-size 1 --lr-gamma 0.98 --lr-step-size 1 --lr-gamma 0.98
``` ```
## Quantized
### Parameters used for generating quantized models:
For all post training quantized models (All quantized models except mobilenet-v2), the settings are:
1. num_calibration_batches: 32
2. num_workers: 16
3. batch_size: 32
4. eval_batch_size: 128
5. backend: 'fbgemm'
For Mobilenet-v2, the model was trained with quantization aware training, the settings used are:
1. num_workers: 16
2. batch_size: 32
3. eval_batch_size: 128
4. backend: 'qnnpack'
5. learning-rate: 0.0001
6. num_epochs: 90
7. num_observer_update_epochs:4
8. num_batch_norm_update_epochs:3
9. momentum: 0.9
10. lr_step_size:30
11. lr_gamma: 0.1
Training converges at about 10 epochs.
For post training quant, device is set to CPU. For training, the device is set to CUDA
...@@ -47,12 +47,12 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, pri ...@@ -47,12 +47,12 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, pri
metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time)) metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time))
def evaluate(model, criterion, data_loader, device): def evaluate(model, criterion, data_loader, device, print_freq=100):
model.eval() model.eval()
metric_logger = utils.MetricLogger(delimiter=" ") metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:' header = 'Test:'
with torch.no_grad(): with torch.no_grad():
for image, target in metric_logger.log_every(data_loader, 100, header): for image, target in metric_logger.log_every(data_loader, print_freq, header):
image = image.to(device, non_blocking=True) image = image.to(device, non_blocking=True)
target = target.to(device, non_blocking=True) target = target.to(device, non_blocking=True)
output = model(image) output = model(image)
...@@ -81,35 +81,16 @@ def _get_cache_path(filepath): ...@@ -81,35 +81,16 @@ def _get_cache_path(filepath):
return cache_path return cache_path
def main(args): def load_data(traindir, valdir, cache_dataset, distributed):
if args.apex:
if sys.version_info < (3, 0):
raise RuntimeError("Apex currently only supports Python 3. Aborting.")
if amp is None:
raise RuntimeError("Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
"to enable mixed-precision training.")
if args.output_dir:
utils.mkdir(args.output_dir)
utils.init_distributed_mode(args)
print(args)
device = torch.device(args.device)
torch.backends.cudnn.benchmark = True
# Data loading code # Data loading code
print("Loading data") print("Loading data")
traindir = os.path.join(args.data_path, 'train')
valdir = os.path.join(args.data_path, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]) std=[0.229, 0.224, 0.225])
print("Loading training data") print("Loading training data")
st = time.time() st = time.time()
cache_path = _get_cache_path(traindir) cache_path = _get_cache_path(traindir)
if args.cache_dataset and os.path.exists(cache_path): if cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached! # Attention, as the transforms are also cached!
print("Loading dataset_train from {}".format(cache_path)) print("Loading dataset_train from {}".format(cache_path))
dataset, _ = torch.load(cache_path) dataset, _ = torch.load(cache_path)
...@@ -122,7 +103,7 @@ def main(args): ...@@ -122,7 +103,7 @@ def main(args):
transforms.ToTensor(), transforms.ToTensor(),
normalize, normalize,
])) ]))
if args.cache_dataset: if cache_dataset:
print("Saving dataset_train to {}".format(cache_path)) print("Saving dataset_train to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path)) utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset, traindir), cache_path) utils.save_on_master((dataset, traindir), cache_path)
...@@ -130,7 +111,7 @@ def main(args): ...@@ -130,7 +111,7 @@ def main(args):
print("Loading validation data") print("Loading validation data")
cache_path = _get_cache_path(valdir) cache_path = _get_cache_path(valdir)
if args.cache_dataset and os.path.exists(cache_path): if cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached! # Attention, as the transforms are also cached!
print("Loading dataset_test from {}".format(cache_path)) print("Loading dataset_test from {}".format(cache_path))
dataset_test, _ = torch.load(cache_path) dataset_test, _ = torch.load(cache_path)
...@@ -143,19 +124,44 @@ def main(args): ...@@ -143,19 +124,44 @@ def main(args):
transforms.ToTensor(), transforms.ToTensor(),
normalize, normalize,
])) ]))
if args.cache_dataset: if cache_dataset:
print("Saving dataset_test to {}".format(cache_path)) print("Saving dataset_test to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path)) utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset_test, valdir), cache_path) utils.save_on_master((dataset_test, valdir), cache_path)
print("Creating data loaders") print("Creating data loaders")
if args.distributed: if distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test) test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
else: else:
train_sampler = torch.utils.data.RandomSampler(dataset) train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test) test_sampler = torch.utils.data.SequentialSampler(dataset_test)
return dataset, dataset_test, train_sampler, test_sampler
def main(args):
if args.apex:
if sys.version_info < (3, 0):
raise RuntimeError("Apex currently only supports Python 3. Aborting.")
if amp is None:
raise RuntimeError("Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
"to enable mixed-precision training.")
if args.output_dir:
utils.mkdir(args.output_dir)
utils.init_distributed_mode(args)
print(args)
device = torch.device(args.device)
torch.backends.cudnn.benchmark = True
train_dir = os.path.join(args.data_path, 'train')
val_dir = os.path.join(args.data_path, 'val')
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir,
args.cache_dataset, args.distributed)
data_loader = torch.utils.data.DataLoader( data_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size, dataset, batch_size=args.batch_size,
sampler=train_sampler, num_workers=args.workers, pin_memory=True) sampler=train_sampler, num_workers=args.workers, pin_memory=True)
......
from __future__ import print_function
import datetime
import os
import time
import sys
import copy
import torch
import torch.utils.data
from torch import nn
import torchvision
import torch.quantization
import utils
from train import train_one_epoch, evaluate, load_data
def main(args):
if args.output_dir:
utils.mkdir(args.output_dir)
utils.init_distributed_mode(args)
print(args)
if args.post_training_quantize and args.distributed:
raise RuntimeError("Post training quantization example should not be performed "
"on distributed mode")
# Set backend engine to ensure that quantized model runs on the correct kernels
if args.backend not in torch.backends.quantized.supported_engines:
raise RuntimeError("Quantized backend not supported: " + str(args.backend))
torch.backends.quantized.engine = args.backend
device = torch.device(args.device)
torch.backends.cudnn.benchmark = True
# Data loading code
print("Loading data")
train_dir = os.path.join(args.data_path, 'train')
val_dir = os.path.join(args.data_path, 'val')
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir,
args.cache_dataset, args.distributed)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size,
sampler=train_sampler, num_workers=args.workers, pin_memory=True)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=args.eval_batch_size,
sampler=test_sampler, num_workers=args.workers, pin_memory=True)
print("Creating model", args.model)
# when training quantized models, we always start from a pre-trained fp32 reference model
model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only)
model.to(device)
if not (args.test_only or args.post_training_quantize):
model.fuse_model()
model.qconfig = torch.quantization.get_default_qat_qconfig(args.backend)
torch.quantization.prepare_qat(model, inplace=True)
optimizer = torch.optim.SGD(
model.parameters(), lr=args.lr, momentum=args.momentum,
weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
step_size=args.lr_step_size,
gamma=args.lr_gamma)
criterion = nn.CrossEntropyLoss()
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
if args.post_training_quantize:
# perform calibration on a subset of the training dataset
# for that, create a subset of the training dataset
ds = torch.utils.data.Subset(
dataset,
indices=list(range(args.batch_size * args.num_calibration_batches)))
data_loader_calibration = torch.utils.data.DataLoader(
ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers,
pin_memory=True)
model.eval()
model.fuse_model()
model.qconfig = torch.quantization.get_default_qconfig(args.backend)
torch.quantization.prepare(model, inplace=True)
# Calibrate first
print("Calibrating")
evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
torch.quantization.convert(model, inplace=True)
if args.output_dir:
print('Saving quantized model')
if utils.is_main_process():
torch.save(model.state_dict(), os.path.join(args.output_dir,
'quantized_post_train_model.pth'))
print("Evaluating post-training quantized model")
evaluate(model, criterion, data_loader_test, device=device)
return
if args.test_only:
evaluate(model, criterion, data_loader_test, device=device)
return
model.apply(torch.quantization.enable_observer)
model.apply(torch.quantization.enable_fake_quant)
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
print('Starting training for epoch', epoch)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch,
args.print_freq)
lr_scheduler.step()
with torch.no_grad():
if epoch >= args.num_observer_update_epochs:
print('Disabling observer for subseq epochs, epoch = ', epoch)
model.apply(torch.quantization.disable_observer)
if epoch >= args.num_batch_norm_update_epochs:
print('Freezing BN for subseq epochs, epoch = ', epoch)
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
print('Evaluate QAT model')
evaluate(model, criterion, data_loader_test, device=device)
quantized_eval_model = copy.deepcopy(model)
quantized_eval_model.eval()
quantized_eval_model.to(torch.device('cpu'))
torch.quantization.convert(quantized_eval_model, inplace=True)
print('Evaluate Quantized model')
evaluate(quantized_eval_model, criterion, data_loader_test,
device=torch.device('cpu'))
model.train()
if args.output_dir:
checkpoint = {
'model': model_without_ddp.state_dict(),
'eval_model': quantized_eval_model.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'args': args}
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'checkpoint.pth'))
print('Saving models after epoch ', epoch)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
def parse_args():
import argparse
parser = argparse.ArgumentParser(description='PyTorch Classification Training')
parser.add_argument('--data-path',
default='/datasets01/imagenet_full_size/061417/',
help='dataset')
parser.add_argument('--model',
default='mobilenet_v2',
help='model')
parser.add_argument('--backend',
default='qnnpack',
help='fbgemm or qnnpack')
parser.add_argument('--device',
default='cuda',
help='device')
parser.add_argument('-b', '--batch-size', default=32, type=int,
help='batch size for calibration/training')
parser.add_argument('--eval-batch-size', default=128, type=int,
help='batch size for evaluation')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--num-observer-update-epochs',
default=4, type=int, metavar='N',
help='number of total epochs to update observers')
parser.add_argument('--num-batch-norm-update-epochs', default=3,
type=int, metavar='N',
help='number of total epochs to update batch norm stats')
parser.add_argument('--num-calibration-batches',
default=32, type=int, metavar='N',
help='number of batches of training set for \
observer calibration ')
parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
help='number of data loading workers (default: 16)')
parser.add_argument('--lr',
default=0.0001, type=float,
help='initial learning rate')
parser.add_argument('--momentum',
default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('--lr-step-size', default=30, type=int,
help='decrease lr every step-size epochs')
parser.add_argument('--lr-gamma', default=0.1, type=float,
help='decrease lr by a factor of lr-gamma')
parser.add_argument('--print-freq', default=10, type=int,
help='print frequency')
parser.add_argument('--output-dir', default='.', help='path where to save')
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument(
"--cache-dataset",
dest="cache_dataset",
help="Cache the datasets for quicker initialization. \
It also serializes the transforms",
action="store_true",
)
parser.add_argument(
"--test-only",
dest="test_only",
help="Only test the model",
action="store_true",
)
parser.add_argument(
"--post-training-quantize",
dest="post_training_quantize",
help="Post training quantize the model",
action="store_true",
)
# distributed training parameters
parser.add_argument('--world-size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist-url',
default='env://',
help='url used to set up distributed training')
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
main(args)
import torchvision
from common_utils import TestCase, map_nested_tensor_object
from collections import OrderedDict
from itertools import product
import torch
import numpy as np
from torchvision import models
import unittest
import traceback
import random
def set_rng_seed(seed):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
def get_available_quantizable_models():
# TODO add a registration mechanism to torchvision.models
return [k for k, v in models.quantization.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
# list of models that are not scriptable
scriptable_quantizable_models_blacklist = []
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines and
'qnnpack' in torch.backends.quantized.supported_engines,
"This Pytorch Build has not been built with fbgemm and qnnpack")
class ModelTester(TestCase):
def check_quantized_model(self, model, input_shape):
x = torch.rand(input_shape)
model(x)
return
def check_script(self, model, name):
if name in scriptable_quantizable_models_blacklist:
return
scriptable = True
msg = ""
try:
torch.jit.script(model)
except Exception as e:
tb = traceback.format_exc()
scriptable = False
msg = str(e) + str(tb)
self.assertTrue(scriptable, msg)
def _test_classification_model(self, name, input_shape):
# passing num_class equal to a number other than 1000 helps in making the test
# more enforcing in nature
# First check if quantize=True provides models that can run with input data
model = torchvision.models.quantization.__dict__[name](pretrained=True, quantize=True)
self.check_quantized_model(model, input_shape)
for eval in [True, False]:
model = torchvision.models.quantization.__dict__[name](pretrained=False, quantize=False)
if eval:
model.eval()
model.qconfig = torch.quantization.default_qconfig
else:
model.train()
model.qconfig = torch.quantization.default_qat_qconfig
model.fuse_model()
if eval:
torch.quantization.prepare(model, inplace=True)
else:
torch.quantization.prepare_qat(model, inplace=True)
model.eval()
torch.quantization.convert(model, inplace=True)
self.check_script(model, name)
for model_name in get_available_quantizable_models():
# for-loop bodies don't define scopes, so we have to save the variables
# we want to close over in some way
def do_test(self, model_name=model_name):
input_shape = (1, 3, 224, 224)
if model_name in ['inception_v3']:
input_shape = (1, 3, 299, 299)
self._test_classification_model(model_name, input_shape)
setattr(ModelTester, "test_" + model_name, do_test)
if __name__ == '__main__':
unittest.main()
...@@ -11,3 +11,4 @@ from .shufflenetv2 import * ...@@ -11,3 +11,4 @@ from .shufflenetv2 import *
from . import segmentation from . import segmentation
from . import detection from . import detection
from . import video from . import video
from . import quantization
...@@ -70,7 +70,12 @@ class InvertedResidual(nn.Module): ...@@ -70,7 +70,12 @@ class InvertedResidual(nn.Module):
class MobileNetV2(nn.Module): class MobileNetV2(nn.Module):
def __init__(self, num_classes=1000, width_mult=1.0, inverted_residual_setting=None, round_nearest=8): def __init__(self,
num_classes=1000,
width_mult=1.0,
inverted_residual_setting=None,
round_nearest=8,
block=None):
""" """
MobileNet V2 main class MobileNet V2 main class
...@@ -80,9 +85,13 @@ class MobileNetV2(nn.Module): ...@@ -80,9 +85,13 @@ class MobileNetV2(nn.Module):
inverted_residual_setting: Network structure inverted_residual_setting: Network structure
round_nearest (int): Round the number of channels in each layer to be a multiple of this number round_nearest (int): Round the number of channels in each layer to be a multiple of this number
Set to 1 to turn off rounding Set to 1 to turn off rounding
block: Module specifying inverted residual building block for mobilenet
""" """
super(MobileNetV2, self).__init__() super(MobileNetV2, self).__init__()
block = InvertedResidual
if block is None:
block = InvertedResidual
input_channel = 32 input_channel = 32
last_channel = 1280 last_channel = 1280
...@@ -138,12 +147,15 @@ class MobileNetV2(nn.Module): ...@@ -138,12 +147,15 @@ class MobileNetV2(nn.Module):
nn.init.normal_(m.weight, 0, 0.01) nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)
def forward(self, x): def _forward(self, x):
x = self.features(x) x = self.features(x)
x = x.mean([2, 3]) x = x.mean([2, 3])
x = self.classifier(x) x = self.classifier(x)
return x return x
# Allow for accessing forward method in a inherited class
forward = _forward
def mobilenet_v2(pretrained=False, progress=True, **kwargs): def mobilenet_v2(pretrained=False, progress=True, **kwargs):
""" """
......
from .mobilenet import *
from .resnet import *
from torch import nn
from torchvision.models.utils import load_state_dict_from_url
from torchvision.models.mobilenet import InvertedResidual, ConvBNReLU, MobileNetV2, model_urls
from torch.quantization import QuantStub, DeQuantStub, fuse_modules
from .utils import _replace_relu, quantize_model
__all__ = ['QuantizableMobileNetV2', 'mobilenet_v2']
quant_model_urls = {
'mobilenet_v2_qnnpack':
'https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth'
}
class QuantizableInvertedResidual(InvertedResidual):
def __init__(self, *args, **kwargs):
super(QuantizableInvertedResidual, self).__init__(*args, **kwargs)
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x):
if self.use_res_connect:
return self.skip_add.add(x, self.conv(x))
else:
return self.conv(x)
def fuse_model(self):
for idx in range(len(self.conv)):
if type(self.conv[idx]) == nn.Conv2d:
fuse_modules(self.conv, [str(idx), str(idx + 1)], inplace=True)
class QuantizableMobileNetV2(MobileNetV2):
def __init__(self, *args, **kwargs):
"""
MobileNet V2 main class
Args:
Inherits args from floating point MobileNetV2
"""
super(QuantizableMobileNetV2, self).__init__(*args, **kwargs)
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self._forward(x)
x = self.dequant(x)
return x
def fuse_model(self):
for m in self.modules():
if type(m) == ConvBNReLU:
fuse_modules(m, ['0', '1', '2'], inplace=True)
if type(m) == QuantizableInvertedResidual:
m.fuse_model()
def mobilenet_v2(pretrained=False, progress=True, quantize=False, **kwargs):
"""
Constructs a MobileNetV2 architecture from
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks"
<https://arxiv.org/abs/1801.04381>`_.
Note that quantize = True returns a quantized model with 8 bit
weights. Quantized models only support inference and run on CPUs.
GPU inference is not yet supported
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet.
progress (bool): If True, displays a progress bar of the download to stderr
quantize(bool): If True, returns a quantized model, else returns a float model
"""
model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs)
_replace_relu(model)
if quantize:
# TODO use pretrained as a string to specify the backend
backend = 'qnnpack'
quantize_model(model, backend)
else:
assert pretrained in [True, False]
if pretrained:
if quantize:
model_url = quant_model_urls['mobilenet_v2_' + backend]
else:
model_url = model_urls['mobilenet_v2']
state_dict = load_state_dict_from_url(model_url,
progress=progress)
model.load_state_dict(state_dict)
return model
import torch
from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, model_urls
import torch.nn as nn
from torchvision.models.utils import load_state_dict_from_url
from torch.quantization import QuantStub, DeQuantStub, fuse_modules
from torch._jit_internal import Optional
from .utils import _replace_relu, quantize_model
__all__ = ['QuantizableResNet', 'resnet18', 'resnet50',
'resnext101_32x8d']
quant_model_urls = {
'resnet18_fbgemm':
'https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth',
'resnet50_fbgemm':
'https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth',
'resnext101_32x8d_fbgemm':
'https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth',
}
class QuantizableBasicBlock(BasicBlock):
def __init__(self, *args, **kwargs):
super(QuantizableBasicBlock, self).__init__(*args, **kwargs)
self.add_relu = torch.nn.quantized.FloatFunctional()
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out = self.add_relu.add_relu(out, identity)
return out
def fuse_model(self):
torch.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu'],
['conv2', 'bn2']], inplace=True)
if self.downsample:
torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True)
class QuantizableBottleneck(Bottleneck):
def __init__(self, *args, **kwargs):
super(QuantizableBottleneck, self).__init__(*args, **kwargs)
self.skip_add_relu = nn.quantized.FloatFunctional()
self.relu1 = nn.ReLU(inplace=False)
self.relu2 = nn.ReLU(inplace=False)
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu2(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out = self.skip_add_relu.add_relu(out, identity)
return out
def fuse_model(self):
fuse_modules(self, [['conv1', 'bn1', 'relu1'],
['conv2', 'bn2', 'relu2'],
['conv3', 'bn3']], inplace=True)
if self.downsample:
torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True)
class QuantizableResNet(ResNet):
def __init__(self, *args, **kwargs):
super(QuantizableResNet, self).__init__(*args, **kwargs)
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
# Ensure scriptability
# super(QuantizableResNet,self).forward(x)
# is not scriptable
x = self._forward(x)
x = self.dequant(x)
return x
def fuse_model(self):
r"""Fuse conv/bn/relu modules in resnet models
Fuse conv+bn+relu/ Conv+relu/conv+Bn modules to prepare for quantization.
Model is modified in place. Note that this operation does not change numerics
and the model after modification is in floating point
"""
fuse_modules(self, ['conv1', 'bn1', 'relu'], inplace=True)
for m in self.modules():
if type(m) == QuantizableBottleneck or type(m) == QuantizableBasicBlock:
m.fuse_model()
def _resnet(arch, block, layers, pretrained, progress, quantize, **kwargs):
model = QuantizableResNet(block, layers, **kwargs)
_replace_relu(model)
if quantize:
# TODO use pretrained as a string to specify the backend
backend = 'fbgemm'
quantize_model(model, backend)
else:
assert pretrained in [True, False]
if pretrained:
if quantize:
model_url = quant_model_urls[arch + '_' + backend]
else:
model_url = model_urls[arch]
state_dict = load_state_dict_from_url(model_url,
progress=progress)
model.load_state_dict(state_dict)
return model
def resnet18(pretrained=False, progress=True, quantize=False, **kwargs):
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet18', QuantizableBasicBlock, [2, 2, 2, 2], pretrained, progress,
quantize, **kwargs)
def resnet50(pretrained=False, progress=True, quantize=False, **kwargs):
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet50', QuantizableBottleneck, [3, 4, 6, 3], pretrained, progress,
quantize, **kwargs)
def resnext101_32x8d(pretrained=False, progress=True, quantize=False, **kwargs):
r"""ResNeXt-101 32x8d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 8
return _resnet('resnext101_32x8d', QuantizableBottleneck, [3, 4, 23, 3],
pretrained, progress, quantize, **kwargs)
import torch
from torch import nn
def _replace_relu(module):
reassign = {}
for name, mod in module.named_children():
_replace_relu(mod)
# Checking for explicit type instead of instance
# as we only want to replace modules of the exact type
# not inherited classes
if type(mod) == nn.ReLU or type(mod) == nn.ReLU6:
reassign[name] = nn.ReLU(inplace=False)
for key, value in reassign.items():
module._modules[key] = value
def quantize_model(model, backend):
_dummy_input_data = torch.rand(1, 3, 299, 299)
if backend not in torch.backends.quantized.supported_engines:
raise RuntimeError("Quantized backend not supported ")
torch.backends.quantized.engine = backend
model.eval()
# Make sure that weight qconfig matches that of the serialized models
if backend == 'fbgemm':
model.qconfig = torch.quantization.QConfig(
activation=torch.quantization.default_observer,
weight=torch.quantization.default_per_channel_weight_observer)
elif backend == 'qnnpack':
model.qconfig = torch.quantization.QConfig(
activation=torch.quantization.default_observer,
weight=torch.quantization.default_weight_observer)
model.fuse_model()
torch.quantization.prepare(model, inplace=True)
model(_dummy_input_data)
torch.quantization.convert(model, inplace=True)
return
...@@ -194,7 +194,7 @@ class ResNet(nn.Module): ...@@ -194,7 +194,7 @@ class ResNet(nn.Module):
return nn.Sequential(*layers) return nn.Sequential(*layers)
def forward(self, x): def _forward(self, x):
x = self.conv1(x) x = self.conv1(x)
x = self.bn1(x) x = self.bn1(x)
x = self.relu(x) x = self.relu(x)
...@@ -211,6 +211,9 @@ class ResNet(nn.Module): ...@@ -211,6 +211,9 @@ class ResNet(nn.Module):
return x return x
# Allow for accessing forward method in a inherited class
forward = _forward
def _resnet(arch, block, layers, pretrained, progress, **kwargs): def _resnet(arch, block, layers, pretrained, progress, **kwargs):
model = ResNet(block, layers, **kwargs) model = ResNet(block, layers, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment