Unverified Commit 370e88df authored by chenbohua3's avatar chenbohua3 Committed by GitHub
Browse files

Add post training observer_quantizer (#3915)

parent 3f1e4f55
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.algorithms.compression.pytorch.quantization import ObserverQuantizer
import sys
sys.path.append('../models')
from mnist.naive import NaiveModel
def train(model, device, train_loader, optimizer):
model.to(device)
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('Loss: {} Accuracy: {}%)\n'.format(
test_loss, 100 * correct / len(test_loader.dataset)))
def calibration(model, device, test_loader):
model.eval()
with torch.no_grad():
for data, _ in test_loader:
data = data.to(device)
model(data)
def main():
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True, transform=trans),
batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=False, transform=trans),
batch_size=1000, shuffle=True)
model = NaiveModel()
configure_list = [{
'quant_types': ['weight', 'input'],
'quant_bits': {'weight': 8, 'input': 8},
'op_names': ['conv1'],
}, {
'quant_types': ['output'],
'quant_bits': {'output': 8, },
'op_names': ['relu1'],
}, {
'quant_types': ['weight', 'input'],
'quant_bits': {'weight': 8, 'input': 8},
'op_names': ['conv2'],
}, {
'quant_types': ['output'],
'quant_bits': {'output': 8},
'op_names': ['relu2'],
}, {
'quant_types': ['output'],
'quant_bits': {'output': 8},
'op_names': ['max_pool2'],
}
]
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
# Train the model to get a baseline performance
for epoch in range(5):
print('# Epoch {} #'.format(epoch))
train(model, device, train_loader, optimizer)
test(model, device, test_loader)
# Construct the ObserverQuantizer. Note that currently ObserverQuantizer only works
# in evaluation mode.
quantizer = ObserverQuantizer(model.eval(), configure_list, optimizer)
# Use the test data set to do calibration, this will not change the model parameters
calibration(model, device, test_loader)
# obtain the quantization information and switch the model to "accuracy verification" mode
quantizer.compress()
# measure the accuracy of the quantized model.
test(model, device, test_loader)
model_path = "mnist_model.pth"
calibration_path = "mnist_calibration.pth"
calibration_config = quantizer.export_model(model_path, calibration_path)
print("calibration_config: ", calibration_config)
# For now the quantization settings of ObserverQuantizer does not match the TensorRT,
# so TensorRT conversion are not supported
# current settings:
# weight : per_tensor_symmetric, qint8
# activation : per_tensor_affine, quint8, reduce_range=True
if __name__ == '__main__':
main()
from torch.quantization import default_weight_observer, default_histogram_observer
__all__ = ["default_weight_observer", "default_histogram_observer"]
......@@ -3,12 +3,15 @@
import logging
import copy
from collections import defaultdict
import torch
from schema import Schema, And, Or, Optional
from nni.compression.pytorch.utils.config_validation import QuantizerSchema
from nni.compression.pytorch.compressor import BN_FOLD_TAG, Quantizer, QuantForward, QuantGrad, QuantType
__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer', 'LsqQuantizer']
from .observers import default_weight_observer, default_histogram_observer
__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer', 'LsqQuantizer', 'ObserverQuantizer']
logger = logging.getLogger(__name__)
......@@ -120,6 +123,231 @@ class QATGrad(QuantGrad):
return grad_output
class ObserverQuantizer(Quantizer):
"""This quantizer uses observers to record weight/activation statistics to get quantization information.
The whole process can be divided into three steps:
1. It will register observers to the place where quantization would happen (just like registering hooks).
2. The observers would record tensors' statistics during calibration.
3. Scale & zero point would be obtained after calibration.
Note that the observer type, tensor dtype and quantization qscheme are hard coded for now. Their customization
are under development and will be ready soon.
"""
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
# NOTE: this quantizer is experimental for now. The dtype and qscheme of quantization
# is hard-coded.
# TODO:
# 1. support dtype and qscheme customization through config_list. Current settings:
# weight observer : per_tensor_symmetric, qint8
# activation observer : per_tensor_affine, quint8, reduce_range=True
# 2. add more kinds of observers, such as Kullback-Leibler divergence.
# 3. add batch normalization folding
assert not model.training, "Currently the observer quantizer only works in evaluation mode."
self.quant_grad = QuantForward()
self.device = next(model.parameters()).device
modules_to_compress = self.get_modules_to_compress()
all_observers = defaultdict(dict)
weight_q_min, weight_q_max = -127, 127
activation_q_min, activation_q_max = 0, 127 # reduce_range is set to True
self.compressed = False
for layer, config in modules_to_compress:
layer_name = layer.name
module = layer.module
if "weight" in config.get("quant_types", []):
all_observers[layer_name]["weight"] = default_weight_observer()
setattr(module, "weight_qmax", weight_q_max)
setattr(module, "weight_qmin", weight_q_min)
if "input" in config.get("quant_types", []):
all_observers[layer_name]["input"] = default_histogram_observer()
setattr(module, "input_qmax", activation_q_max)
setattr(module, "input_qmin", activation_q_min)
if "output" in config.get("quant_types", []):
all_observers[layer_name]["output"] = default_histogram_observer()
setattr(module, "output_qmax", activation_q_max)
setattr(module, "output_qmin", activation_q_min)
self.all_observers = all_observers
self.bound_model.to(self.device)
def validate_config(self, model, config_list):
schema = QuantizerSchema([{
Optional('quant_types'): Schema([lambda x: x in ['weight', 'output', 'input']]),
Optional('quant_bits'): Or(And(int, lambda n: n == 8), Schema({
Optional('weight'): And(int, lambda n: n == 8),
Optional('output'): And(int, lambda n: n == 8),
Optional('input'): And(int, lambda n: n == 8),
})),
Optional('op_types'): [str],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
def record(self, wrapper, quant_type, tensor):
name = wrapper.name
observer = self.all_observers[name][quant_type]
if isinstance(tensor, tuple):
# NB: This only works for single tensor
tensor = (t.cpu() for t in tensor)
observer(*tensor)
else:
observer(tensor.cpu())
def calculate_qparams(self, name, quant_type):
observer = self.all_observers[name][quant_type]
scale, zero_point = observer.calculate_qparams()
return scale, zero_point
def _quantize(self, x, scale, zero_point, qmin, qmax):
x = x / scale + zero_point
x = torch.clamp(x, qmin, qmax)
x = torch.round(x)
x = (x - zero_point) * scale
return x
def quantize_input(self, *inputs, wrapper, **kwargs):
if self.compressed:
module = wrapper.module
new_input = self._quantize(inputs[0],
module.input_scale,
module.input_zero_point,
module.input_qmin,
module.input_qmax)
list_inp = list(inputs)
list_inp[0] = new_input
inputs = tuple(list_inp)
else:
self.record(wrapper, 'input', inputs)
return inputs
def quantize_weight(self, wrapper, **kwargs):
# If ObserverQuantizer.compress is executed, the weight will be set to
# the Pseudo-quantized one. So there is no need to quantize it
if self.compressed:
return
module = wrapper.module
old_weight = module.weight
self.record(wrapper, 'weight', old_weight)
def quantize_output(self, output, wrapper, **kwargs):
if self.compressed:
module = wrapper.module
new_output = self._quantize(output,
module.output_scale,
module.output_zero_point,
module.output_qmin,
module.output_qmax)
else:
self.record(wrapper, 'output', output)
new_output = output
return new_output
def compress(self):
"""
Calculate quantization information of each tensor. Note that the inference of
the compressed model will no longer update the corresponding. Instead, the quantization
process will be simulated, which is used to test the accuracy of the quantization.
"""
modules_to_compress = self.get_modules_to_compress()
for layer, config in modules_to_compress:
module = layer.module
if "weight" in config.get("quant_types", []):
scale, zero_point = self.calculate_qparams(layer.name, 'weight')
module.register_buffer('weight_scale', scale.to(self.device))
module.register_buffer('weight_zero_point', zero_point.to(self.device))
weight = module.weight
quantized_weight = self._quantize(weight,
module.weight_scale,
module.weight_zero_point,
module.weight_qmin,
module.weight_qmax)
delattr(module, 'weight')
module.register_parameter('weight', torch.nn.Parameter(quantized_weight))
if "input" in config.get("quant_types", []):
scale, zero_point = self.calculate_qparams(layer.name, 'input')
module.register_buffer('input_scale', scale.to(self.device))
module.register_buffer('input_zero_point', zero_point.to(self.device))
if "output" in config.get("quant_types", []):
scale, zero_point = self.calculate_qparams(layer.name, 'output')
module.register_buffer('output_scale', scale.to(self.device))
module.register_buffer('output_zero_point', zero_point.to(self.device))
self.compressed = True
super().compress()
def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None):
"""
Export quantized model weights and calibration parameters(optional)
Parameters
----------
model_path : str
path to save quantized model weight
calibration_path : str
(optional) path to save quantize parameters after calibration
onnx_path : str
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
Returns
-------
Dict
"""
assert model_path is not None, 'model_path must be specified'
self._unwrap_model()
calibration_config = {}
for name, module in self.bound_model.named_modules():
if hasattr(module, 'weight_scale') or hasattr(module, 'input_scale') or hasattr(module, 'output_scale'):
calibration_config[name] = {}
if hasattr(module, 'weight_scale'):
calibration_config[name]['weight_bit'] = 8
val = float(module.weight_scale * module.weight_qmax)
calibration_config[name]['tracked_max_weight'] = val
calibration_config[name]['tracked_min_weight'] = -val
calibration_config[name]['tracked_weight_qmin'] = -127
calibration_config[name]['tracked_weight_qmax'] = 127
# refactor these magic numbers when customizations of dtype and qscheme are ready.
if hasattr(module, 'input_scale'):
calibration_config[name]['input_bit'] = 8
max_input = float(module.input_scale * (module.input_qmax - module.input_zero_point))
min_input = float(module.input_scale * (module.input_qmin - module.input_zero_point))
calibration_config[name]['tracked_min_input'] = min_input
calibration_config[name]['tracked_max_input'] = max_input
calibration_config[name]['tracked_input_qmin'] = 0
calibration_config[name]['tracked_input_qmax'] = 127
if hasattr(module, 'output_scale'):
calibration_config[name]['activation_bit'] = 8
max_input = float(module.output_scale * (module.output_qmax - module.output_zero_point))
min_input = float(module.output_scale * (module.output_qmin - module.output_zero_point))
calibration_config[name]['tracked_min_activation'] = min_input
calibration_config[name]['tracked_max_activation'] = max_input
calibration_config[name]['tracked_activation_qmin'] = 0
calibration_config[name]['tracked_activation_qmax'] = 127
self._del_simulated_attr(module)
self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path,
input_shape, device)
return calibration_config
def _del_simulated_attr(self, module):
"""
delete redundant parameters in quantize module
"""
del_attr_list = ['old_weight', 'steps', 'weight_qmax', 'weight_qmin', 'input_qmax', 'input_qmin',
'output_qmax', 'output_qmin', 'weight_scale', 'weight_zero_point', 'input_scale',
'input_zero_point', 'output_scale', 'output_zero_point']
for attr in del_attr_list:
if hasattr(module, attr):
delattr(module, attr)
class QAT_Quantizer(Quantizer):
"""Quantizer defined in:
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
from unittest import TestCase, main
import numpy as np
import torch
......@@ -263,6 +264,46 @@ class CompressorTestCase(TestCase):
assert all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 0., 0, 0., 25.]))
assert all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 125., 125., 125., 0., 0., 0.]))
def test_torch_observer_quantizer(self):
model = TorchModel()
# test invalid config
# only support 8bit for now
config_list = [{
'quant_types': ['weight'],
'quant_bits': 5,
'op_types': ['Conv2d', 'Linear']
}]
with self.assertRaises(schema.SchemaError):
torch_quantizer.ObserverQuantizer(model, config_list)
# weight will not change for now
model = TorchModel().eval()
origin_parameters = copy.deepcopy(dict(model.named_parameters()))
config_list = [{
'quant_types': ['weight'],
'quant_bits': 8,
'op_types': ['Conv2d', 'Linear']
}]
quantizer = torch_quantizer.ObserverQuantizer(model, config_list)
input = torch.randn(1, 1, 28, 28)
model(input)
quantizer.compress()
buffers = dict(model.named_buffers())
scales = {k: v for k, v in buffers.items() if 'scale' in k}
model_path = "test_model.pth"
calibration_path = "test_calibration.pth"
calibration_config = quantizer.export_model(model_path, calibration_path)
new_parameters = dict(model.named_parameters())
for layer_name, v in calibration_config.items():
scale_name = layer_name + '.module.weight_scale'
weight_name = layer_name + '.weight'
s = float(scales[scale_name])
self.assertTrue(torch.allclose(origin_parameters[weight_name], new_parameters[weight_name], atol=0.5 * s))
self.assertTrue(calibration_config is not None)
self.assertTrue(len(calibration_config) == 4)
def test_torch_QAT_quantizer(self):
model = TorchModel()
config_list = [{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment