Unverified Commit c9cd53aa authored by chenbohua3's avatar chenbohua3 Committed by GitHub
Browse files

support dtype&scheme customization for QAT quantizer (#4137)

parent b0f34da1
......@@ -155,7 +155,7 @@ Sometimes it's necessary for a quantization operation to have a customized backw
grad_output : Tensor
gradient of the output of quantization operation
quant_type : QuantType
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`,
the type of quantization, it can be `QuantType.INPUT`, `QuantType.WEIGHT`, `QuantType.OUTPUT`,
you can define different behavior for different types.
Returns
-------
......@@ -164,7 +164,7 @@ Sometimes it's necessary for a quantization operation to have a customized backw
"""
# for quant_output function, set grad to zero if the absolute value of tensor is larger than 1
if quant_type == QuantType.QUANT_OUTPUT:
if quant_type == QuantType.OUTPUT:
grad_output[torch.abs(tensor) > 1] = 0
return grad_output
......
......@@ -2,11 +2,13 @@ import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer
from nni.compression.pytorch.quantization.settings import set_quant_scheme_dtype
import sys
sys.path.append('../models')
from mnist.naive import NaiveModel
def train(model, device, train_loader, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
......@@ -58,22 +60,32 @@ def main():
# {'quant_types': ['input'], 'op_names': ['b']} in the configure_list.
configure_list = [{
'quant_types': ['weight', 'input'],
'quant_bits': {'weight': 8, 'input': 8},
'op_names': ['conv1', 'conv2']
}, {
'quant_types': ['output'],
'quant_bits': {'output': 8, },
'op_names': ['relu1', 'relu2']
}, {
'quant_types': ['output', 'weight', 'input'],
'quant_bits': {'output': 8, 'weight': 8, 'input': 8},
'op_names': ['fc1'],
}, {
'quant_types': ['output', 'weight', 'input'],
'quant_bits': {'output': 8, 'weight': 8, 'input': 8},
'op_names': ['fc2'],
}]
'quant_types': ['weight', 'input'],
'quant_bits': {'weight': 8, 'input': 8},
'op_names': ['conv1', 'conv2']
}, {
'quant_types': ['output'],
'quant_bits': {'output': 8, },
'op_names': ['relu1', 'relu2']
}, {
'quant_types': ['output', 'weight', 'input'],
'quant_bits': {'output': 8, 'weight': 8, 'input': 8},
'op_names': ['fc1', 'fc2'],
}]
# you can also set the quantization dtype and scheme layer-wise through configure_list like:
# configure_list = [{
# 'quant_types': ['weight', 'input'],
# 'quant_bits': {'weight': 8, 'input': 8},
# 'op_names': ['conv1', 'conv2'],
# 'quant_dtype': 'int',
# 'quant_scheme': 'per_channel_symmetric'
# }]
# For now quant_dtype's options are 'int' and 'uint. And quant_scheme's options are per_tensor_affine,
# per_tensor_symmetric, per_channel_affine and per_channel_symmetric.
set_quant_scheme_dtype('weight', 'per_channel_symmetric', 'int')
set_quant_scheme_dtype('output', 'per_tensor_symmetric', 'int')
set_quant_scheme_dtype('input', 'per_tensor_symmetric', 'int')
model = NaiveModel().to(device)
dummy_input = torch.randn(1, 1, 28, 28).to(device)
......@@ -98,5 +110,6 @@ def main():
calibration_config = quantizer.export_model(model_path, calibration_path, onnx_path, input_shape, device)
print("Generated calibration config is: ", calibration_config)
if __name__ == '__main__':
main()
import logging
try:
import torch
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split(".")[:2])
except Exception:
logging.info("PyTorch is not installed.")
TORCH_VERSION = None
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import types
import logging
import torch
from nni.common.graph_utils import build_module_graph
from nni.compression.pytorch.quantization.literal import QuantType, BN_FOLD_OP, BN_FOLD_TAG
from nni.compression.pytorch.quantization.observers import RecordingObserver
from . import default_layers
_logger = logging.getLogger(__name__)
......@@ -547,7 +550,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
assert len(inputs) == 1, "Quantization of input only supports ops with single input."
new_inp = self.quantizer.quant_grad(
inputs[0],
QuantType.QUANT_INPUT,
QuantType.INPUT,
self)
inputs = (new_inp,)
......@@ -563,7 +566,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
self.quantizer.quant_grad(
new_weight,
QuantType.QUANT_WEIGHT,
QuantType.WEIGHT,
self, inputs[0])
result = self.module(*inputs)
......@@ -571,7 +574,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
if 'output' in self.config['quant_types']:
result = self.quantizer.quant_grad(
result,
QuantType.QUANT_OUTPUT,
QuantType.OUTPUT,
self)
return result
......@@ -604,10 +607,13 @@ class Quantizer(Compressor):
def __init__(self, model, config_list, optimizer=None, dummy_input=None):
if isinstance(model, torch.nn.DataParallel):
model = model.module
model_copied = copy.deepcopy(model)
self.identity_wrappers = []
self.conv_bn_patterns = {}
self.find_conv_bn_patterns(model, dummy_input)
super().__init__(model, config_list, optimizer)
self.all_shapes = {}
self.record_shape(model_copied, dummy_input)
self.quant_grad = QuantGrad.apply
if self.optimizer is not None:
self.patch_optimizer(self.step_with_optimizer)
......@@ -845,25 +851,54 @@ class Quantizer(Compressor):
if successor.op_type == 'BatchNorm2d':
self.conv_bn_patterns[node_group.name] = successor.name
def step_with_optimizer(self):
pass
def record_shape(self, model, dummy_input):
"""
Record input/output's shapes of each module to be quantized
class QuantType:
"""
Enum class for quantization type.
"""
QUANT_INPUT = 0
QUANT_WEIGHT = 1
QUANT_OUTPUT = 2
Parameters
----------
model : torch.nn.Module
model to be recorded.
dummy_input : tupel of torch.tensor
inputs to the model.
"""
def _pre_forward_hook(self, inp):
# Only record the first tensor of the input
return self.pre_forward(inp[0])
def _post_forward_hook(self, _, out):
return self.post_forward(out)
if dummy_input is None:
return
all_handles = []
all_observers = {}
modules_to_compress = self.get_modules_to_compress()
compress_names = [layer_info[0].name for layer_info in modules_to_compress]
for name, module in model.named_modules():
if name in compress_names:
all_observers[name] = {}
all_observers[name]['input_hook'] = RecordingObserver()
all_observers[name]['output_hook'] = RecordingObserver()
module.add_module('pre_forward', all_observers[name]['input_hook'])
module.add_module('post_forward', all_observers[name]['output_hook'])
all_handles.append(module.register_forward_pre_hook(_pre_forward_hook))
all_handles.append(module.register_forward_hook(_post_forward_hook))
model(dummy_input)
for name, hooks in all_observers.items():
# only support single input
input_val = hooks['input_hook'].tensor_val
input_shape = input_val[0].shape if input_val else None
output_val = hooks['output_hook'].tensor_val
output_shape = output_val[0].shape if output_val else None
shapes = [input_shape, output_shape]
self.all_shapes[name] = shapes
return
QType_Dict = {
0: "input",
1: "weight",
2: "output"
}
def step_with_optimizer(self):
pass
BN_FOLD_OP = ["Conv2d"]
BN_FOLD_TAG = 'BN_FOLD_TAG'
class QuantGrad(torch.autograd.Function):
"""
......@@ -920,8 +955,8 @@ class QuantGrad(torch.autograd.Function):
grad_output : Tensor
gradient of the output of quantization operation
scale : Tensor
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`,
`QuantType.QUANT_OUTPUT`, you can define different behavior for different types.
the type of quantization, it can be `QuantType.INPUT`, `QuantType.WEIGHT`,
`QuantType.OUTPUT`, you can define different behavior for different types.
zero_point : Tensor
zero_point for quantizing tensor
qmin : Tensor
......@@ -939,28 +974,39 @@ class QuantGrad(torch.autograd.Function):
def forward(ctx, tensor, quant_type, wrapper, input_tensor=None, **kwargs):
output = quantize_helper(tensor, quant_type, wrapper, input_tensor, **kwargs)
bits = QuantGrad.get_bits_length(wrapper.config, QType_Dict[quant_type])
qmin, qmax = torch.Tensor([0]).to(tensor.device), torch.Tensor([(1 << bits) - 1]).to(tensor.device)
if hasattr(wrapper.module, 'scale') and hasattr(wrapper.module, 'zero_point'):
if hasattr(wrapper.module, "layer_quant_setting"):
layer_quant_setting = wrapper.module.layer_quant_setting
qmin, qmax = getattr(layer_quant_setting, quant_type).get_qmin_qmax()
else:
# todo: when dtype/scheme customization is ready for all quantizers, remove this
bits = QuantGrad.get_bits_length(wrapper.config, quant_type)
qmin, qmax = 0, (1 << bits) - 1
scale_name, zero_point_name = quant_type.type_to_scale_zero_point_name()
if hasattr(wrapper.module, scale_name) and hasattr(wrapper.module, zero_point_name):
scale = getattr(wrapper.module, scale_name)
zero_point = getattr(wrapper.module, zero_point_name)
# todo: remove this when other quantizers use different scale & zero point for input/weight/output
elif hasattr(wrapper.module, 'scale') and hasattr(wrapper.module, 'zero_point'):
scale = wrapper.module.scale
zero_point = wrapper.module.zero_point
else:
scale, zero_point = None, None
ctx.save_for_backward(tensor)
# Only tensors have gradients flowing back needs to be saved by save_for_backward.
# Others should directly assign to ctx.
ctx.scale = scale
ctx.zero_point = zero_point
ctx.save_for_backward(tensor)
ctx.quant_type = quant_type
ctx.qmin, ctx.qmax = qmin, qmax
ctx.scale = scale
ctx.zero_point = zero_point
return output
@classmethod
def backward(cls, ctx, grad_output):
tensor = ctx.saved_variables[0]
scale, zero_point = ctx.scale, ctx.zero_point
qmin, qmax = ctx.qmin, ctx.qmax
quant_type = ctx.quant_type
qmin, qmax = ctx.qmin, ctx.qmax
output = cls.quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax)
return output, None, None, None
......@@ -977,11 +1023,11 @@ def _check_bias(module):
return False
def quantize_helper(tensor, quant_type, wrapper, input_tensor=None, **kwargs):
if quant_type == QuantType.QUANT_INPUT:
if quant_type == QuantType.INPUT:
output = wrapper.quantizer.quantize_input(tensor, wrapper=wrapper, **kwargs)
elif quant_type == QuantType.QUANT_WEIGHT:
elif quant_type == QuantType.WEIGHT:
output = wrapper.quantizer.quantize_weight(wrapper, input_tensor=input_tensor, **kwargs)
elif quant_type == QuantType.QUANT_OUTPUT:
elif quant_type == QuantType.OUTPUT:
output = wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs)
else:
raise ValueError("unrecognized QuantType.")
......
from enum import Enum, EnumMeta
class _QuantLiteralEnumMeta(EnumMeta):
def __contains__(cls, item):
try:
cls(item)
except ValueError:
return False
return True
class _QuantLiteralEnum(Enum, metaclass=_QuantLiteralEnumMeta):
pass
class QuantScheme(str, _QuantLiteralEnum):
PER_TENSOR_AFFINE = 'per_tensor_affine'
PER_TENSOR_SYMMETRIC = 'per_tensor_symmetric'
PER_CHANNEL_AFFINE = 'per_channel_affine'
PER_CHANNEL_SYMMETRIC = 'per_channel_symmetric'
PER_CHANNEL_QUANT_SCHEME = [QuantScheme.PER_CHANNEL_AFFINE, QuantScheme.PER_CHANNEL_SYMMETRIC]
class QuantDtype(str, _QuantLiteralEnum):
UINT = 'uint'
INT = 'int'
class QuantType(str, _QuantLiteralEnum):
INPUT = 'input'
WEIGHT = 'weight'
OUTPUT = 'output'
def type_to_scale_zero_point_name(self):
if self == QuantType.INPUT:
return 'input_scale', 'input_zero_point'
elif self == QuantType.WEIGHT:
return 'weight_scale', 'weight_zero_point'
elif self == QuantType.OUTPUT:
return 'output_scale', 'output_zero_point'
else:
raise TypeError
# Just show each attribute's name, no practical effect
class QuantConfigLiteral(str, _QuantLiteralEnum):
QUANT_SETTINGS = 'quant_settings'
QUANT_SCHEME = 'quant_scheme'
QUANT_DTYPE = 'quant_dtype'
BITS = 'bits'
QMIN = 'qmin'
QMAX = 'qmax'
INPUT_SCALE = 'input_scale'
INPUT_ZERO_POINT = 'input_zero_point'
OUTPUT_SCALE = 'output_scale'
OUTPUT_ZERO_POINT = 'output_zero_point'
WEIGHT_SCALE = 'weight_scale'
WEIGHT_ZERO_POINT = 'weight_zero_point'
BN_FOLD_OP = ["Conv2d"]
BN_FOLD_TAG = 'BN_FOLD_TAG'
from torch.quantization import default_weight_observer, default_histogram_observer
from torch.quantization import RecordingObserver as _RecordingObserver
__all__ = ["default_weight_observer", "default_histogram_observer"]
__all__ = ["default_weight_observer", "default_histogram_observer", "RecordingObserver"]
class RecordingObserver(_RecordingObserver):
"""
A extended version of PyTorch's RecordingObserver, used to record gpu tensor
"""
def forward(self, x):
val = x.cpu()
super().forward(val)
return x
from typing import Any, Optional
from .literal import QuantDtype, QuantType, QuantScheme
from .utils import calculate_qmin_qmax, get_bits_length
# default settings for quantization module
quant_default_settings = {
QuantType.WEIGHT: {
'quant_scheme': QuantScheme.PER_TENSOR_AFFINE,
'quant_dtype': QuantDtype.UINT,
},
QuantType.INPUT: {
'quant_scheme': QuantScheme.PER_TENSOR_AFFINE,
'quant_dtype': QuantDtype.UINT
},
QuantType.OUTPUT: {
'quant_scheme': QuantScheme.PER_TENSOR_AFFINE,
'quant_dtype': QuantDtype.UINT
}
}
class TensorQuantSetting(object):
def __init__(self, **kwargs):
self._fields = {}
for k, v in kwargs.items():
self._fields[k] = v
def __setattr__(self, name: str, val: Any) -> None:
if name.startswith("_"):
super().__setattr__(name, val)
else:
self._fields[name] = val
def __getattr__(self, name):
if name == "_fields" or name not in self._fields:
raise AttributeError("Cannot find {} in TensorQuantSetting!".format(name))
return self._fields[name]
def get_qmin_qmax(self):
assert 'qmin' in self._fields and 'qmax' in self._fields, \
"Can not found qmin & qmax in TensorQuantSetting"
return self._fields['qmin'], self._fields['qmax']
class LayerQuantSetting(object):
def __init__(self, config):
self.input: Optional[TensorQuantSetting] = None
self.weight: Optional[TensorQuantSetting] = None
self.output: Optional[TensorQuantSetting] = None
self._extra_layer_setting = {}
for quant_type in QuantType:
if quant_type in config.get("quant_types", []):
setting = TensorQuantSetting()
quant_scheme = self.parse_optional_config(config, quant_type, 'quant_scheme')
setting.quant_scheme = quant_scheme
quant_dtype = self.parse_optional_config(config, quant_type, 'quant_dtype')
setting.quant_dtype = quant_dtype
bits = get_bits_length(config, quant_type)
qmin, qmax = calculate_qmin_qmax(bits, quant_dtype)
setting.bits = bits
setting.qmin = qmin
setting.qmax = qmax
setattr(self, quant_type, setting)
def __setattr__(self, name: str, val: Any) -> None:
if name.startswith("_") or name in QuantType:
super().__setattr__(name, val)
else:
self._extra_layer_setting[name] = val
def __getattr__(self, name):
if name == "_extra_layer_setting" or name not in self._extra_layer_setting:
raise AttributeError("Cannot find {} in LayerQuantSetting!".format(name))
return self._extra_layer_setting[name]
@staticmethod
def parse_optional_config(config, quant_type, target):
def get_config(config, quant_type, target):
if not config.get(target):
return None
if isinstance(config[target], dict):
return config[target].get(quant_type)
else:
return config[target]
default_val = quant_default_settings[quant_type].get(target, None)
config_val = get_config(config, quant_type, target)
val = config_val if config_val else default_val
return val
def set_quant_scheme_dtype(quant_type, new_scheme=None, new_dtype=None):
# todo: remove this if we convert string config to enum type.
if isinstance(quant_type, str):
assert quant_type in QuantType, "Wrong quant_type"
if isinstance(new_scheme, str):
assert new_scheme in QuantScheme, "Wrong quant_scheme"
if isinstance(new_dtype, str):
assert new_dtype in QuantDtype, "Wrong quant_dtype"
# TODO: It is not a good idea to directly modify global settings. A better choice is
# making this function an attribute function of Quantizer and call this function after
# the quantizer is initialized. However, within current framework of quantization, if
# we want to modify the dtype & scheme when the quantizer is initialized, we must do
# some other things (like changing the shapes of scales and zero_points and other quantization
# information in the subclass).
global quant_default_settings
if new_scheme is not None:
quant_default_settings[quant_type]['quant_scheme'] = new_scheme
if new_dtype is not None:
quant_default_settings[quant_type]['quant_dtype'] = new_dtype
return
import torch
from nni.common.version import TORCH_VERSION
from .literal import QuantDtype, QuantScheme, QuantType
def calculate_qmin_qmax(bits, dtype):
if dtype == QuantDtype.INT:
qmin, qmax = -2 ** (bits - 1) + 1, 2 ** (bits - 1) - 1
elif dtype == QuantDtype.UINT:
qmin, qmax = 0, 2 ** bits - 1
else:
raise TypeError("Wrong quantization dtype, please make sure it is one of 'int' and 'uint'.")
return qmin, qmax
def get_bits_length(config, quant_type):
if isinstance(config["quant_bits"], int):
return config["quant_bits"]
else:
return config["quant_bits"].get(quant_type)
def get_target_dim(quant_type, quant_scheme):
# for weight: c_out x c_in x (h) * (w)
# for feature maps: batch * channel * (t) * h * w
# other type is not supported for now
default_idx = 0 if quant_type == QuantType.WEIGHT else 1
if is_per_channel(quant_scheme):
target_dim = default_idx
else:
target_dim = None
return target_dim
def get_min_max_value(x, quant_type, quant_scheme):
target_dim = get_target_dim(quant_type, quant_scheme)
if target_dim is None:
return torch.min(x), torch.max(x)
indices = list(range(len(x.shape)))
assert target_dim < len(indices), "target_dim needs to be less than the number of dim of the tensor"
del indices[target_dim]
if TORCH_VERSION > (1, 6):
min_val = torch.amin(x, indices, keepdims=True)
max_val = torch.amax(x, indices, keepdims=True)
else:
min_val = max_val = x
for ind in indices:
min_val = torch.min(min_val, dim=ind, keepdim=True)[0]
max_val = torch.max(max_val, dim=ind, keepdim=True)[0]
return min_val, max_val
def get_mean_value(x, target_dim=None):
if target_dim is None:
return torch.mean(x)
indices = list(range(len(x.shape)))
assert target_dim < len(indices), "target_dim needs to be less than the number of dim of the tensor"
del indices[target_dim]
mean_val = torch.mean(x, dim=indices, keepdim=True)
return mean_val
def is_per_channel(quant_scheme):
if quant_scheme in [QuantScheme.PER_CHANNEL_AFFINE, QuantScheme.PER_CHANNEL_SYMMETRIC]:
return True
else:
return False
def get_quant_shape(shape, quant_type, quant_scheme):
default_idx = 0 if quant_type == QuantType.WEIGHT else 1
if is_per_channel(quant_scheme):
quant_shape = [1 if idx != default_idx else s for idx, s in enumerate(shape)]
else:
quant_shape = []
return quant_shape
......@@ -9,6 +9,7 @@ import torch.nn.functional as F
import schema
import nni.algorithms.compression.pytorch.pruning as torch_pruner
import nni.algorithms.compression.pytorch.quantization as torch_quantizer
from nni.compression.pytorch.quantization.utils import calculate_qmin_qmax, get_quant_shape, get_min_max_value
import math
......@@ -50,7 +51,8 @@ class CompressorTestCase(TestCase):
model.relu = torch.nn.ReLU()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer)
dummy = torch.randn(1, 1, 28, 28)
quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer, dummy_input=dummy)
quantizer.compress()
modules_to_compress = quantizer.get_modules_to_compress()
modules_to_compress_name = [t[0].name for t in modules_to_compress]
......@@ -332,6 +334,130 @@ class CompressorTestCase(TestCase):
self.assertFalse(isinstance(model.fc1.module.weight, torch.nn.Parameter))
self.assertFalse(isinstance(model.fc2.module.weight, torch.nn.Parameter))
def test_quantization_dtype_scheme(self):
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 2, 3, 1)
self.bn1 = torch.nn.BatchNorm2d(2)
def forward(self, x):
x = self.bn1(self.conv1(x))
return x
dtypes = ['int', 'uint']
qschemes = ['per_tensor_affine', 'per_tensor_symmetric', 'per_channel_affine', 'per_channel_symmetric']
for dtype in dtypes:
for qscheme in qschemes:
config_list = [{
'quant_types': ['weight', 'input'],
'quant_bits': 8,
'op_types': ['Conv2d'],
'quant_dtype': dtype,
'quant_scheme': qscheme
}]
model = TestModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
# only QAT_quantizer is supported for now
dummy = torch.randn(1, 1, 4, 4)
quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer, dummy_input=dummy)
# test layer setting
for layer, config in quantizer.modules_to_compress:
module = layer.module
name = layer.name
layer_setting = module.layer_quant_setting
qmin, qmax = calculate_qmin_qmax(8, dtype)
all_quant_types = ['input', 'weight']
for quant_type in all_quant_types:
# check for settings
tensor_setting = getattr(layer_setting, quant_type)
self.assertTrue(tensor_setting is not None)
self.assertTrue(tensor_setting.quant_scheme == qscheme)
self.assertTrue(tensor_setting.quant_dtype == dtype)
self.assertTrue(tensor_setting.qmin == qmin)
self.assertTrue(tensor_setting.qmax == qmax)
input_shape, output_shape = quantizer.all_shapes[name]
shape = input_shape if quant_type == 'input' else module.weight.shape
quant_shape = get_quant_shape(shape, quant_type, qscheme)
scale_name = quant_type + '_scale'
zero_point_name = quant_type + '_zero_point'
scale = getattr(module, scale_name)
zero_point = getattr(module, zero_point_name)
self.assertTrue(list(scale.shape) == quant_shape)
self.assertTrue(list(zero_point.shape) == quant_shape)
weight = torch.arange(start=1, end=19).view(2, 1, 3, 3)
if qscheme == 'per_channel_symmetric':
if dtype == 'int':
target_scale = torch.tensor([9. / 127, 18. / 127]).view([2, 1, 1, 1])
target_zero_point = torch.ones([2, 1, 1, 1]) * 0
else:
target_scale = torch.tensor([9. / 127.5, 18. / 127.5]).view([2, 1, 1, 1])
target_zero_point = torch.ones([2, 1, 1, 1]) * 127
elif qscheme == 'per_tensor_symmetric':
if dtype == 'int':
target_scale = torch.tensor(18. / 127)
target_zero_point = torch.zeros([])
else:
target_scale = torch.tensor(18. / 127.5)
target_zero_point = torch.ones([]) * 127
elif qscheme == 'per_channel_affine':
min_val = torch.tensor([0., 0.]).view([2, 1, 1, 1])
if dtype == 'int':
target_scale = torch.tensor([9. / 254, 18. / 254]).view([2, 1, 1, 1])
target_zero_point = -127 - torch.round(min_val / target_scale)
else:
target_scale = torch.tensor([9. / 255, 18. / 255]).view([2, 1, 1, 1])
target_zero_point = 0 - torch.round(min_val / target_scale)
else:
if dtype == 'int':
target_scale = torch.tensor(18. / 254)
target_zero_point = -127 - torch.round(0 / target_scale)
else:
target_scale = torch.tensor(18. / 255)
target_zero_point = 0 - torch.round(0 / target_scale)
wrapper = getattr(model, name)
wrapper.module.weight = weight
quantizer.quantize_weight(wrapper)
self.assertTrue(torch.equal(getattr(model, name).module.weight_scale, target_scale))
self.assertTrue(torch.equal(getattr(model, name).module.weight_zero_point, target_zero_point))
inp = torch.arange(start=0, end=16).view(1, 1, 4, 4)
if qscheme == 'per_channel_symmetric':
if dtype == 'int':
target_scale = torch.tensor([15. / 127]).view([1, 1, 1, 1])
target_zero_point = torch.ones([1, 1, 1, 1]) * 0
else:
target_scale = torch.tensor([15. / 127.5]).view([1, 1, 1, 1])
target_zero_point = torch.ones([1, 1, 1, 1]) * 127
elif qscheme == 'per_tensor_symmetric':
if dtype == 'int':
target_scale = torch.tensor(15. / 127)
target_zero_point = torch.zeros([])
else:
target_scale = torch.tensor(15. / 127.5)
target_zero_point = torch.ones([]) * 127
elif qscheme == 'per_channel_affine':
min_val = torch.tensor([0.]).view([1, 1, 1, 1])
if dtype == 'int':
target_scale = torch.tensor([15. / 254]).view([1, 1, 1, 1])
target_zero_point = -127 - torch.round(min_val / target_scale)
else:
target_scale = torch.tensor([15. / 255]).view([1, 1, 1, 1])
target_zero_point = 0 - torch.round(min_val / target_scale)
else:
if dtype == 'int':
target_scale = torch.tensor(15. / 254)
target_zero_point = -127 - torch.round(0 / target_scale)
else:
target_scale = torch.tensor(15. / 255)
target_zero_point = 0 - torch.round(0 / target_scale)
quantizer.quantize_input(inp, wrapper)
self.assertTrue(torch.equal(getattr(model, name).module.input_scale, target_scale))
self.assertTrue(torch.equal(getattr(model, name).module.input_zero_point, target_zero_point))
def test_torch_QAT_quantizer(self):
model = TorchModel()
config_list = [{
......@@ -347,7 +473,8 @@ class CompressorTestCase(TestCase):
model.relu = torch.nn.ReLU()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer)
dummy = torch.randn(1, 1, 28, 28)
quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer, dummy_input=dummy)
quantizer.compress()
# test quantize
......@@ -357,20 +484,20 @@ class CompressorTestCase(TestCase):
weight = torch.tensor([[1, 2], [3, 5]]).float()
model.conv2.module.weight.data = weight
quantizer.quantize_weight(model.conv2, input_tensor=input)
assert math.isclose(model.conv2.module.scale, 5 / 255, abs_tol=eps)
assert model.conv2.module.zero_point == 0
assert math.isclose(model.conv2.module.weight_scale, 5 / 255, abs_tol=eps)
assert model.conv2.module.weight_zero_point == 0
quantizer.quantize_input(input, model.conv2)
self.assertTrue(torch.allclose(model.conv2.module.scale, torch.tensor([0.04 / 255])))
self.assertTrue(torch.equal(model.conv2.module.zero_point, torch.tensor([0.])))
self.assertTrue(torch.allclose(model.conv2.module.input_scale, torch.tensor([4. / 255])))
self.assertTrue(torch.equal(model.conv2.module.input_zero_point, torch.tensor(0.)))
# range including 0
weight = torch.tensor([[-1, 2], [3, 5]]).float()
model.conv2.module.weight = weight
quantizer.quantize_weight(model.conv2, input_tensor=input)
assert math.isclose(model.conv2.module.scale, 6 / 255, abs_tol=eps)
assert model.conv2.module.zero_point in (42, 43)
assert math.isclose(model.conv2.module.weight_scale, 6 / 255, abs_tol=eps)
assert model.conv2.module.weight_zero_point in (42, 43)
quantizer.quantize_input(input, model.conv2)
self.assertTrue(torch.allclose(model.conv2.module.scale, torch.tensor([0.0796 / 255])))
self.assertTrue(torch.equal(model.conv2.module.zero_point, torch.tensor([0.])))
self.assertTrue(torch.allclose(model.conv2.module.input_scale, torch.tensor([4. / 255])))
self.assertTrue(torch.equal(model.conv2.module.input_zero_point, torch.tensor(0.)))
# test value of weight and bias after quantization
weight = torch.tensor([[1.1287, 2.3456], [3.7814, 5.9723]])
weight_valid = torch.tensor([[1.1242, 2.3421], [3.7707, 5.9723]])
......@@ -385,15 +512,15 @@ class CompressorTestCase(TestCase):
# test ema
eps = 1e-7
x = torch.tensor([[-0.2, 0], [0.1, 0.2]])
out = model.relu(x)
assert math.isclose(model.relu.module.tracked_min_output, 0, abs_tol=eps)
assert math.isclose(model.relu.module.tracked_max_output, 0.002, abs_tol=eps)
model.relu(x)
self.assertTrue(torch.equal(model.relu.module.tracked_min_output, torch.tensor(0.)))
self.assertTrue(torch.equal(model.relu.module.tracked_max_output, torch.tensor(0.2)))
quantizer.step_with_optimizer()
x = torch.tensor([[0.2, 0.4], [0.6, 0.8]])
out = model.relu(x)
assert math.isclose(model.relu.module.tracked_min_output, 0.002, abs_tol=eps)
assert math.isclose(model.relu.module.tracked_max_output, 0.00998, abs_tol=eps)
model.relu(x)
self.assertTrue(torch.equal(model.relu.module.tracked_min_output, torch.tensor(0.002)))
self.assertTrue(torch.equal(model.relu.module.tracked_max_output, torch.tensor(0.2060)))
def test_torch_quantizer_export(self):
config_list_qat = [{
......@@ -424,12 +551,15 @@ class CompressorTestCase(TestCase):
}]
config_set = [config_list_qat, config_list_dorefa, config_list_bnn]
quantize_algorithm_set = [torch_quantizer.QAT_Quantizer, torch_quantizer.DoReFaQuantizer, torch_quantizer.BNNQuantizer]
dummy = torch.randn(1, 1, 28, 28)
for config, quantize_algorithm in zip(config_set, quantize_algorithm_set):
model = TorchModel()
model.relu = torch.nn.ReLU()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
quantizer = quantize_algorithm(model, config, optimizer)
if quantize_algorithm == torch_quantizer.QAT_Quantizer:
quantizer = quantize_algorithm(model, config, optimizer, dummy)
else:
quantizer = quantize_algorithm(model, config, optimizer)
quantizer.compress()
x = torch.rand((1, 1, 28, 28), requires_grad=True)
......@@ -461,7 +591,11 @@ class CompressorTestCase(TestCase):
model = TorchModel().eval()
model.relu = torch.nn.ReLU()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
quantizer = quantize_algorithm(model, configure_list, optimizer)
if quantize_algorithm == torch_quantizer.QAT_Quantizer:
dummy = torch.randn(1, 1, 28, 28)
quantizer = quantize_algorithm(model, configure_list, optimizer, dummy_input=dummy)
else:
quantizer = quantize_algorithm(model, configure_list, optimizer)
quantizer.compress()
if calibration_config is not None:
quantizer.load_calibration_config(calibration_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment