Unverified Commit c9cd53aa authored by chenbohua3's avatar chenbohua3 Committed by GitHub
Browse files

support dtype&scheme customization for QAT quantizer (#4137)

parent b0f34da1
......@@ -155,7 +155,7 @@ Sometimes it's necessary for a quantization operation to have a customized backw
grad_output : Tensor
gradient of the output of quantization operation
quant_type : QuantType
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`,
the type of quantization, it can be `QuantType.INPUT`, `QuantType.WEIGHT`, `QuantType.OUTPUT`,
you can define different behavior for different types.
Returns
-------
......@@ -164,7 +164,7 @@ Sometimes it's necessary for a quantization operation to have a customized backw
"""
# for quant_output function, set grad to zero if the absolute value of tensor is larger than 1
if quant_type == QuantType.QUANT_OUTPUT:
if quant_type == QuantType.OUTPUT:
grad_output[torch.abs(tensor) > 1] = 0
return grad_output
......
......@@ -2,11 +2,13 @@ import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer
from nni.compression.pytorch.quantization.settings import set_quant_scheme_dtype
import sys
sys.path.append('../models')
from mnist.naive import NaiveModel
def train(model, device, train_loader, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
......@@ -68,13 +70,23 @@ def main():
}, {
'quant_types': ['output', 'weight', 'input'],
'quant_bits': {'output': 8, 'weight': 8, 'input': 8},
'op_names': ['fc1'],
}, {
'quant_types': ['output', 'weight', 'input'],
'quant_bits': {'output': 8, 'weight': 8, 'input': 8},
'op_names': ['fc2'],
'op_names': ['fc1', 'fc2'],
}]
# you can also set the quantization dtype and scheme layer-wise through configure_list like:
# configure_list = [{
# 'quant_types': ['weight', 'input'],
# 'quant_bits': {'weight': 8, 'input': 8},
# 'op_names': ['conv1', 'conv2'],
# 'quant_dtype': 'int',
# 'quant_scheme': 'per_channel_symmetric'
# }]
# For now quant_dtype's options are 'int' and 'uint. And quant_scheme's options are per_tensor_affine,
# per_tensor_symmetric, per_channel_affine and per_channel_symmetric.
set_quant_scheme_dtype('weight', 'per_channel_symmetric', 'int')
set_quant_scheme_dtype('output', 'per_tensor_symmetric', 'int')
set_quant_scheme_dtype('input', 'per_tensor_symmetric', 'int')
model = NaiveModel().to(device)
dummy_input = torch.randn(1, 1, 28, 28).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
......@@ -98,5 +110,6 @@ def main():
calibration_config = quantizer.export_model(model_path, calibration_path, onnx_path, input_shape, device)
print("Generated calibration config is: ", calibration_config)
if __name__ == '__main__':
main()
......@@ -6,9 +6,21 @@ from collections import defaultdict
import torch
from schema import Schema, And, Or, Optional
from nni.compression.pytorch.utils.config_validation import QuantizerSchema
from nni.compression.pytorch.compressor import BN_FOLD_TAG, Quantizer, QuantForward, QuantGrad, QuantType
from .observers import default_weight_observer, default_histogram_observer
from nni.compression.pytorch.compressor import BN_FOLD_TAG, Quantizer, QuantForward, QuantGrad
from nni.compression.pytorch.quantization.literal import (
PER_CHANNEL_QUANT_SCHEME,
QuantScheme,
QuantDtype,
QuantType
)
from nni.compression.pytorch.quantization.observers import default_weight_observer, default_histogram_observer
from nni.compression.pytorch.quantization.settings import LayerQuantSetting
from nni.compression.pytorch.quantization.utils import (
calculate_qmin_qmax,
get_bits_length,
get_min_max_value,
get_quant_shape
)
__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer', 'LsqQuantizer', 'ObserverQuantizer']
......@@ -65,7 +77,7 @@ def update_ema(biased_ema, value, decay):
return biased_ema
def update_quantization_param(bits, rmin, rmax):
def update_quantization_param(bits, rmin, rmax, dtype, scheme):
"""
calculate the `zero_point` and `scale`.
......@@ -77,41 +89,46 @@ def update_quantization_param(bits, rmin, rmax):
min value of real value
rmax : Tensor
max value of real value
dtype : QuantDtype
quantized data type
scheme : QuantScheme
quantization scheme to be used
Returns
-------
float, float
"""
# extend the [min, max] interval to ensure that it contains 0.
# Otherwise, we would not meet the requirement that 0 be an exactly
# representable value.
rmin = torch.min(rmin, torch.Tensor([0]).to(rmin.device))
rmax = torch.max(rmax, torch.Tensor([0]).to(rmin.device))
qmin = torch.Tensor([0]).to(rmin.device)
qmax = torch.Tensor([(1 << bits) - 1]).to(rmin.device)
# First determine the scale.
scale = (rmax - rmin) / (qmax - qmin)
# Zero-point computation.
initial_zero_point = qmin - rmin / scale
# Now we need to nudge the zero point to be an integer
if initial_zero_point < qmin:
nudged_zero_point = qmin
elif initial_zero_point > qmax:
nudged_zero_point = qmax
# I think this is for activations that need to be pad in the training.
# However this is a default behavior in PyTorch quantization observer.
# So we also make it a default behavior
rmin = torch.min(rmin, torch.zeros_like(rmin))
rmax = torch.max(rmax, torch.zeros_like(rmax))
zero_point = torch.zeros_like(rmin)
# todo: there is no need to calculate qmin and qmax again
qmin, qmax = calculate_qmin_qmax(bits, dtype)
if scheme in [QuantScheme.PER_TENSOR_SYMMETRIC, QuantScheme.PER_CHANNEL_SYMMETRIC]:
abs_max = torch.max(torch.abs(rmin), torch.abs(rmax))
scale = abs_max / (float(qmax - qmin) / 2)
if dtype == QuantDtype.UINT:
zero_point_val = (qmin + qmax) // 2
zero_point = zero_point.new_full(zero_point.size(), zero_point_val)
else:
nudged_zero_point = torch.round(initial_zero_point)
scale = (rmax - rmin) / float(qmax - qmin)
zero_point = qmin - torch.round(rmin / scale)
return scale, nudged_zero_point
zero_point = torch.clamp(zero_point, qmin, qmax)
# todo: add these lines
# eps = torch.finfo(torch.float32).eps
# scale = torch.max(scale, eps)
return scale, zero_point
def get_bits_length(config, quant_type):
if isinstance(config["quant_bits"], int):
return config["quant_bits"]
else:
return config["quant_bits"].get(quant_type)
class QATGrad(QuantGrad):
@staticmethod
......@@ -384,22 +401,49 @@ class QAT_Quantizer(Quantizer):
self.bound_model.register_buffer("steps", torch.tensor(1))
for layer, config in modules_to_compress:
module = layer.module
module.register_buffer("zero_point", torch.tensor([0.0]))
module.register_buffer("scale", torch.tensor([1.0]))
module.register_buffer('ema_decay', torch.tensor([0.99]))
name = layer.name
# TODO: may relax this limitation?
assert name in self.all_shapes, "Could not found shapes for layer {}".format(name)
input_shape, output_shape = self.all_shapes[name]
layer_quant_setting = LayerQuantSetting(config)
layer_quant_setting.ema_decay = 0.99
quant_start_step = config.get('quant_start_step', 0)
layer_quant_setting.quant_start_step = quant_start_step
# todo: support other ranks and remove this check
if isinstance(module, torch.nn.Linear):
if "input" in config.get("quant_types", []) and \
layer_quant_setting.input.quant_scheme in PER_CHANNEL_QUANT_SCHEME:
if len(input_shape) != 2:
logger.warning("When quantize torch.nn.Linear, make sure that the rank of the inputs "
"of the layer is 2. Skip quantization of layer %s.", name)
continue
if "output" in config.get("quant_types", []) and \
layer_quant_setting.output.quant_scheme in PER_CHANNEL_QUANT_SCHEME:
if len(output_shape) != 2:
logger.warning("When quantize torch.nn.Linear, make sure that the rank of the outputs "
"of the layer is 2. Skip quantization of layer %s.", name)
continue
if "weight" in config.get("quant_types", []):
weight_bits = get_bits_length(config, 'weight')
layer.module.register_buffer('weight_bits', torch.Tensor([int(weight_bits)]))
quant_shape = get_quant_shape(module.weight.shape, QuantType.WEIGHT, layer_quant_setting.weight.quant_scheme)
module.register_buffer('weight_scale', torch.zeros(quant_shape))
module.register_buffer('weight_zero_point', torch.zeros(quant_shape))
if "input" in config.get("quant_types", []):
input_bits = get_bits_length(config, 'input')
layer.module.register_buffer('tracked_min_input', torch.zeros(1))
layer.module.register_buffer('tracked_max_input', torch.zeros(1))
layer.module.register_buffer('input_bits', torch.Tensor([int(input_bits)]))
quant_shape = get_quant_shape(input_shape, QuantType.INPUT, layer_quant_setting.input.quant_scheme)
module.register_buffer('tracked_min_input', torch.zeros(quant_shape))
module.register_buffer('tracked_max_input', torch.zeros(quant_shape))
module.register_buffer('input_scale', torch.zeros(quant_shape))
module.register_buffer('input_zero_point', torch.zeros(quant_shape))
if "output" in config.get("quant_types", []):
output_bits = get_bits_length(config, 'output')
layer.module.register_buffer('output_bits', torch.Tensor([int(output_bits)]))
layer.module.register_buffer('tracked_min_output', torch.zeros(1))
layer.module.register_buffer('tracked_max_output', torch.zeros(1))
quant_shape = get_quant_shape(output_shape, QuantType.OUTPUT, layer_quant_setting.output.quant_scheme)
module.register_buffer('tracked_min_output', torch.zeros(quant_shape))
module.register_buffer('tracked_max_output', torch.zeros(quant_shape))
module.register_buffer('output_scale', torch.zeros(quant_shape))
module.register_buffer('output_zero_point', torch.zeros(quant_shape))
setattr(module, "layer_quant_setting", layer_quant_setting)
self.bound_model.to(device)
def _del_simulated_attr(self, module):
......@@ -407,8 +451,9 @@ class QAT_Quantizer(Quantizer):
delete redundant parameters in quantize module
"""
del_attr_list = ['old_weight', 'old_bias', 'ema_decay', 'tracked_min_output', 'tracked_max_output',
'tracked_min_input', 'tracked_max_input', 'scale', 'zero_point', 'weight_bits',
'output_bits', 'BN_FOLD_TAG', 'input_bits']
'tracked_min_input', 'tracked_max_input', 'BN_FOLD_TAG',
'weight_scale', 'weight_zero_point', 'input_scale', 'input_zero_point',
'output_scale', 'output_zero_point', 'layer_quant_setting']
for attr in del_attr_list:
if hasattr(module, attr):
delattr(module, attr)
......@@ -422,6 +467,7 @@ class QAT_Quantizer(Quantizer):
config_list : list of dict
List of configurations
"""
SUPPORTED_OPS = ['Conv2d', 'Linear', 'ReLU', 'ReLU6']
schema = QuantizerSchema([{
Optional('quant_types'): Schema([lambda x: x in ['weight', 'output', 'input']]),
Optional('quant_bits'): Or(And(int, lambda n: 0 < n < 32), Schema({
......@@ -429,41 +475,51 @@ class QAT_Quantizer(Quantizer):
Optional('weight'): And(int, lambda n: 0 < n < 32),
Optional('output'): And(int, lambda n: 0 < n < 32),
})),
Optional('quant_scheme'): Or(lambda x: x in QuantScheme, Schema({
Optional('input'): lambda x: x in QuantScheme,
Optional('weight'): lambda x: x in QuantScheme,
Optional('output'): lambda x: x in QuantScheme
})),
Optional('quant_dtype'): Or(lambda x: x in QuantDtype, Schema({
Optional('input'): lambda x: x in QuantDtype,
Optional('weight'): lambda x: x in QuantDtype,
Optional('output'): lambda x: x in QuantDtype
})),
Optional('quant_start_step'): And(int, lambda n: n >= 0),
Optional('op_types'): [str],
Optional('op_types'): [And(str, lambda n: n in SUPPORTED_OPS)],
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)
schema.validate(config_list)
def _quantize(self, bits, op, real_val):
def _quantize(self, real_value, scale, zero_point, qmin, qmax):
"""
quantize real value.
Parameters
----------
bits : int
quantization bits length
op : torch.nn.Module
target module
real_val : Tensor
real value to be quantized
real_value : torch.Tensor
the real value to be quantized
scale : torch.Tensor
quantization scale
zero_point : torch.Tensor
quantization zero point
qmin : int
lower bound of the int range
qmax : int
upper bound of the int range
Returns
-------
Tensor
"""
op.zero_point = op.zero_point.to(real_val.device)
op.scale = op.scale.to(real_val.device)
transformed_val = op.zero_point + real_val / op.scale
qmin = 0
qmax = (1 << bits) - 1
transformed_val = zero_point + real_value / scale
clamped_val = torch.clamp(transformed_val, qmin, qmax)
quantized_val = torch.round(clamped_val)
return quantized_val
def _dequantize(self, op, quantized_val):
def _dequantize(self, quantized_val, scale, zero_point):
"""
dequantize quantized value.
Because we simulate quantization in training process, all the computations still happen as float point computations, which means we
......@@ -471,103 +527,149 @@ class QAT_Quantizer(Quantizer):
Parameters
----------
op : torch.nn.Module
target module
quantized_val : float
quantized_val value to be dequantized
quantized_val : torch.Tensor
the quantized value to be de-quantized
scale : torch.Tensor
quantization scale
zero_point : torch.Tensor
quantization zero point
Returns
-------
float
Tensor
"""
real_val = op.scale * (quantized_val - op.zero_point)
real_val = scale * (quantized_val - zero_point)
return real_val
def quantize_weight(self, wrapper, **kwargs):
config = wrapper.config
module = wrapper.module
weight = module.weight
weight_bits = int(module.weight_bits)
quant_start_step = config.get('quant_start_step', 0)
assert weight_bits >= 1, "quant bits length should be at least 1"
layer_quant_setting = module.layer_quant_setting
tensor_quant_setting = layer_quant_setting.weight
if quant_start_step > int(self.bound_model.steps):
return weight
# layer-wise settings
quant_start_step = layer_quant_setting.quant_start_step
# tensor-wise settings
dtype = tensor_quant_setting.quant_dtype
scheme = tensor_quant_setting.quant_scheme
qmin, qmax = tensor_quant_setting.get_qmin_qmax()
bits = tensor_quant_setting.bits
# In evaluation mode, we only quantize weight without updating statistics
if not wrapper.training:
scale, zero_point = module.weight_scale, module.weight_zero_point
weight = self._quantize(weight, scale, zero_point, qmin, qmax)
weight = self._dequantize(weight, scale, zero_point)
module.weight = weight
return weight
if quant_start_step > int(self.bound_model.steps):
return weight
# quantize weight
rmin, rmax = torch.min(weight), torch.max(weight)
scale, zero_point = update_quantization_param(weight_bits, rmin, rmax)
module.scale.copy_(scale)
module.zero_point.copy_(zero_point)
weight = self._quantize(weight_bits, module, weight)
weight = self._dequantize(module, weight)
current_min, current_max = get_min_max_value(weight, QuantType.WEIGHT, scheme)
scale, zero_point = update_quantization_param(bits, current_min, current_max, dtype, scheme)
module.weight_scale.copy_(scale)
module.weight_zero_point.copy_(zero_point)
weight = self._quantize(weight, scale, zero_point, qmin, qmax)
weight = self._dequantize(weight, scale, zero_point)
# Weight can not be in-place modified, so when use torch.nn.DataParallel, this update
# will be lost after each forward process. However, this update takes effect on each
# replicated module during each forward process, which will make the quantized weight
# be used correctly.
wrapper.module.weight = weight
return weight
def quantize_input(self, inputs, wrapper, **kwargs):
config = wrapper.config
module = wrapper.module
input_bits = int(module.input_bits)
quant_start_step = config.get('quant_start_step', 0)
assert input_bits >= 1, "quant bits length should be at least 1"
if quant_start_step > int(self.bound_model.steps):
current_min, current_max = torch.min(inputs), torch.max(inputs)
module.tracked_min_input.copy_(current_min)
module.tracked_max_input.copy_(current_max)
layer_quant_setting = module.layer_quant_setting
tensor_quant_setting = layer_quant_setting.input
# layer-wise settings
quant_start_step = layer_quant_setting.quant_start_step
ema_decay = layer_quant_setting.ema_decay
# tensor-wise settings
dtype = tensor_quant_setting.quant_dtype
scheme = tensor_quant_setting.quant_scheme
qmin, qmax = tensor_quant_setting.get_qmin_qmax()
bits = tensor_quant_setting.bits
if not wrapper.training:
scale = module.input_scale
zero_point = module.input_zero_point
inputs = self._quantize(inputs, scale, zero_point, qmin, qmax)
inputs = self._dequantize(inputs, scale, zero_point)
return inputs
# we dont update output quantization parameters in evaluation stage
if wrapper.training:
current_min, current_max = torch.min(inputs), torch.max(inputs)
current_min = update_ema(module.tracked_min_input, current_min, module.ema_decay)
current_max = update_ema(module.tracked_max_input, current_max, module.ema_decay)
current_min, current_max = get_min_max_value(inputs, QuantType.INPUT, scheme)
if int(self.bound_model.steps) == 1:
module.tracked_min_input.copy_(current_min)
module.tracked_max_input.copy_(current_max)
tracked_min_input = update_ema(module.tracked_min_input, current_min, ema_decay)
tracked_max_input = update_ema(module.tracked_max_input, current_max, ema_decay)
module.tracked_min_input.copy_(tracked_min_input)
module.tracked_max_input.copy_(tracked_max_input)
if quant_start_step > int(self.bound_model.steps):
return inputs
scale, zero_point = update_quantization_param(
input_bits, module.tracked_min_input, module.tracked_max_input)
module.scale.copy_(scale)
module.zero_point.copy_(zero_point)
bits, module.tracked_min_input, module.tracked_max_input, dtype, scheme)
module.input_scale.copy_(scale)
module.input_zero_point.copy_(zero_point)
inp = self._quantize(input_bits, module, inputs)
inp = self._dequantize(module, inp)
return inp
inputs = self._quantize(inputs, scale, zero_point, qmin, qmax)
inputs = self._dequantize(inputs, scale, zero_point)
return inputs
def quantize_output(self, output, wrapper, **kwargs):
config = wrapper.config
module = wrapper.module
output_bits = int(module.output_bits)
quant_start_step = config.get('quant_start_step', 0)
assert output_bits >= 1, "quant bits length should be at least 1"
layer_quant_setting = module.layer_quant_setting
tensor_quant_setting = layer_quant_setting.output
if quant_start_step > int(self.bound_model.steps):
current_min, current_max = torch.min(output), torch.max(output)
# layer-wise settings
quant_start_step = layer_quant_setting.quant_start_step
ema_decay = layer_quant_setting.ema_decay
# tensor-wise settings
dtype = tensor_quant_setting.quant_dtype
scheme = tensor_quant_setting.quant_scheme
qmin, qmax = tensor_quant_setting.get_qmin_qmax()
bits = tensor_quant_setting.bits
if not wrapper.training:
scale = module.output_scale
zero_point = module.output_zero_point
output = self._quantize(output, scale, zero_point, qmin, qmax)
output = self._dequantize(output, scale, zero_point)
return output
current_min, current_max = get_min_max_value(output, QuantType.OUTPUT, scheme)
if int(self.bound_model.steps) == 1:
module.tracked_min_output.copy_(current_min)
module.tracked_max_output.copy_(current_max)
return output
# we dont update output quantization parameters in evaluation stage
if wrapper.training:
current_min, current_max = torch.min(output), torch.max(output)
tracked_min_output = update_ema(module.tracked_min_output, current_min,
module.ema_decay)
tracked_max_output = update_ema(module.tracked_max_output, current_max,
module.ema_decay)
tracked_min_output = update_ema(module.tracked_min_output, current_min, ema_decay)
tracked_max_output = update_ema(module.tracked_max_output, current_max, ema_decay)
module.tracked_min_output.copy_(tracked_min_output)
module.tracked_max_output.copy_(tracked_max_output)
if quant_start_step > int(self.bound_model.steps):
return output
scale, zero_point = update_quantization_param(
output_bits, module.tracked_min_output, module.tracked_max_output)
module.scale.copy_(scale)
module.zero_point.copy_(zero_point)
bits, module.tracked_min_output, module.tracked_max_output, dtype, scheme)
module.output_scale.copy_(scale)
module.output_zero_point.copy_(zero_point)
out = self._quantize(output_bits, module, output)
out = self._dequantize(module, out)
return out
output = self._quantize(output, scale, zero_point, qmin, qmax)
output = self._dequantize(output, scale, zero_point)
return output
def load_calibration_config(self, calibration_config):
modules_to_compress = self.get_modules_to_compress()
......@@ -581,12 +683,12 @@ class QAT_Quantizer(Quantizer):
assert calibration_config[name]['weight_bits'] == module.weight_bits, f"weight bits of module {name} fail to match"
if hasattr(module, 'input_bits'):
assert calibration_config[name]['input_bits'] == module.input_bits, f"input bits of module {name} fail to match"
module.tracked_min_input.data = torch.Tensor([calibration_config[name]['tracked_min_input']])
module.tracked_max_input.data = torch.Tensor([calibration_config[name]['tracked_max_input']])
module.tracked_min_input.data = torch.tensor([calibration_config[name]['tracked_min_input']])
module.tracked_max_input.data = torch.tensor([calibration_config[name]['tracked_max_input']])
if hasattr(module, 'output_bits'):
assert calibration_config[name]['output_bits'] == module.output_bits, f"output bits of module {name} fail to match"
module.tracked_min_output.data = torch.Tensor([calibration_config[name]['tracked_min_output']])
module.tracked_max_output.data = torch.Tensor([calibration_config[name]['tracked_max_output']])
module.tracked_min_output.data = torch.tensor([calibration_config[name]['tracked_min_output']])
module.tracked_max_output.data = torch.tensor([calibration_config[name]['tracked_max_output']])
def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None):
"""
......@@ -619,6 +721,8 @@ class QAT_Quantizer(Quantizer):
calibration_config[name] = {}
if hasattr(module, 'weight_bits'):
calibration_config[name]['weight_bits'] = int(module.weight_bits)
calibration_config[name]['weight_scale'] = module.weight_scale
calibration_config[name]['weight_zero_point'] = module.weight_zero_point
# Recover weight/bias for batch normalization folding
actual_weight = getattr(module, 'old_weight', None)
......@@ -759,7 +863,7 @@ class DoReFaQuantizer(Quantizer):
class ClipGrad(QuantGrad):
@staticmethod
def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax):
if quant_type == QuantType.QUANT_OUTPUT:
if quant_type == QuantType.OUTPUT:
grad_output[torch.abs(tensor) > 1] = 0
return grad_output
......
import logging
try:
import torch
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split(".")[:2])
except Exception:
logging.info("PyTorch is not installed.")
TORCH_VERSION = None
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import types
import logging
import torch
from nni.common.graph_utils import build_module_graph
from nni.compression.pytorch.quantization.literal import QuantType, BN_FOLD_OP, BN_FOLD_TAG
from nni.compression.pytorch.quantization.observers import RecordingObserver
from . import default_layers
_logger = logging.getLogger(__name__)
......@@ -547,7 +550,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
assert len(inputs) == 1, "Quantization of input only supports ops with single input."
new_inp = self.quantizer.quant_grad(
inputs[0],
QuantType.QUANT_INPUT,
QuantType.INPUT,
self)
inputs = (new_inp,)
......@@ -563,7 +566,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
self.quantizer.quant_grad(
new_weight,
QuantType.QUANT_WEIGHT,
QuantType.WEIGHT,
self, inputs[0])
result = self.module(*inputs)
......@@ -571,7 +574,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
if 'output' in self.config['quant_types']:
result = self.quantizer.quant_grad(
result,
QuantType.QUANT_OUTPUT,
QuantType.OUTPUT,
self)
return result
......@@ -604,10 +607,13 @@ class Quantizer(Compressor):
def __init__(self, model, config_list, optimizer=None, dummy_input=None):
if isinstance(model, torch.nn.DataParallel):
model = model.module
model_copied = copy.deepcopy(model)
self.identity_wrappers = []
self.conv_bn_patterns = {}
self.find_conv_bn_patterns(model, dummy_input)
super().__init__(model, config_list, optimizer)
self.all_shapes = {}
self.record_shape(model_copied, dummy_input)
self.quant_grad = QuantGrad.apply
if self.optimizer is not None:
self.patch_optimizer(self.step_with_optimizer)
......@@ -845,25 +851,54 @@ class Quantizer(Compressor):
if successor.op_type == 'BatchNorm2d':
self.conv_bn_patterns[node_group.name] = successor.name
def step_with_optimizer(self):
pass
class QuantType:
def record_shape(self, model, dummy_input):
"""
Enum class for quantization type.
Record input/output's shapes of each module to be quantized
Parameters
----------
model : torch.nn.Module
model to be recorded.
dummy_input : tupel of torch.tensor
inputs to the model.
"""
QUANT_INPUT = 0
QUANT_WEIGHT = 1
QUANT_OUTPUT = 2
def _pre_forward_hook(self, inp):
# Only record the first tensor of the input
return self.pre_forward(inp[0])
def _post_forward_hook(self, _, out):
return self.post_forward(out)
QType_Dict = {
0: "input",
1: "weight",
2: "output"
}
if dummy_input is None:
return
all_handles = []
all_observers = {}
modules_to_compress = self.get_modules_to_compress()
compress_names = [layer_info[0].name for layer_info in modules_to_compress]
for name, module in model.named_modules():
if name in compress_names:
all_observers[name] = {}
all_observers[name]['input_hook'] = RecordingObserver()
all_observers[name]['output_hook'] = RecordingObserver()
module.add_module('pre_forward', all_observers[name]['input_hook'])
module.add_module('post_forward', all_observers[name]['output_hook'])
all_handles.append(module.register_forward_pre_hook(_pre_forward_hook))
all_handles.append(module.register_forward_hook(_post_forward_hook))
model(dummy_input)
for name, hooks in all_observers.items():
# only support single input
input_val = hooks['input_hook'].tensor_val
input_shape = input_val[0].shape if input_val else None
output_val = hooks['output_hook'].tensor_val
output_shape = output_val[0].shape if output_val else None
shapes = [input_shape, output_shape]
self.all_shapes[name] = shapes
return
def step_with_optimizer(self):
pass
BN_FOLD_OP = ["Conv2d"]
BN_FOLD_TAG = 'BN_FOLD_TAG'
class QuantGrad(torch.autograd.Function):
"""
......@@ -920,8 +955,8 @@ class QuantGrad(torch.autograd.Function):
grad_output : Tensor
gradient of the output of quantization operation
scale : Tensor
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`,
`QuantType.QUANT_OUTPUT`, you can define different behavior for different types.
the type of quantization, it can be `QuantType.INPUT`, `QuantType.WEIGHT`,
`QuantType.OUTPUT`, you can define different behavior for different types.
zero_point : Tensor
zero_point for quantizing tensor
qmin : Tensor
......@@ -939,28 +974,39 @@ class QuantGrad(torch.autograd.Function):
def forward(ctx, tensor, quant_type, wrapper, input_tensor=None, **kwargs):
output = quantize_helper(tensor, quant_type, wrapper, input_tensor, **kwargs)
bits = QuantGrad.get_bits_length(wrapper.config, QType_Dict[quant_type])
qmin, qmax = torch.Tensor([0]).to(tensor.device), torch.Tensor([(1 << bits) - 1]).to(tensor.device)
if hasattr(wrapper.module, 'scale') and hasattr(wrapper.module, 'zero_point'):
if hasattr(wrapper.module, "layer_quant_setting"):
layer_quant_setting = wrapper.module.layer_quant_setting
qmin, qmax = getattr(layer_quant_setting, quant_type).get_qmin_qmax()
else:
# todo: when dtype/scheme customization is ready for all quantizers, remove this
bits = QuantGrad.get_bits_length(wrapper.config, quant_type)
qmin, qmax = 0, (1 << bits) - 1
scale_name, zero_point_name = quant_type.type_to_scale_zero_point_name()
if hasattr(wrapper.module, scale_name) and hasattr(wrapper.module, zero_point_name):
scale = getattr(wrapper.module, scale_name)
zero_point = getattr(wrapper.module, zero_point_name)
# todo: remove this when other quantizers use different scale & zero point for input/weight/output
elif hasattr(wrapper.module, 'scale') and hasattr(wrapper.module, 'zero_point'):
scale = wrapper.module.scale
zero_point = wrapper.module.zero_point
else:
scale, zero_point = None, None
ctx.save_for_backward(tensor)
# Only tensors have gradients flowing back needs to be saved by save_for_backward.
# Others should directly assign to ctx.
ctx.scale = scale
ctx.zero_point = zero_point
ctx.save_for_backward(tensor)
ctx.quant_type = quant_type
ctx.qmin, ctx.qmax = qmin, qmax
ctx.scale = scale
ctx.zero_point = zero_point
return output
@classmethod
def backward(cls, ctx, grad_output):
tensor = ctx.saved_variables[0]
scale, zero_point = ctx.scale, ctx.zero_point
qmin, qmax = ctx.qmin, ctx.qmax
quant_type = ctx.quant_type
qmin, qmax = ctx.qmin, ctx.qmax
output = cls.quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax)
return output, None, None, None
......@@ -977,11 +1023,11 @@ def _check_bias(module):
return False
def quantize_helper(tensor, quant_type, wrapper, input_tensor=None, **kwargs):
if quant_type == QuantType.QUANT_INPUT:
if quant_type == QuantType.INPUT:
output = wrapper.quantizer.quantize_input(tensor, wrapper=wrapper, **kwargs)
elif quant_type == QuantType.QUANT_WEIGHT:
elif quant_type == QuantType.WEIGHT:
output = wrapper.quantizer.quantize_weight(wrapper, input_tensor=input_tensor, **kwargs)
elif quant_type == QuantType.QUANT_OUTPUT:
elif quant_type == QuantType.OUTPUT:
output = wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs)
else:
raise ValueError("unrecognized QuantType.")
......
from enum import Enum, EnumMeta
class _QuantLiteralEnumMeta(EnumMeta):
def __contains__(cls, item):
try:
cls(item)
except ValueError:
return False
return True
class _QuantLiteralEnum(Enum, metaclass=_QuantLiteralEnumMeta):
pass
class QuantScheme(str, _QuantLiteralEnum):
PER_TENSOR_AFFINE = 'per_tensor_affine'
PER_TENSOR_SYMMETRIC = 'per_tensor_symmetric'
PER_CHANNEL_AFFINE = 'per_channel_affine'
PER_CHANNEL_SYMMETRIC = 'per_channel_symmetric'
PER_CHANNEL_QUANT_SCHEME = [QuantScheme.PER_CHANNEL_AFFINE, QuantScheme.PER_CHANNEL_SYMMETRIC]
class QuantDtype(str, _QuantLiteralEnum):
UINT = 'uint'
INT = 'int'
class QuantType(str, _QuantLiteralEnum):
INPUT = 'input'
WEIGHT = 'weight'
OUTPUT = 'output'
def type_to_scale_zero_point_name(self):
if self == QuantType.INPUT:
return 'input_scale', 'input_zero_point'
elif self == QuantType.WEIGHT:
return 'weight_scale', 'weight_zero_point'
elif self == QuantType.OUTPUT:
return 'output_scale', 'output_zero_point'
else:
raise TypeError
# Just show each attribute's name, no practical effect
class QuantConfigLiteral(str, _QuantLiteralEnum):
QUANT_SETTINGS = 'quant_settings'
QUANT_SCHEME = 'quant_scheme'
QUANT_DTYPE = 'quant_dtype'
BITS = 'bits'
QMIN = 'qmin'
QMAX = 'qmax'
INPUT_SCALE = 'input_scale'
INPUT_ZERO_POINT = 'input_zero_point'
OUTPUT_SCALE = 'output_scale'
OUTPUT_ZERO_POINT = 'output_zero_point'
WEIGHT_SCALE = 'weight_scale'
WEIGHT_ZERO_POINT = 'weight_zero_point'
BN_FOLD_OP = ["Conv2d"]
BN_FOLD_TAG = 'BN_FOLD_TAG'
from torch.quantization import default_weight_observer, default_histogram_observer
from torch.quantization import RecordingObserver as _RecordingObserver
__all__ = ["default_weight_observer", "default_histogram_observer"]
__all__ = ["default_weight_observer", "default_histogram_observer", "RecordingObserver"]
class RecordingObserver(_RecordingObserver):
"""
A extended version of PyTorch's RecordingObserver, used to record gpu tensor
"""
def forward(self, x):
val = x.cpu()
super().forward(val)
return x
from typing import Any, Optional
from .literal import QuantDtype, QuantType, QuantScheme
from .utils import calculate_qmin_qmax, get_bits_length
# default settings for quantization module
quant_default_settings = {
QuantType.WEIGHT: {
'quant_scheme': QuantScheme.PER_TENSOR_AFFINE,
'quant_dtype': QuantDtype.UINT,
},
QuantType.INPUT: {
'quant_scheme': QuantScheme.PER_TENSOR_AFFINE,
'quant_dtype': QuantDtype.UINT
},
QuantType.OUTPUT: {
'quant_scheme': QuantScheme.PER_TENSOR_AFFINE,
'quant_dtype': QuantDtype.UINT
}
}
class TensorQuantSetting(object):
def __init__(self, **kwargs):
self._fields = {}
for k, v in kwargs.items():
self._fields[k] = v
def __setattr__(self, name: str, val: Any) -> None:
if name.startswith("_"):
super().__setattr__(name, val)
else:
self._fields[name] = val
def __getattr__(self, name):
if name == "_fields" or name not in self._fields:
raise AttributeError("Cannot find {} in TensorQuantSetting!".format(name))
return self._fields[name]
def get_qmin_qmax(self):
assert 'qmin' in self._fields and 'qmax' in self._fields, \
"Can not found qmin & qmax in TensorQuantSetting"
return self._fields['qmin'], self._fields['qmax']
class LayerQuantSetting(object):
def __init__(self, config):
self.input: Optional[TensorQuantSetting] = None
self.weight: Optional[TensorQuantSetting] = None
self.output: Optional[TensorQuantSetting] = None
self._extra_layer_setting = {}
for quant_type in QuantType:
if quant_type in config.get("quant_types", []):
setting = TensorQuantSetting()
quant_scheme = self.parse_optional_config(config, quant_type, 'quant_scheme')
setting.quant_scheme = quant_scheme
quant_dtype = self.parse_optional_config(config, quant_type, 'quant_dtype')
setting.quant_dtype = quant_dtype
bits = get_bits_length(config, quant_type)
qmin, qmax = calculate_qmin_qmax(bits, quant_dtype)
setting.bits = bits
setting.qmin = qmin
setting.qmax = qmax
setattr(self, quant_type, setting)
def __setattr__(self, name: str, val: Any) -> None:
if name.startswith("_") or name in QuantType:
super().__setattr__(name, val)
else:
self._extra_layer_setting[name] = val
def __getattr__(self, name):
if name == "_extra_layer_setting" or name not in self._extra_layer_setting:
raise AttributeError("Cannot find {} in LayerQuantSetting!".format(name))
return self._extra_layer_setting[name]
@staticmethod
def parse_optional_config(config, quant_type, target):
def get_config(config, quant_type, target):
if not config.get(target):
return None
if isinstance(config[target], dict):
return config[target].get(quant_type)
else:
return config[target]
default_val = quant_default_settings[quant_type].get(target, None)
config_val = get_config(config, quant_type, target)
val = config_val if config_val else default_val
return val
def set_quant_scheme_dtype(quant_type, new_scheme=None, new_dtype=None):
# todo: remove this if we convert string config to enum type.
if isinstance(quant_type, str):
assert quant_type in QuantType, "Wrong quant_type"
if isinstance(new_scheme, str):
assert new_scheme in QuantScheme, "Wrong quant_scheme"
if isinstance(new_dtype, str):
assert new_dtype in QuantDtype, "Wrong quant_dtype"
# TODO: It is not a good idea to directly modify global settings. A better choice is
# making this function an attribute function of Quantizer and call this function after
# the quantizer is initialized. However, within current framework of quantization, if
# we want to modify the dtype & scheme when the quantizer is initialized, we must do
# some other things (like changing the shapes of scales and zero_points and other quantization
# information in the subclass).
global quant_default_settings
if new_scheme is not None:
quant_default_settings[quant_type]['quant_scheme'] = new_scheme
if new_dtype is not None:
quant_default_settings[quant_type]['quant_dtype'] = new_dtype
return
import torch
from nni.common.version import TORCH_VERSION
from .literal import QuantDtype, QuantScheme, QuantType
def calculate_qmin_qmax(bits, dtype):
if dtype == QuantDtype.INT:
qmin, qmax = -2 ** (bits - 1) + 1, 2 ** (bits - 1) - 1
elif dtype == QuantDtype.UINT:
qmin, qmax = 0, 2 ** bits - 1
else:
raise TypeError("Wrong quantization dtype, please make sure it is one of 'int' and 'uint'.")
return qmin, qmax
def get_bits_length(config, quant_type):
if isinstance(config["quant_bits"], int):
return config["quant_bits"]
else:
return config["quant_bits"].get(quant_type)
def get_target_dim(quant_type, quant_scheme):
# for weight: c_out x c_in x (h) * (w)
# for feature maps: batch * channel * (t) * h * w
# other type is not supported for now
default_idx = 0 if quant_type == QuantType.WEIGHT else 1
if is_per_channel(quant_scheme):
target_dim = default_idx
else:
target_dim = None
return target_dim
def get_min_max_value(x, quant_type, quant_scheme):
target_dim = get_target_dim(quant_type, quant_scheme)
if target_dim is None:
return torch.min(x), torch.max(x)
indices = list(range(len(x.shape)))
assert target_dim < len(indices), "target_dim needs to be less than the number of dim of the tensor"
del indices[target_dim]
if TORCH_VERSION > (1, 6):
min_val = torch.amin(x, indices, keepdims=True)
max_val = torch.amax(x, indices, keepdims=True)
else:
min_val = max_val = x
for ind in indices:
min_val = torch.min(min_val, dim=ind, keepdim=True)[0]
max_val = torch.max(max_val, dim=ind, keepdim=True)[0]
return min_val, max_val
def get_mean_value(x, target_dim=None):
if target_dim is None:
return torch.mean(x)
indices = list(range(len(x.shape)))
assert target_dim < len(indices), "target_dim needs to be less than the number of dim of the tensor"
del indices[target_dim]
mean_val = torch.mean(x, dim=indices, keepdim=True)
return mean_val
def is_per_channel(quant_scheme):
if quant_scheme in [QuantScheme.PER_CHANNEL_AFFINE, QuantScheme.PER_CHANNEL_SYMMETRIC]:
return True
else:
return False
def get_quant_shape(shape, quant_type, quant_scheme):
default_idx = 0 if quant_type == QuantType.WEIGHT else 1
if is_per_channel(quant_scheme):
quant_shape = [1 if idx != default_idx else s for idx, s in enumerate(shape)]
else:
quant_shape = []
return quant_shape
......@@ -9,6 +9,7 @@ import torch.nn.functional as F
import schema
import nni.algorithms.compression.pytorch.pruning as torch_pruner
import nni.algorithms.compression.pytorch.quantization as torch_quantizer
from nni.compression.pytorch.quantization.utils import calculate_qmin_qmax, get_quant_shape, get_min_max_value
import math
......@@ -50,7 +51,8 @@ class CompressorTestCase(TestCase):
model.relu = torch.nn.ReLU()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer)
dummy = torch.randn(1, 1, 28, 28)
quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer, dummy_input=dummy)
quantizer.compress()
modules_to_compress = quantizer.get_modules_to_compress()
modules_to_compress_name = [t[0].name for t in modules_to_compress]
......@@ -332,6 +334,130 @@ class CompressorTestCase(TestCase):
self.assertFalse(isinstance(model.fc1.module.weight, torch.nn.Parameter))
self.assertFalse(isinstance(model.fc2.module.weight, torch.nn.Parameter))
def test_quantization_dtype_scheme(self):
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 2, 3, 1)
self.bn1 = torch.nn.BatchNorm2d(2)
def forward(self, x):
x = self.bn1(self.conv1(x))
return x
dtypes = ['int', 'uint']
qschemes = ['per_tensor_affine', 'per_tensor_symmetric', 'per_channel_affine', 'per_channel_symmetric']
for dtype in dtypes:
for qscheme in qschemes:
config_list = [{
'quant_types': ['weight', 'input'],
'quant_bits': 8,
'op_types': ['Conv2d'],
'quant_dtype': dtype,
'quant_scheme': qscheme
}]
model = TestModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
# only QAT_quantizer is supported for now
dummy = torch.randn(1, 1, 4, 4)
quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer, dummy_input=dummy)
# test layer setting
for layer, config in quantizer.modules_to_compress:
module = layer.module
name = layer.name
layer_setting = module.layer_quant_setting
qmin, qmax = calculate_qmin_qmax(8, dtype)
all_quant_types = ['input', 'weight']
for quant_type in all_quant_types:
# check for settings
tensor_setting = getattr(layer_setting, quant_type)
self.assertTrue(tensor_setting is not None)
self.assertTrue(tensor_setting.quant_scheme == qscheme)
self.assertTrue(tensor_setting.quant_dtype == dtype)
self.assertTrue(tensor_setting.qmin == qmin)
self.assertTrue(tensor_setting.qmax == qmax)
input_shape, output_shape = quantizer.all_shapes[name]
shape = input_shape if quant_type == 'input' else module.weight.shape
quant_shape = get_quant_shape(shape, quant_type, qscheme)
scale_name = quant_type + '_scale'
zero_point_name = quant_type + '_zero_point'
scale = getattr(module, scale_name)
zero_point = getattr(module, zero_point_name)
self.assertTrue(list(scale.shape) == quant_shape)
self.assertTrue(list(zero_point.shape) == quant_shape)
weight = torch.arange(start=1, end=19).view(2, 1, 3, 3)
if qscheme == 'per_channel_symmetric':
if dtype == 'int':
target_scale = torch.tensor([9. / 127, 18. / 127]).view([2, 1, 1, 1])
target_zero_point = torch.ones([2, 1, 1, 1]) * 0
else:
target_scale = torch.tensor([9. / 127.5, 18. / 127.5]).view([2, 1, 1, 1])
target_zero_point = torch.ones([2, 1, 1, 1]) * 127
elif qscheme == 'per_tensor_symmetric':
if dtype == 'int':
target_scale = torch.tensor(18. / 127)
target_zero_point = torch.zeros([])
else:
target_scale = torch.tensor(18. / 127.5)
target_zero_point = torch.ones([]) * 127
elif qscheme == 'per_channel_affine':
min_val = torch.tensor([0., 0.]).view([2, 1, 1, 1])
if dtype == 'int':
target_scale = torch.tensor([9. / 254, 18. / 254]).view([2, 1, 1, 1])
target_zero_point = -127 - torch.round(min_val / target_scale)
else:
target_scale = torch.tensor([9. / 255, 18. / 255]).view([2, 1, 1, 1])
target_zero_point = 0 - torch.round(min_val / target_scale)
else:
if dtype == 'int':
target_scale = torch.tensor(18. / 254)
target_zero_point = -127 - torch.round(0 / target_scale)
else:
target_scale = torch.tensor(18. / 255)
target_zero_point = 0 - torch.round(0 / target_scale)
wrapper = getattr(model, name)
wrapper.module.weight = weight
quantizer.quantize_weight(wrapper)
self.assertTrue(torch.equal(getattr(model, name).module.weight_scale, target_scale))
self.assertTrue(torch.equal(getattr(model, name).module.weight_zero_point, target_zero_point))
inp = torch.arange(start=0, end=16).view(1, 1, 4, 4)
if qscheme == 'per_channel_symmetric':
if dtype == 'int':
target_scale = torch.tensor([15. / 127]).view([1, 1, 1, 1])
target_zero_point = torch.ones([1, 1, 1, 1]) * 0
else:
target_scale = torch.tensor([15. / 127.5]).view([1, 1, 1, 1])
target_zero_point = torch.ones([1, 1, 1, 1]) * 127
elif qscheme == 'per_tensor_symmetric':
if dtype == 'int':
target_scale = torch.tensor(15. / 127)
target_zero_point = torch.zeros([])
else:
target_scale = torch.tensor(15. / 127.5)
target_zero_point = torch.ones([]) * 127
elif qscheme == 'per_channel_affine':
min_val = torch.tensor([0.]).view([1, 1, 1, 1])
if dtype == 'int':
target_scale = torch.tensor([15. / 254]).view([1, 1, 1, 1])
target_zero_point = -127 - torch.round(min_val / target_scale)
else:
target_scale = torch.tensor([15. / 255]).view([1, 1, 1, 1])
target_zero_point = 0 - torch.round(min_val / target_scale)
else:
if dtype == 'int':
target_scale = torch.tensor(15. / 254)
target_zero_point = -127 - torch.round(0 / target_scale)
else:
target_scale = torch.tensor(15. / 255)
target_zero_point = 0 - torch.round(0 / target_scale)
quantizer.quantize_input(inp, wrapper)
self.assertTrue(torch.equal(getattr(model, name).module.input_scale, target_scale))
self.assertTrue(torch.equal(getattr(model, name).module.input_zero_point, target_zero_point))
def test_torch_QAT_quantizer(self):
model = TorchModel()
config_list = [{
......@@ -347,7 +473,8 @@ class CompressorTestCase(TestCase):
model.relu = torch.nn.ReLU()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer)
dummy = torch.randn(1, 1, 28, 28)
quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer, dummy_input=dummy)
quantizer.compress()
# test quantize
......@@ -357,20 +484,20 @@ class CompressorTestCase(TestCase):
weight = torch.tensor([[1, 2], [3, 5]]).float()
model.conv2.module.weight.data = weight
quantizer.quantize_weight(model.conv2, input_tensor=input)
assert math.isclose(model.conv2.module.scale, 5 / 255, abs_tol=eps)
assert model.conv2.module.zero_point == 0
assert math.isclose(model.conv2.module.weight_scale, 5 / 255, abs_tol=eps)
assert model.conv2.module.weight_zero_point == 0
quantizer.quantize_input(input, model.conv2)
self.assertTrue(torch.allclose(model.conv2.module.scale, torch.tensor([0.04 / 255])))
self.assertTrue(torch.equal(model.conv2.module.zero_point, torch.tensor([0.])))
self.assertTrue(torch.allclose(model.conv2.module.input_scale, torch.tensor([4. / 255])))
self.assertTrue(torch.equal(model.conv2.module.input_zero_point, torch.tensor(0.)))
# range including 0
weight = torch.tensor([[-1, 2], [3, 5]]).float()
model.conv2.module.weight = weight
quantizer.quantize_weight(model.conv2, input_tensor=input)
assert math.isclose(model.conv2.module.scale, 6 / 255, abs_tol=eps)
assert model.conv2.module.zero_point in (42, 43)
assert math.isclose(model.conv2.module.weight_scale, 6 / 255, abs_tol=eps)
assert model.conv2.module.weight_zero_point in (42, 43)
quantizer.quantize_input(input, model.conv2)
self.assertTrue(torch.allclose(model.conv2.module.scale, torch.tensor([0.0796 / 255])))
self.assertTrue(torch.equal(model.conv2.module.zero_point, torch.tensor([0.])))
self.assertTrue(torch.allclose(model.conv2.module.input_scale, torch.tensor([4. / 255])))
self.assertTrue(torch.equal(model.conv2.module.input_zero_point, torch.tensor(0.)))
# test value of weight and bias after quantization
weight = torch.tensor([[1.1287, 2.3456], [3.7814, 5.9723]])
weight_valid = torch.tensor([[1.1242, 2.3421], [3.7707, 5.9723]])
......@@ -385,15 +512,15 @@ class CompressorTestCase(TestCase):
# test ema
eps = 1e-7
x = torch.tensor([[-0.2, 0], [0.1, 0.2]])
out = model.relu(x)
assert math.isclose(model.relu.module.tracked_min_output, 0, abs_tol=eps)
assert math.isclose(model.relu.module.tracked_max_output, 0.002, abs_tol=eps)
model.relu(x)
self.assertTrue(torch.equal(model.relu.module.tracked_min_output, torch.tensor(0.)))
self.assertTrue(torch.equal(model.relu.module.tracked_max_output, torch.tensor(0.2)))
quantizer.step_with_optimizer()
x = torch.tensor([[0.2, 0.4], [0.6, 0.8]])
out = model.relu(x)
assert math.isclose(model.relu.module.tracked_min_output, 0.002, abs_tol=eps)
assert math.isclose(model.relu.module.tracked_max_output, 0.00998, abs_tol=eps)
model.relu(x)
self.assertTrue(torch.equal(model.relu.module.tracked_min_output, torch.tensor(0.002)))
self.assertTrue(torch.equal(model.relu.module.tracked_max_output, torch.tensor(0.2060)))
def test_torch_quantizer_export(self):
config_list_qat = [{
......@@ -424,11 +551,14 @@ class CompressorTestCase(TestCase):
}]
config_set = [config_list_qat, config_list_dorefa, config_list_bnn]
quantize_algorithm_set = [torch_quantizer.QAT_Quantizer, torch_quantizer.DoReFaQuantizer, torch_quantizer.BNNQuantizer]
dummy = torch.randn(1, 1, 28, 28)
for config, quantize_algorithm in zip(config_set, quantize_algorithm_set):
model = TorchModel()
model.relu = torch.nn.ReLU()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
if quantize_algorithm == torch_quantizer.QAT_Quantizer:
quantizer = quantize_algorithm(model, config, optimizer, dummy)
else:
quantizer = quantize_algorithm(model, config, optimizer)
quantizer.compress()
......@@ -461,6 +591,10 @@ class CompressorTestCase(TestCase):
model = TorchModel().eval()
model.relu = torch.nn.ReLU()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
if quantize_algorithm == torch_quantizer.QAT_Quantizer:
dummy = torch.randn(1, 1, 28, 28)
quantizer = quantize_algorithm(model, configure_list, optimizer, dummy_input=dummy)
else:
quantizer = quantize_algorithm(model, configure_list, optimizer)
quantizer.compress()
if calibration_config is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment