"test/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "45dddfa7a5a456e97177d7ae75a1f5e1b6318c3e"
Unverified Commit c9cd53aa authored by chenbohua3's avatar chenbohua3 Committed by GitHub
Browse files

support dtype&scheme customization for QAT quantizer (#4137)

parent b0f34da1
...@@ -155,7 +155,7 @@ Sometimes it's necessary for a quantization operation to have a customized backw ...@@ -155,7 +155,7 @@ Sometimes it's necessary for a quantization operation to have a customized backw
grad_output : Tensor grad_output : Tensor
gradient of the output of quantization operation gradient of the output of quantization operation
quant_type : QuantType quant_type : QuantType
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`, the type of quantization, it can be `QuantType.INPUT`, `QuantType.WEIGHT`, `QuantType.OUTPUT`,
you can define different behavior for different types. you can define different behavior for different types.
Returns Returns
------- -------
...@@ -164,7 +164,7 @@ Sometimes it's necessary for a quantization operation to have a customized backw ...@@ -164,7 +164,7 @@ Sometimes it's necessary for a quantization operation to have a customized backw
""" """
# for quant_output function, set grad to zero if the absolute value of tensor is larger than 1 # for quant_output function, set grad to zero if the absolute value of tensor is larger than 1
if quant_type == QuantType.QUANT_OUTPUT: if quant_type == QuantType.OUTPUT:
grad_output[torch.abs(tensor) > 1] = 0 grad_output[torch.abs(tensor) > 1] = 0
return grad_output return grad_output
......
...@@ -2,11 +2,13 @@ import torch ...@@ -2,11 +2,13 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torchvision import datasets, transforms from torchvision import datasets, transforms
from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer
from nni.compression.pytorch.quantization.settings import set_quant_scheme_dtype
import sys import sys
sys.path.append('../models') sys.path.append('../models')
from mnist.naive import NaiveModel from mnist.naive import NaiveModel
def train(model, device, train_loader, optimizer): def train(model, device, train_loader, optimizer):
model.train() model.train()
for batch_idx, (data, target) in enumerate(train_loader): for batch_idx, (data, target) in enumerate(train_loader):
...@@ -68,13 +70,23 @@ def main(): ...@@ -68,13 +70,23 @@ def main():
}, { }, {
'quant_types': ['output', 'weight', 'input'], 'quant_types': ['output', 'weight', 'input'],
'quant_bits': {'output': 8, 'weight': 8, 'input': 8}, 'quant_bits': {'output': 8, 'weight': 8, 'input': 8},
'op_names': ['fc1'], 'op_names': ['fc1', 'fc2'],
}, {
'quant_types': ['output', 'weight', 'input'],
'quant_bits': {'output': 8, 'weight': 8, 'input': 8},
'op_names': ['fc2'],
}] }]
# you can also set the quantization dtype and scheme layer-wise through configure_list like:
# configure_list = [{
# 'quant_types': ['weight', 'input'],
# 'quant_bits': {'weight': 8, 'input': 8},
# 'op_names': ['conv1', 'conv2'],
# 'quant_dtype': 'int',
# 'quant_scheme': 'per_channel_symmetric'
# }]
# For now quant_dtype's options are 'int' and 'uint. And quant_scheme's options are per_tensor_affine,
# per_tensor_symmetric, per_channel_affine and per_channel_symmetric.
set_quant_scheme_dtype('weight', 'per_channel_symmetric', 'int')
set_quant_scheme_dtype('output', 'per_tensor_symmetric', 'int')
set_quant_scheme_dtype('input', 'per_tensor_symmetric', 'int')
model = NaiveModel().to(device) model = NaiveModel().to(device)
dummy_input = torch.randn(1, 1, 28, 28).to(device) dummy_input = torch.randn(1, 1, 28, 28).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
...@@ -98,5 +110,6 @@ def main(): ...@@ -98,5 +110,6 @@ def main():
calibration_config = quantizer.export_model(model_path, calibration_path, onnx_path, input_shape, device) calibration_config = quantizer.export_model(model_path, calibration_path, onnx_path, input_shape, device)
print("Generated calibration config is: ", calibration_config) print("Generated calibration config is: ", calibration_config)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
...@@ -6,9 +6,21 @@ from collections import defaultdict ...@@ -6,9 +6,21 @@ from collections import defaultdict
import torch import torch
from schema import Schema, And, Or, Optional from schema import Schema, And, Or, Optional
from nni.compression.pytorch.utils.config_validation import QuantizerSchema from nni.compression.pytorch.utils.config_validation import QuantizerSchema
from nni.compression.pytorch.compressor import BN_FOLD_TAG, Quantizer, QuantForward, QuantGrad, QuantType from nni.compression.pytorch.compressor import BN_FOLD_TAG, Quantizer, QuantForward, QuantGrad
from nni.compression.pytorch.quantization.literal import (
from .observers import default_weight_observer, default_histogram_observer PER_CHANNEL_QUANT_SCHEME,
QuantScheme,
QuantDtype,
QuantType
)
from nni.compression.pytorch.quantization.observers import default_weight_observer, default_histogram_observer
from nni.compression.pytorch.quantization.settings import LayerQuantSetting
from nni.compression.pytorch.quantization.utils import (
calculate_qmin_qmax,
get_bits_length,
get_min_max_value,
get_quant_shape
)
__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer', 'LsqQuantizer', 'ObserverQuantizer'] __all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer', 'LsqQuantizer', 'ObserverQuantizer']
...@@ -65,7 +77,7 @@ def update_ema(biased_ema, value, decay): ...@@ -65,7 +77,7 @@ def update_ema(biased_ema, value, decay):
return biased_ema return biased_ema
def update_quantization_param(bits, rmin, rmax): def update_quantization_param(bits, rmin, rmax, dtype, scheme):
""" """
calculate the `zero_point` and `scale`. calculate the `zero_point` and `scale`.
...@@ -77,41 +89,46 @@ def update_quantization_param(bits, rmin, rmax): ...@@ -77,41 +89,46 @@ def update_quantization_param(bits, rmin, rmax):
min value of real value min value of real value
rmax : Tensor rmax : Tensor
max value of real value max value of real value
dtype : QuantDtype
quantized data type
scheme : QuantScheme
quantization scheme to be used
Returns Returns
------- -------
float, float float, float
""" """
# extend the [min, max] interval to ensure that it contains 0. # extend the [min, max] interval to ensure that it contains 0.
# Otherwise, we would not meet the requirement that 0 be an exactly # Otherwise, we would not meet the requirement that 0 be an exactly
# representable value. # representable value.
rmin = torch.min(rmin, torch.Tensor([0]).to(rmin.device)) # I think this is for activations that need to be pad in the training.
rmax = torch.max(rmax, torch.Tensor([0]).to(rmin.device)) # However this is a default behavior in PyTorch quantization observer.
qmin = torch.Tensor([0]).to(rmin.device) # So we also make it a default behavior
qmax = torch.Tensor([(1 << bits) - 1]).to(rmin.device) rmin = torch.min(rmin, torch.zeros_like(rmin))
rmax = torch.max(rmax, torch.zeros_like(rmax))
# First determine the scale. zero_point = torch.zeros_like(rmin)
scale = (rmax - rmin) / (qmax - qmin)
# todo: there is no need to calculate qmin and qmax again
# Zero-point computation. qmin, qmax = calculate_qmin_qmax(bits, dtype)
initial_zero_point = qmin - rmin / scale
if scheme in [QuantScheme.PER_TENSOR_SYMMETRIC, QuantScheme.PER_CHANNEL_SYMMETRIC]:
# Now we need to nudge the zero point to be an integer abs_max = torch.max(torch.abs(rmin), torch.abs(rmax))
if initial_zero_point < qmin: scale = abs_max / (float(qmax - qmin) / 2)
nudged_zero_point = qmin if dtype == QuantDtype.UINT:
elif initial_zero_point > qmax: zero_point_val = (qmin + qmax) // 2
nudged_zero_point = qmax zero_point = zero_point.new_full(zero_point.size(), zero_point_val)
else: else:
nudged_zero_point = torch.round(initial_zero_point) scale = (rmax - rmin) / float(qmax - qmin)
zero_point = qmin - torch.round(rmin / scale)
return scale, nudged_zero_point zero_point = torch.clamp(zero_point, qmin, qmax)
# todo: add these lines
# eps = torch.finfo(torch.float32).eps
# scale = torch.max(scale, eps)
return scale, zero_point
def get_bits_length(config, quant_type):
if isinstance(config["quant_bits"], int):
return config["quant_bits"]
else:
return config["quant_bits"].get(quant_type)
class QATGrad(QuantGrad): class QATGrad(QuantGrad):
@staticmethod @staticmethod
...@@ -384,22 +401,49 @@ class QAT_Quantizer(Quantizer): ...@@ -384,22 +401,49 @@ class QAT_Quantizer(Quantizer):
self.bound_model.register_buffer("steps", torch.tensor(1)) self.bound_model.register_buffer("steps", torch.tensor(1))
for layer, config in modules_to_compress: for layer, config in modules_to_compress:
module = layer.module module = layer.module
module.register_buffer("zero_point", torch.tensor([0.0])) name = layer.name
module.register_buffer("scale", torch.tensor([1.0])) # TODO: may relax this limitation?
module.register_buffer('ema_decay', torch.tensor([0.99])) assert name in self.all_shapes, "Could not found shapes for layer {}".format(name)
input_shape, output_shape = self.all_shapes[name]
layer_quant_setting = LayerQuantSetting(config)
layer_quant_setting.ema_decay = 0.99
quant_start_step = config.get('quant_start_step', 0)
layer_quant_setting.quant_start_step = quant_start_step
# todo: support other ranks and remove this check
if isinstance(module, torch.nn.Linear):
if "input" in config.get("quant_types", []) and \
layer_quant_setting.input.quant_scheme in PER_CHANNEL_QUANT_SCHEME:
if len(input_shape) != 2:
logger.warning("When quantize torch.nn.Linear, make sure that the rank of the inputs "
"of the layer is 2. Skip quantization of layer %s.", name)
continue
if "output" in config.get("quant_types", []) and \
layer_quant_setting.output.quant_scheme in PER_CHANNEL_QUANT_SCHEME:
if len(output_shape) != 2:
logger.warning("When quantize torch.nn.Linear, make sure that the rank of the outputs "
"of the layer is 2. Skip quantization of layer %s.", name)
continue
if "weight" in config.get("quant_types", []): if "weight" in config.get("quant_types", []):
weight_bits = get_bits_length(config, 'weight') quant_shape = get_quant_shape(module.weight.shape, QuantType.WEIGHT, layer_quant_setting.weight.quant_scheme)
layer.module.register_buffer('weight_bits', torch.Tensor([int(weight_bits)])) module.register_buffer('weight_scale', torch.zeros(quant_shape))
module.register_buffer('weight_zero_point', torch.zeros(quant_shape))
if "input" in config.get("quant_types", []): if "input" in config.get("quant_types", []):
input_bits = get_bits_length(config, 'input') quant_shape = get_quant_shape(input_shape, QuantType.INPUT, layer_quant_setting.input.quant_scheme)
layer.module.register_buffer('tracked_min_input', torch.zeros(1)) module.register_buffer('tracked_min_input', torch.zeros(quant_shape))
layer.module.register_buffer('tracked_max_input', torch.zeros(1)) module.register_buffer('tracked_max_input', torch.zeros(quant_shape))
layer.module.register_buffer('input_bits', torch.Tensor([int(input_bits)])) module.register_buffer('input_scale', torch.zeros(quant_shape))
module.register_buffer('input_zero_point', torch.zeros(quant_shape))
if "output" in config.get("quant_types", []): if "output" in config.get("quant_types", []):
output_bits = get_bits_length(config, 'output') quant_shape = get_quant_shape(output_shape, QuantType.OUTPUT, layer_quant_setting.output.quant_scheme)
layer.module.register_buffer('output_bits', torch.Tensor([int(output_bits)])) module.register_buffer('tracked_min_output', torch.zeros(quant_shape))
layer.module.register_buffer('tracked_min_output', torch.zeros(1)) module.register_buffer('tracked_max_output', torch.zeros(quant_shape))
layer.module.register_buffer('tracked_max_output', torch.zeros(1)) module.register_buffer('output_scale', torch.zeros(quant_shape))
module.register_buffer('output_zero_point', torch.zeros(quant_shape))
setattr(module, "layer_quant_setting", layer_quant_setting)
self.bound_model.to(device) self.bound_model.to(device)
def _del_simulated_attr(self, module): def _del_simulated_attr(self, module):
...@@ -407,8 +451,9 @@ class QAT_Quantizer(Quantizer): ...@@ -407,8 +451,9 @@ class QAT_Quantizer(Quantizer):
delete redundant parameters in quantize module delete redundant parameters in quantize module
""" """
del_attr_list = ['old_weight', 'old_bias', 'ema_decay', 'tracked_min_output', 'tracked_max_output', del_attr_list = ['old_weight', 'old_bias', 'ema_decay', 'tracked_min_output', 'tracked_max_output',
'tracked_min_input', 'tracked_max_input', 'scale', 'zero_point', 'weight_bits', 'tracked_min_input', 'tracked_max_input', 'BN_FOLD_TAG',
'output_bits', 'BN_FOLD_TAG', 'input_bits'] 'weight_scale', 'weight_zero_point', 'input_scale', 'input_zero_point',
'output_scale', 'output_zero_point', 'layer_quant_setting']
for attr in del_attr_list: for attr in del_attr_list:
if hasattr(module, attr): if hasattr(module, attr):
delattr(module, attr) delattr(module, attr)
...@@ -422,6 +467,7 @@ class QAT_Quantizer(Quantizer): ...@@ -422,6 +467,7 @@ class QAT_Quantizer(Quantizer):
config_list : list of dict config_list : list of dict
List of configurations List of configurations
""" """
SUPPORTED_OPS = ['Conv2d', 'Linear', 'ReLU', 'ReLU6']
schema = QuantizerSchema([{ schema = QuantizerSchema([{
Optional('quant_types'): Schema([lambda x: x in ['weight', 'output', 'input']]), Optional('quant_types'): Schema([lambda x: x in ['weight', 'output', 'input']]),
Optional('quant_bits'): Or(And(int, lambda n: 0 < n < 32), Schema({ Optional('quant_bits'): Or(And(int, lambda n: 0 < n < 32), Schema({
...@@ -429,41 +475,51 @@ class QAT_Quantizer(Quantizer): ...@@ -429,41 +475,51 @@ class QAT_Quantizer(Quantizer):
Optional('weight'): And(int, lambda n: 0 < n < 32), Optional('weight'): And(int, lambda n: 0 < n < 32),
Optional('output'): And(int, lambda n: 0 < n < 32), Optional('output'): And(int, lambda n: 0 < n < 32),
})), })),
Optional('quant_scheme'): Or(lambda x: x in QuantScheme, Schema({
Optional('input'): lambda x: x in QuantScheme,
Optional('weight'): lambda x: x in QuantScheme,
Optional('output'): lambda x: x in QuantScheme
})),
Optional('quant_dtype'): Or(lambda x: x in QuantDtype, Schema({
Optional('input'): lambda x: x in QuantDtype,
Optional('weight'): lambda x: x in QuantDtype,
Optional('output'): lambda x: x in QuantDtype
})),
Optional('quant_start_step'): And(int, lambda n: n >= 0), Optional('quant_start_step'): And(int, lambda n: n >= 0),
Optional('op_types'): [str], Optional('op_types'): [And(str, lambda n: n in SUPPORTED_OPS)],
Optional('op_names'): [str], Optional('op_names'): [str],
Optional('exclude'): bool Optional('exclude'): bool
}], model, logger) }], model, logger)
schema.validate(config_list) schema.validate(config_list)
def _quantize(self, bits, op, real_val): def _quantize(self, real_value, scale, zero_point, qmin, qmax):
""" """
quantize real value. quantize real value.
Parameters Parameters
---------- ----------
bits : int real_value : torch.Tensor
quantization bits length the real value to be quantized
op : torch.nn.Module scale : torch.Tensor
target module quantization scale
real_val : Tensor zero_point : torch.Tensor
real value to be quantized quantization zero point
qmin : int
lower bound of the int range
qmax : int
upper bound of the int range
Returns Returns
------- -------
Tensor Tensor
""" """
op.zero_point = op.zero_point.to(real_val.device) transformed_val = zero_point + real_value / scale
op.scale = op.scale.to(real_val.device)
transformed_val = op.zero_point + real_val / op.scale
qmin = 0
qmax = (1 << bits) - 1
clamped_val = torch.clamp(transformed_val, qmin, qmax) clamped_val = torch.clamp(transformed_val, qmin, qmax)
quantized_val = torch.round(clamped_val) quantized_val = torch.round(clamped_val)
return quantized_val return quantized_val
def _dequantize(self, op, quantized_val): def _dequantize(self, quantized_val, scale, zero_point):
""" """
dequantize quantized value. dequantize quantized value.
Because we simulate quantization in training process, all the computations still happen as float point computations, which means we Because we simulate quantization in training process, all the computations still happen as float point computations, which means we
...@@ -471,103 +527,149 @@ class QAT_Quantizer(Quantizer): ...@@ -471,103 +527,149 @@ class QAT_Quantizer(Quantizer):
Parameters Parameters
---------- ----------
op : torch.nn.Module quantized_val : torch.Tensor
target module the quantized value to be de-quantized
quantized_val : float scale : torch.Tensor
quantized_val value to be dequantized quantization scale
zero_point : torch.Tensor
quantization zero point
Returns Returns
------- -------
float Tensor
""" """
real_val = op.scale * (quantized_val - op.zero_point) real_val = scale * (quantized_val - zero_point)
return real_val return real_val
def quantize_weight(self, wrapper, **kwargs): def quantize_weight(self, wrapper, **kwargs):
config = wrapper.config
module = wrapper.module module = wrapper.module
weight = module.weight weight = module.weight
weight_bits = int(module.weight_bits) layer_quant_setting = module.layer_quant_setting
quant_start_step = config.get('quant_start_step', 0) tensor_quant_setting = layer_quant_setting.weight
assert weight_bits >= 1, "quant bits length should be at least 1"
if quant_start_step > int(self.bound_model.steps): # layer-wise settings
return weight quant_start_step = layer_quant_setting.quant_start_step
# tensor-wise settings
dtype = tensor_quant_setting.quant_dtype
scheme = tensor_quant_setting.quant_scheme
qmin, qmax = tensor_quant_setting.get_qmin_qmax()
bits = tensor_quant_setting.bits
# In evaluation mode, we only quantize weight without updating statistics
if not wrapper.training: if not wrapper.training:
scale, zero_point = module.weight_scale, module.weight_zero_point
weight = self._quantize(weight, scale, zero_point, qmin, qmax)
weight = self._dequantize(weight, scale, zero_point)
module.weight = weight
return weight
if quant_start_step > int(self.bound_model.steps):
return weight return weight
# quantize weight current_min, current_max = get_min_max_value(weight, QuantType.WEIGHT, scheme)
rmin, rmax = torch.min(weight), torch.max(weight) scale, zero_point = update_quantization_param(bits, current_min, current_max, dtype, scheme)
scale, zero_point = update_quantization_param(weight_bits, rmin, rmax) module.weight_scale.copy_(scale)
module.scale.copy_(scale) module.weight_zero_point.copy_(zero_point)
module.zero_point.copy_(zero_point) weight = self._quantize(weight, scale, zero_point, qmin, qmax)
weight = self._quantize(weight_bits, module, weight) weight = self._dequantize(weight, scale, zero_point)
weight = self._dequantize(module, weight) # Weight can not be in-place modified, so when use torch.nn.DataParallel, this update
# will be lost after each forward process. However, this update takes effect on each
# replicated module during each forward process, which will make the quantized weight
# be used correctly.
wrapper.module.weight = weight wrapper.module.weight = weight
return weight return weight
def quantize_input(self, inputs, wrapper, **kwargs): def quantize_input(self, inputs, wrapper, **kwargs):
config = wrapper.config
module = wrapper.module module = wrapper.module
input_bits = int(module.input_bits)
quant_start_step = config.get('quant_start_step', 0)
assert input_bits >= 1, "quant bits length should be at least 1"
if quant_start_step > int(self.bound_model.steps): layer_quant_setting = module.layer_quant_setting
current_min, current_max = torch.min(inputs), torch.max(inputs) tensor_quant_setting = layer_quant_setting.input
module.tracked_min_input.copy_(current_min)
module.tracked_max_input.copy_(current_max) # layer-wise settings
quant_start_step = layer_quant_setting.quant_start_step
ema_decay = layer_quant_setting.ema_decay
# tensor-wise settings
dtype = tensor_quant_setting.quant_dtype
scheme = tensor_quant_setting.quant_scheme
qmin, qmax = tensor_quant_setting.get_qmin_qmax()
bits = tensor_quant_setting.bits
if not wrapper.training:
scale = module.input_scale
zero_point = module.input_zero_point
inputs = self._quantize(inputs, scale, zero_point, qmin, qmax)
inputs = self._dequantize(inputs, scale, zero_point)
return inputs return inputs
# we dont update output quantization parameters in evaluation stage current_min, current_max = get_min_max_value(inputs, QuantType.INPUT, scheme)
if wrapper.training:
current_min, current_max = torch.min(inputs), torch.max(inputs) if int(self.bound_model.steps) == 1:
current_min = update_ema(module.tracked_min_input, current_min, module.ema_decay)
current_max = update_ema(module.tracked_max_input, current_max, module.ema_decay)
module.tracked_min_input.copy_(current_min) module.tracked_min_input.copy_(current_min)
module.tracked_max_input.copy_(current_max) module.tracked_max_input.copy_(current_max)
tracked_min_input = update_ema(module.tracked_min_input, current_min, ema_decay)
tracked_max_input = update_ema(module.tracked_max_input, current_max, ema_decay)
module.tracked_min_input.copy_(tracked_min_input)
module.tracked_max_input.copy_(tracked_max_input)
if quant_start_step > int(self.bound_model.steps):
return inputs
scale, zero_point = update_quantization_param( scale, zero_point = update_quantization_param(
input_bits, module.tracked_min_input, module.tracked_max_input) bits, module.tracked_min_input, module.tracked_max_input, dtype, scheme)
module.scale.copy_(scale) module.input_scale.copy_(scale)
module.zero_point.copy_(zero_point) module.input_zero_point.copy_(zero_point)
inp = self._quantize(input_bits, module, inputs) inputs = self._quantize(inputs, scale, zero_point, qmin, qmax)
inp = self._dequantize(module, inp) inputs = self._dequantize(inputs, scale, zero_point)
return inp return inputs
def quantize_output(self, output, wrapper, **kwargs): def quantize_output(self, output, wrapper, **kwargs):
config = wrapper.config
module = wrapper.module module = wrapper.module
output_bits = int(module.output_bits) layer_quant_setting = module.layer_quant_setting
quant_start_step = config.get('quant_start_step', 0) tensor_quant_setting = layer_quant_setting.output
assert output_bits >= 1, "quant bits length should be at least 1"
if quant_start_step > int(self.bound_model.steps): # layer-wise settings
current_min, current_max = torch.min(output), torch.max(output) quant_start_step = layer_quant_setting.quant_start_step
ema_decay = layer_quant_setting.ema_decay
# tensor-wise settings
dtype = tensor_quant_setting.quant_dtype
scheme = tensor_quant_setting.quant_scheme
qmin, qmax = tensor_quant_setting.get_qmin_qmax()
bits = tensor_quant_setting.bits
if not wrapper.training:
scale = module.output_scale
zero_point = module.output_zero_point
output = self._quantize(output, scale, zero_point, qmin, qmax)
output = self._dequantize(output, scale, zero_point)
return output
current_min, current_max = get_min_max_value(output, QuantType.OUTPUT, scheme)
if int(self.bound_model.steps) == 1:
module.tracked_min_output.copy_(current_min) module.tracked_min_output.copy_(current_min)
module.tracked_max_output.copy_(current_max) module.tracked_max_output.copy_(current_max)
return output
# we dont update output quantization parameters in evaluation stage tracked_min_output = update_ema(module.tracked_min_output, current_min, ema_decay)
if wrapper.training: tracked_max_output = update_ema(module.tracked_max_output, current_max, ema_decay)
current_min, current_max = torch.min(output), torch.max(output)
tracked_min_output = update_ema(module.tracked_min_output, current_min,
module.ema_decay)
tracked_max_output = update_ema(module.tracked_max_output, current_max,
module.ema_decay)
module.tracked_min_output.copy_(tracked_min_output) module.tracked_min_output.copy_(tracked_min_output)
module.tracked_max_output.copy_(tracked_max_output) module.tracked_max_output.copy_(tracked_max_output)
if quant_start_step > int(self.bound_model.steps):
return output
scale, zero_point = update_quantization_param( scale, zero_point = update_quantization_param(
output_bits, module.tracked_min_output, module.tracked_max_output) bits, module.tracked_min_output, module.tracked_max_output, dtype, scheme)
module.scale.copy_(scale) module.output_scale.copy_(scale)
module.zero_point.copy_(zero_point) module.output_zero_point.copy_(zero_point)
out = self._quantize(output_bits, module, output) output = self._quantize(output, scale, zero_point, qmin, qmax)
out = self._dequantize(module, out) output = self._dequantize(output, scale, zero_point)
return out return output
def load_calibration_config(self, calibration_config): def load_calibration_config(self, calibration_config):
modules_to_compress = self.get_modules_to_compress() modules_to_compress = self.get_modules_to_compress()
...@@ -581,12 +683,12 @@ class QAT_Quantizer(Quantizer): ...@@ -581,12 +683,12 @@ class QAT_Quantizer(Quantizer):
assert calibration_config[name]['weight_bits'] == module.weight_bits, f"weight bits of module {name} fail to match" assert calibration_config[name]['weight_bits'] == module.weight_bits, f"weight bits of module {name} fail to match"
if hasattr(module, 'input_bits'): if hasattr(module, 'input_bits'):
assert calibration_config[name]['input_bits'] == module.input_bits, f"input bits of module {name} fail to match" assert calibration_config[name]['input_bits'] == module.input_bits, f"input bits of module {name} fail to match"
module.tracked_min_input.data = torch.Tensor([calibration_config[name]['tracked_min_input']]) module.tracked_min_input.data = torch.tensor([calibration_config[name]['tracked_min_input']])
module.tracked_max_input.data = torch.Tensor([calibration_config[name]['tracked_max_input']]) module.tracked_max_input.data = torch.tensor([calibration_config[name]['tracked_max_input']])
if hasattr(module, 'output_bits'): if hasattr(module, 'output_bits'):
assert calibration_config[name]['output_bits'] == module.output_bits, f"output bits of module {name} fail to match" assert calibration_config[name]['output_bits'] == module.output_bits, f"output bits of module {name} fail to match"
module.tracked_min_output.data = torch.Tensor([calibration_config[name]['tracked_min_output']]) module.tracked_min_output.data = torch.tensor([calibration_config[name]['tracked_min_output']])
module.tracked_max_output.data = torch.Tensor([calibration_config[name]['tracked_max_output']]) module.tracked_max_output.data = torch.tensor([calibration_config[name]['tracked_max_output']])
def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None): def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None):
""" """
...@@ -619,6 +721,8 @@ class QAT_Quantizer(Quantizer): ...@@ -619,6 +721,8 @@ class QAT_Quantizer(Quantizer):
calibration_config[name] = {} calibration_config[name] = {}
if hasattr(module, 'weight_bits'): if hasattr(module, 'weight_bits'):
calibration_config[name]['weight_bits'] = int(module.weight_bits) calibration_config[name]['weight_bits'] = int(module.weight_bits)
calibration_config[name]['weight_scale'] = module.weight_scale
calibration_config[name]['weight_zero_point'] = module.weight_zero_point
# Recover weight/bias for batch normalization folding # Recover weight/bias for batch normalization folding
actual_weight = getattr(module, 'old_weight', None) actual_weight = getattr(module, 'old_weight', None)
...@@ -759,7 +863,7 @@ class DoReFaQuantizer(Quantizer): ...@@ -759,7 +863,7 @@ class DoReFaQuantizer(Quantizer):
class ClipGrad(QuantGrad): class ClipGrad(QuantGrad):
@staticmethod @staticmethod
def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax): def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax):
if quant_type == QuantType.QUANT_OUTPUT: if quant_type == QuantType.OUTPUT:
grad_output[torch.abs(tensor) > 1] = 0 grad_output[torch.abs(tensor) > 1] = 0
return grad_output return grad_output
......
import logging
try:
import torch
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split(".")[:2])
except Exception:
logging.info("PyTorch is not installed.")
TORCH_VERSION = None
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import copy
import types import types
import logging import logging
import torch import torch
from nni.common.graph_utils import build_module_graph from nni.common.graph_utils import build_module_graph
from nni.compression.pytorch.quantization.literal import QuantType, BN_FOLD_OP, BN_FOLD_TAG
from nni.compression.pytorch.quantization.observers import RecordingObserver
from . import default_layers from . import default_layers
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -547,7 +550,7 @@ class QuantizerModuleWrapper(torch.nn.Module): ...@@ -547,7 +550,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
assert len(inputs) == 1, "Quantization of input only supports ops with single input." assert len(inputs) == 1, "Quantization of input only supports ops with single input."
new_inp = self.quantizer.quant_grad( new_inp = self.quantizer.quant_grad(
inputs[0], inputs[0],
QuantType.QUANT_INPUT, QuantType.INPUT,
self) self)
inputs = (new_inp,) inputs = (new_inp,)
...@@ -563,7 +566,7 @@ class QuantizerModuleWrapper(torch.nn.Module): ...@@ -563,7 +566,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
self.quantizer.quant_grad( self.quantizer.quant_grad(
new_weight, new_weight,
QuantType.QUANT_WEIGHT, QuantType.WEIGHT,
self, inputs[0]) self, inputs[0])
result = self.module(*inputs) result = self.module(*inputs)
...@@ -571,7 +574,7 @@ class QuantizerModuleWrapper(torch.nn.Module): ...@@ -571,7 +574,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
if 'output' in self.config['quant_types']: if 'output' in self.config['quant_types']:
result = self.quantizer.quant_grad( result = self.quantizer.quant_grad(
result, result,
QuantType.QUANT_OUTPUT, QuantType.OUTPUT,
self) self)
return result return result
...@@ -604,10 +607,13 @@ class Quantizer(Compressor): ...@@ -604,10 +607,13 @@ class Quantizer(Compressor):
def __init__(self, model, config_list, optimizer=None, dummy_input=None): def __init__(self, model, config_list, optimizer=None, dummy_input=None):
if isinstance(model, torch.nn.DataParallel): if isinstance(model, torch.nn.DataParallel):
model = model.module model = model.module
model_copied = copy.deepcopy(model)
self.identity_wrappers = [] self.identity_wrappers = []
self.conv_bn_patterns = {} self.conv_bn_patterns = {}
self.find_conv_bn_patterns(model, dummy_input) self.find_conv_bn_patterns(model, dummy_input)
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
self.all_shapes = {}
self.record_shape(model_copied, dummy_input)
self.quant_grad = QuantGrad.apply self.quant_grad = QuantGrad.apply
if self.optimizer is not None: if self.optimizer is not None:
self.patch_optimizer(self.step_with_optimizer) self.patch_optimizer(self.step_with_optimizer)
...@@ -845,25 +851,54 @@ class Quantizer(Compressor): ...@@ -845,25 +851,54 @@ class Quantizer(Compressor):
if successor.op_type == 'BatchNorm2d': if successor.op_type == 'BatchNorm2d':
self.conv_bn_patterns[node_group.name] = successor.name self.conv_bn_patterns[node_group.name] = successor.name
def step_with_optimizer(self): def record_shape(self, model, dummy_input):
pass
class QuantType:
""" """
Enum class for quantization type. Record input/output's shapes of each module to be quantized
Parameters
----------
model : torch.nn.Module
model to be recorded.
dummy_input : tupel of torch.tensor
inputs to the model.
""" """
QUANT_INPUT = 0 def _pre_forward_hook(self, inp):
QUANT_WEIGHT = 1 # Only record the first tensor of the input
QUANT_OUTPUT = 2 return self.pre_forward(inp[0])
def _post_forward_hook(self, _, out):
return self.post_forward(out)
QType_Dict = { if dummy_input is None:
0: "input", return
1: "weight",
2: "output" all_handles = []
} all_observers = {}
modules_to_compress = self.get_modules_to_compress()
compress_names = [layer_info[0].name for layer_info in modules_to_compress]
for name, module in model.named_modules():
if name in compress_names:
all_observers[name] = {}
all_observers[name]['input_hook'] = RecordingObserver()
all_observers[name]['output_hook'] = RecordingObserver()
module.add_module('pre_forward', all_observers[name]['input_hook'])
module.add_module('post_forward', all_observers[name]['output_hook'])
all_handles.append(module.register_forward_pre_hook(_pre_forward_hook))
all_handles.append(module.register_forward_hook(_post_forward_hook))
model(dummy_input)
for name, hooks in all_observers.items():
# only support single input
input_val = hooks['input_hook'].tensor_val
input_shape = input_val[0].shape if input_val else None
output_val = hooks['output_hook'].tensor_val
output_shape = output_val[0].shape if output_val else None
shapes = [input_shape, output_shape]
self.all_shapes[name] = shapes
return
def step_with_optimizer(self):
pass
BN_FOLD_OP = ["Conv2d"]
BN_FOLD_TAG = 'BN_FOLD_TAG'
class QuantGrad(torch.autograd.Function): class QuantGrad(torch.autograd.Function):
""" """
...@@ -920,8 +955,8 @@ class QuantGrad(torch.autograd.Function): ...@@ -920,8 +955,8 @@ class QuantGrad(torch.autograd.Function):
grad_output : Tensor grad_output : Tensor
gradient of the output of quantization operation gradient of the output of quantization operation
scale : Tensor scale : Tensor
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, the type of quantization, it can be `QuantType.INPUT`, `QuantType.WEIGHT`,
`QuantType.QUANT_OUTPUT`, you can define different behavior for different types. `QuantType.OUTPUT`, you can define different behavior for different types.
zero_point : Tensor zero_point : Tensor
zero_point for quantizing tensor zero_point for quantizing tensor
qmin : Tensor qmin : Tensor
...@@ -939,28 +974,39 @@ class QuantGrad(torch.autograd.Function): ...@@ -939,28 +974,39 @@ class QuantGrad(torch.autograd.Function):
def forward(ctx, tensor, quant_type, wrapper, input_tensor=None, **kwargs): def forward(ctx, tensor, quant_type, wrapper, input_tensor=None, **kwargs):
output = quantize_helper(tensor, quant_type, wrapper, input_tensor, **kwargs) output = quantize_helper(tensor, quant_type, wrapper, input_tensor, **kwargs)
bits = QuantGrad.get_bits_length(wrapper.config, QType_Dict[quant_type]) if hasattr(wrapper.module, "layer_quant_setting"):
qmin, qmax = torch.Tensor([0]).to(tensor.device), torch.Tensor([(1 << bits) - 1]).to(tensor.device) layer_quant_setting = wrapper.module.layer_quant_setting
if hasattr(wrapper.module, 'scale') and hasattr(wrapper.module, 'zero_point'): qmin, qmax = getattr(layer_quant_setting, quant_type).get_qmin_qmax()
else:
# todo: when dtype/scheme customization is ready for all quantizers, remove this
bits = QuantGrad.get_bits_length(wrapper.config, quant_type)
qmin, qmax = 0, (1 << bits) - 1
scale_name, zero_point_name = quant_type.type_to_scale_zero_point_name()
if hasattr(wrapper.module, scale_name) and hasattr(wrapper.module, zero_point_name):
scale = getattr(wrapper.module, scale_name)
zero_point = getattr(wrapper.module, zero_point_name)
# todo: remove this when other quantizers use different scale & zero point for input/weight/output
elif hasattr(wrapper.module, 'scale') and hasattr(wrapper.module, 'zero_point'):
scale = wrapper.module.scale scale = wrapper.module.scale
zero_point = wrapper.module.zero_point zero_point = wrapper.module.zero_point
else: else:
scale, zero_point = None, None scale, zero_point = None, None
ctx.save_for_backward(tensor)
# Only tensors have gradients flowing back needs to be saved by save_for_backward. # Only tensors have gradients flowing back needs to be saved by save_for_backward.
# Others should directly assign to ctx. # Others should directly assign to ctx.
ctx.scale = scale ctx.save_for_backward(tensor)
ctx.zero_point = zero_point
ctx.quant_type = quant_type ctx.quant_type = quant_type
ctx.qmin, ctx.qmax = qmin, qmax ctx.qmin, ctx.qmax = qmin, qmax
ctx.scale = scale
ctx.zero_point = zero_point
return output return output
@classmethod @classmethod
def backward(cls, ctx, grad_output): def backward(cls, ctx, grad_output):
tensor = ctx.saved_variables[0] tensor = ctx.saved_variables[0]
scale, zero_point = ctx.scale, ctx.zero_point scale, zero_point = ctx.scale, ctx.zero_point
qmin, qmax = ctx.qmin, ctx.qmax
quant_type = ctx.quant_type quant_type = ctx.quant_type
qmin, qmax = ctx.qmin, ctx.qmax
output = cls.quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax) output = cls.quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax)
return output, None, None, None return output, None, None, None
...@@ -977,11 +1023,11 @@ def _check_bias(module): ...@@ -977,11 +1023,11 @@ def _check_bias(module):
return False return False
def quantize_helper(tensor, quant_type, wrapper, input_tensor=None, **kwargs): def quantize_helper(tensor, quant_type, wrapper, input_tensor=None, **kwargs):
if quant_type == QuantType.QUANT_INPUT: if quant_type == QuantType.INPUT:
output = wrapper.quantizer.quantize_input(tensor, wrapper=wrapper, **kwargs) output = wrapper.quantizer.quantize_input(tensor, wrapper=wrapper, **kwargs)
elif quant_type == QuantType.QUANT_WEIGHT: elif quant_type == QuantType.WEIGHT:
output = wrapper.quantizer.quantize_weight(wrapper, input_tensor=input_tensor, **kwargs) output = wrapper.quantizer.quantize_weight(wrapper, input_tensor=input_tensor, **kwargs)
elif quant_type == QuantType.QUANT_OUTPUT: elif quant_type == QuantType.OUTPUT:
output = wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs) output = wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs)
else: else:
raise ValueError("unrecognized QuantType.") raise ValueError("unrecognized QuantType.")
......
from enum import Enum, EnumMeta
class _QuantLiteralEnumMeta(EnumMeta):
def __contains__(cls, item):
try:
cls(item)
except ValueError:
return False
return True
class _QuantLiteralEnum(Enum, metaclass=_QuantLiteralEnumMeta):
pass
class QuantScheme(str, _QuantLiteralEnum):
PER_TENSOR_AFFINE = 'per_tensor_affine'
PER_TENSOR_SYMMETRIC = 'per_tensor_symmetric'
PER_CHANNEL_AFFINE = 'per_channel_affine'
PER_CHANNEL_SYMMETRIC = 'per_channel_symmetric'
PER_CHANNEL_QUANT_SCHEME = [QuantScheme.PER_CHANNEL_AFFINE, QuantScheme.PER_CHANNEL_SYMMETRIC]
class QuantDtype(str, _QuantLiteralEnum):
UINT = 'uint'
INT = 'int'
class QuantType(str, _QuantLiteralEnum):
INPUT = 'input'
WEIGHT = 'weight'
OUTPUT = 'output'
def type_to_scale_zero_point_name(self):
if self == QuantType.INPUT:
return 'input_scale', 'input_zero_point'
elif self == QuantType.WEIGHT:
return 'weight_scale', 'weight_zero_point'
elif self == QuantType.OUTPUT:
return 'output_scale', 'output_zero_point'
else:
raise TypeError
# Just show each attribute's name, no practical effect
class QuantConfigLiteral(str, _QuantLiteralEnum):
QUANT_SETTINGS = 'quant_settings'
QUANT_SCHEME = 'quant_scheme'
QUANT_DTYPE = 'quant_dtype'
BITS = 'bits'
QMIN = 'qmin'
QMAX = 'qmax'
INPUT_SCALE = 'input_scale'
INPUT_ZERO_POINT = 'input_zero_point'
OUTPUT_SCALE = 'output_scale'
OUTPUT_ZERO_POINT = 'output_zero_point'
WEIGHT_SCALE = 'weight_scale'
WEIGHT_ZERO_POINT = 'weight_zero_point'
BN_FOLD_OP = ["Conv2d"]
BN_FOLD_TAG = 'BN_FOLD_TAG'
from torch.quantization import default_weight_observer, default_histogram_observer from torch.quantization import default_weight_observer, default_histogram_observer
from torch.quantization import RecordingObserver as _RecordingObserver
__all__ = ["default_weight_observer", "default_histogram_observer"] __all__ = ["default_weight_observer", "default_histogram_observer", "RecordingObserver"]
class RecordingObserver(_RecordingObserver):
"""
A extended version of PyTorch's RecordingObserver, used to record gpu tensor
"""
def forward(self, x):
val = x.cpu()
super().forward(val)
return x
from typing import Any, Optional
from .literal import QuantDtype, QuantType, QuantScheme
from .utils import calculate_qmin_qmax, get_bits_length
# default settings for quantization module
quant_default_settings = {
QuantType.WEIGHT: {
'quant_scheme': QuantScheme.PER_TENSOR_AFFINE,
'quant_dtype': QuantDtype.UINT,
},
QuantType.INPUT: {
'quant_scheme': QuantScheme.PER_TENSOR_AFFINE,
'quant_dtype': QuantDtype.UINT
},
QuantType.OUTPUT: {
'quant_scheme': QuantScheme.PER_TENSOR_AFFINE,
'quant_dtype': QuantDtype.UINT
}
}
class TensorQuantSetting(object):
def __init__(self, **kwargs):
self._fields = {}
for k, v in kwargs.items():
self._fields[k] = v
def __setattr__(self, name: str, val: Any) -> None:
if name.startswith("_"):
super().__setattr__(name, val)
else:
self._fields[name] = val
def __getattr__(self, name):
if name == "_fields" or name not in self._fields:
raise AttributeError("Cannot find {} in TensorQuantSetting!".format(name))
return self._fields[name]
def get_qmin_qmax(self):
assert 'qmin' in self._fields and 'qmax' in self._fields, \
"Can not found qmin & qmax in TensorQuantSetting"
return self._fields['qmin'], self._fields['qmax']
class LayerQuantSetting(object):
def __init__(self, config):
self.input: Optional[TensorQuantSetting] = None
self.weight: Optional[TensorQuantSetting] = None
self.output: Optional[TensorQuantSetting] = None
self._extra_layer_setting = {}
for quant_type in QuantType:
if quant_type in config.get("quant_types", []):
setting = TensorQuantSetting()
quant_scheme = self.parse_optional_config(config, quant_type, 'quant_scheme')
setting.quant_scheme = quant_scheme
quant_dtype = self.parse_optional_config(config, quant_type, 'quant_dtype')
setting.quant_dtype = quant_dtype
bits = get_bits_length(config, quant_type)
qmin, qmax = calculate_qmin_qmax(bits, quant_dtype)
setting.bits = bits
setting.qmin = qmin
setting.qmax = qmax
setattr(self, quant_type, setting)
def __setattr__(self, name: str, val: Any) -> None:
if name.startswith("_") or name in QuantType:
super().__setattr__(name, val)
else:
self._extra_layer_setting[name] = val
def __getattr__(self, name):
if name == "_extra_layer_setting" or name not in self._extra_layer_setting:
raise AttributeError("Cannot find {} in LayerQuantSetting!".format(name))
return self._extra_layer_setting[name]
@staticmethod
def parse_optional_config(config, quant_type, target):
def get_config(config, quant_type, target):
if not config.get(target):
return None
if isinstance(config[target], dict):
return config[target].get(quant_type)
else:
return config[target]
default_val = quant_default_settings[quant_type].get(target, None)
config_val = get_config(config, quant_type, target)
val = config_val if config_val else default_val
return val
def set_quant_scheme_dtype(quant_type, new_scheme=None, new_dtype=None):
# todo: remove this if we convert string config to enum type.
if isinstance(quant_type, str):
assert quant_type in QuantType, "Wrong quant_type"
if isinstance(new_scheme, str):
assert new_scheme in QuantScheme, "Wrong quant_scheme"
if isinstance(new_dtype, str):
assert new_dtype in QuantDtype, "Wrong quant_dtype"
# TODO: It is not a good idea to directly modify global settings. A better choice is
# making this function an attribute function of Quantizer and call this function after
# the quantizer is initialized. However, within current framework of quantization, if
# we want to modify the dtype & scheme when the quantizer is initialized, we must do
# some other things (like changing the shapes of scales and zero_points and other quantization
# information in the subclass).
global quant_default_settings
if new_scheme is not None:
quant_default_settings[quant_type]['quant_scheme'] = new_scheme
if new_dtype is not None:
quant_default_settings[quant_type]['quant_dtype'] = new_dtype
return
import torch
from nni.common.version import TORCH_VERSION
from .literal import QuantDtype, QuantScheme, QuantType
def calculate_qmin_qmax(bits, dtype):
if dtype == QuantDtype.INT:
qmin, qmax = -2 ** (bits - 1) + 1, 2 ** (bits - 1) - 1
elif dtype == QuantDtype.UINT:
qmin, qmax = 0, 2 ** bits - 1
else:
raise TypeError("Wrong quantization dtype, please make sure it is one of 'int' and 'uint'.")
return qmin, qmax
def get_bits_length(config, quant_type):
if isinstance(config["quant_bits"], int):
return config["quant_bits"]
else:
return config["quant_bits"].get(quant_type)
def get_target_dim(quant_type, quant_scheme):
# for weight: c_out x c_in x (h) * (w)
# for feature maps: batch * channel * (t) * h * w
# other type is not supported for now
default_idx = 0 if quant_type == QuantType.WEIGHT else 1
if is_per_channel(quant_scheme):
target_dim = default_idx
else:
target_dim = None
return target_dim
def get_min_max_value(x, quant_type, quant_scheme):
target_dim = get_target_dim(quant_type, quant_scheme)
if target_dim is None:
return torch.min(x), torch.max(x)
indices = list(range(len(x.shape)))
assert target_dim < len(indices), "target_dim needs to be less than the number of dim of the tensor"
del indices[target_dim]
if TORCH_VERSION > (1, 6):
min_val = torch.amin(x, indices, keepdims=True)
max_val = torch.amax(x, indices, keepdims=True)
else:
min_val = max_val = x
for ind in indices:
min_val = torch.min(min_val, dim=ind, keepdim=True)[0]
max_val = torch.max(max_val, dim=ind, keepdim=True)[0]
return min_val, max_val
def get_mean_value(x, target_dim=None):
if target_dim is None:
return torch.mean(x)
indices = list(range(len(x.shape)))
assert target_dim < len(indices), "target_dim needs to be less than the number of dim of the tensor"
del indices[target_dim]
mean_val = torch.mean(x, dim=indices, keepdim=True)
return mean_val
def is_per_channel(quant_scheme):
if quant_scheme in [QuantScheme.PER_CHANNEL_AFFINE, QuantScheme.PER_CHANNEL_SYMMETRIC]:
return True
else:
return False
def get_quant_shape(shape, quant_type, quant_scheme):
default_idx = 0 if quant_type == QuantType.WEIGHT else 1
if is_per_channel(quant_scheme):
quant_shape = [1 if idx != default_idx else s for idx, s in enumerate(shape)]
else:
quant_shape = []
return quant_shape
...@@ -9,6 +9,7 @@ import torch.nn.functional as F ...@@ -9,6 +9,7 @@ import torch.nn.functional as F
import schema import schema
import nni.algorithms.compression.pytorch.pruning as torch_pruner import nni.algorithms.compression.pytorch.pruning as torch_pruner
import nni.algorithms.compression.pytorch.quantization as torch_quantizer import nni.algorithms.compression.pytorch.quantization as torch_quantizer
from nni.compression.pytorch.quantization.utils import calculate_qmin_qmax, get_quant_shape, get_min_max_value
import math import math
...@@ -50,7 +51,8 @@ class CompressorTestCase(TestCase): ...@@ -50,7 +51,8 @@ class CompressorTestCase(TestCase):
model.relu = torch.nn.ReLU() model.relu = torch.nn.ReLU()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer) dummy = torch.randn(1, 1, 28, 28)
quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer, dummy_input=dummy)
quantizer.compress() quantizer.compress()
modules_to_compress = quantizer.get_modules_to_compress() modules_to_compress = quantizer.get_modules_to_compress()
modules_to_compress_name = [t[0].name for t in modules_to_compress] modules_to_compress_name = [t[0].name for t in modules_to_compress]
...@@ -332,6 +334,130 @@ class CompressorTestCase(TestCase): ...@@ -332,6 +334,130 @@ class CompressorTestCase(TestCase):
self.assertFalse(isinstance(model.fc1.module.weight, torch.nn.Parameter)) self.assertFalse(isinstance(model.fc1.module.weight, torch.nn.Parameter))
self.assertFalse(isinstance(model.fc2.module.weight, torch.nn.Parameter)) self.assertFalse(isinstance(model.fc2.module.weight, torch.nn.Parameter))
def test_quantization_dtype_scheme(self):
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 2, 3, 1)
self.bn1 = torch.nn.BatchNorm2d(2)
def forward(self, x):
x = self.bn1(self.conv1(x))
return x
dtypes = ['int', 'uint']
qschemes = ['per_tensor_affine', 'per_tensor_symmetric', 'per_channel_affine', 'per_channel_symmetric']
for dtype in dtypes:
for qscheme in qschemes:
config_list = [{
'quant_types': ['weight', 'input'],
'quant_bits': 8,
'op_types': ['Conv2d'],
'quant_dtype': dtype,
'quant_scheme': qscheme
}]
model = TestModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
# only QAT_quantizer is supported for now
dummy = torch.randn(1, 1, 4, 4)
quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer, dummy_input=dummy)
# test layer setting
for layer, config in quantizer.modules_to_compress:
module = layer.module
name = layer.name
layer_setting = module.layer_quant_setting
qmin, qmax = calculate_qmin_qmax(8, dtype)
all_quant_types = ['input', 'weight']
for quant_type in all_quant_types:
# check for settings
tensor_setting = getattr(layer_setting, quant_type)
self.assertTrue(tensor_setting is not None)
self.assertTrue(tensor_setting.quant_scheme == qscheme)
self.assertTrue(tensor_setting.quant_dtype == dtype)
self.assertTrue(tensor_setting.qmin == qmin)
self.assertTrue(tensor_setting.qmax == qmax)
input_shape, output_shape = quantizer.all_shapes[name]
shape = input_shape if quant_type == 'input' else module.weight.shape
quant_shape = get_quant_shape(shape, quant_type, qscheme)
scale_name = quant_type + '_scale'
zero_point_name = quant_type + '_zero_point'
scale = getattr(module, scale_name)
zero_point = getattr(module, zero_point_name)
self.assertTrue(list(scale.shape) == quant_shape)
self.assertTrue(list(zero_point.shape) == quant_shape)
weight = torch.arange(start=1, end=19).view(2, 1, 3, 3)
if qscheme == 'per_channel_symmetric':
if dtype == 'int':
target_scale = torch.tensor([9. / 127, 18. / 127]).view([2, 1, 1, 1])
target_zero_point = torch.ones([2, 1, 1, 1]) * 0
else:
target_scale = torch.tensor([9. / 127.5, 18. / 127.5]).view([2, 1, 1, 1])
target_zero_point = torch.ones([2, 1, 1, 1]) * 127
elif qscheme == 'per_tensor_symmetric':
if dtype == 'int':
target_scale = torch.tensor(18. / 127)
target_zero_point = torch.zeros([])
else:
target_scale = torch.tensor(18. / 127.5)
target_zero_point = torch.ones([]) * 127
elif qscheme == 'per_channel_affine':
min_val = torch.tensor([0., 0.]).view([2, 1, 1, 1])
if dtype == 'int':
target_scale = torch.tensor([9. / 254, 18. / 254]).view([2, 1, 1, 1])
target_zero_point = -127 - torch.round(min_val / target_scale)
else:
target_scale = torch.tensor([9. / 255, 18. / 255]).view([2, 1, 1, 1])
target_zero_point = 0 - torch.round(min_val / target_scale)
else:
if dtype == 'int':
target_scale = torch.tensor(18. / 254)
target_zero_point = -127 - torch.round(0 / target_scale)
else:
target_scale = torch.tensor(18. / 255)
target_zero_point = 0 - torch.round(0 / target_scale)
wrapper = getattr(model, name)
wrapper.module.weight = weight
quantizer.quantize_weight(wrapper)
self.assertTrue(torch.equal(getattr(model, name).module.weight_scale, target_scale))
self.assertTrue(torch.equal(getattr(model, name).module.weight_zero_point, target_zero_point))
inp = torch.arange(start=0, end=16).view(1, 1, 4, 4)
if qscheme == 'per_channel_symmetric':
if dtype == 'int':
target_scale = torch.tensor([15. / 127]).view([1, 1, 1, 1])
target_zero_point = torch.ones([1, 1, 1, 1]) * 0
else:
target_scale = torch.tensor([15. / 127.5]).view([1, 1, 1, 1])
target_zero_point = torch.ones([1, 1, 1, 1]) * 127
elif qscheme == 'per_tensor_symmetric':
if dtype == 'int':
target_scale = torch.tensor(15. / 127)
target_zero_point = torch.zeros([])
else:
target_scale = torch.tensor(15. / 127.5)
target_zero_point = torch.ones([]) * 127
elif qscheme == 'per_channel_affine':
min_val = torch.tensor([0.]).view([1, 1, 1, 1])
if dtype == 'int':
target_scale = torch.tensor([15. / 254]).view([1, 1, 1, 1])
target_zero_point = -127 - torch.round(min_val / target_scale)
else:
target_scale = torch.tensor([15. / 255]).view([1, 1, 1, 1])
target_zero_point = 0 - torch.round(min_val / target_scale)
else:
if dtype == 'int':
target_scale = torch.tensor(15. / 254)
target_zero_point = -127 - torch.round(0 / target_scale)
else:
target_scale = torch.tensor(15. / 255)
target_zero_point = 0 - torch.round(0 / target_scale)
quantizer.quantize_input(inp, wrapper)
self.assertTrue(torch.equal(getattr(model, name).module.input_scale, target_scale))
self.assertTrue(torch.equal(getattr(model, name).module.input_zero_point, target_zero_point))
def test_torch_QAT_quantizer(self): def test_torch_QAT_quantizer(self):
model = TorchModel() model = TorchModel()
config_list = [{ config_list = [{
...@@ -347,7 +473,8 @@ class CompressorTestCase(TestCase): ...@@ -347,7 +473,8 @@ class CompressorTestCase(TestCase):
model.relu = torch.nn.ReLU() model.relu = torch.nn.ReLU()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer) dummy = torch.randn(1, 1, 28, 28)
quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer, dummy_input=dummy)
quantizer.compress() quantizer.compress()
# test quantize # test quantize
...@@ -357,20 +484,20 @@ class CompressorTestCase(TestCase): ...@@ -357,20 +484,20 @@ class CompressorTestCase(TestCase):
weight = torch.tensor([[1, 2], [3, 5]]).float() weight = torch.tensor([[1, 2], [3, 5]]).float()
model.conv2.module.weight.data = weight model.conv2.module.weight.data = weight
quantizer.quantize_weight(model.conv2, input_tensor=input) quantizer.quantize_weight(model.conv2, input_tensor=input)
assert math.isclose(model.conv2.module.scale, 5 / 255, abs_tol=eps) assert math.isclose(model.conv2.module.weight_scale, 5 / 255, abs_tol=eps)
assert model.conv2.module.zero_point == 0 assert model.conv2.module.weight_zero_point == 0
quantizer.quantize_input(input, model.conv2) quantizer.quantize_input(input, model.conv2)
self.assertTrue(torch.allclose(model.conv2.module.scale, torch.tensor([0.04 / 255]))) self.assertTrue(torch.allclose(model.conv2.module.input_scale, torch.tensor([4. / 255])))
self.assertTrue(torch.equal(model.conv2.module.zero_point, torch.tensor([0.]))) self.assertTrue(torch.equal(model.conv2.module.input_zero_point, torch.tensor(0.)))
# range including 0 # range including 0
weight = torch.tensor([[-1, 2], [3, 5]]).float() weight = torch.tensor([[-1, 2], [3, 5]]).float()
model.conv2.module.weight = weight model.conv2.module.weight = weight
quantizer.quantize_weight(model.conv2, input_tensor=input) quantizer.quantize_weight(model.conv2, input_tensor=input)
assert math.isclose(model.conv2.module.scale, 6 / 255, abs_tol=eps) assert math.isclose(model.conv2.module.weight_scale, 6 / 255, abs_tol=eps)
assert model.conv2.module.zero_point in (42, 43) assert model.conv2.module.weight_zero_point in (42, 43)
quantizer.quantize_input(input, model.conv2) quantizer.quantize_input(input, model.conv2)
self.assertTrue(torch.allclose(model.conv2.module.scale, torch.tensor([0.0796 / 255]))) self.assertTrue(torch.allclose(model.conv2.module.input_scale, torch.tensor([4. / 255])))
self.assertTrue(torch.equal(model.conv2.module.zero_point, torch.tensor([0.]))) self.assertTrue(torch.equal(model.conv2.module.input_zero_point, torch.tensor(0.)))
# test value of weight and bias after quantization # test value of weight and bias after quantization
weight = torch.tensor([[1.1287, 2.3456], [3.7814, 5.9723]]) weight = torch.tensor([[1.1287, 2.3456], [3.7814, 5.9723]])
weight_valid = torch.tensor([[1.1242, 2.3421], [3.7707, 5.9723]]) weight_valid = torch.tensor([[1.1242, 2.3421], [3.7707, 5.9723]])
...@@ -385,15 +512,15 @@ class CompressorTestCase(TestCase): ...@@ -385,15 +512,15 @@ class CompressorTestCase(TestCase):
# test ema # test ema
eps = 1e-7 eps = 1e-7
x = torch.tensor([[-0.2, 0], [0.1, 0.2]]) x = torch.tensor([[-0.2, 0], [0.1, 0.2]])
out = model.relu(x) model.relu(x)
assert math.isclose(model.relu.module.tracked_min_output, 0, abs_tol=eps) self.assertTrue(torch.equal(model.relu.module.tracked_min_output, torch.tensor(0.)))
assert math.isclose(model.relu.module.tracked_max_output, 0.002, abs_tol=eps) self.assertTrue(torch.equal(model.relu.module.tracked_max_output, torch.tensor(0.2)))
quantizer.step_with_optimizer() quantizer.step_with_optimizer()
x = torch.tensor([[0.2, 0.4], [0.6, 0.8]]) x = torch.tensor([[0.2, 0.4], [0.6, 0.8]])
out = model.relu(x) model.relu(x)
assert math.isclose(model.relu.module.tracked_min_output, 0.002, abs_tol=eps) self.assertTrue(torch.equal(model.relu.module.tracked_min_output, torch.tensor(0.002)))
assert math.isclose(model.relu.module.tracked_max_output, 0.00998, abs_tol=eps) self.assertTrue(torch.equal(model.relu.module.tracked_max_output, torch.tensor(0.2060)))
def test_torch_quantizer_export(self): def test_torch_quantizer_export(self):
config_list_qat = [{ config_list_qat = [{
...@@ -424,11 +551,14 @@ class CompressorTestCase(TestCase): ...@@ -424,11 +551,14 @@ class CompressorTestCase(TestCase):
}] }]
config_set = [config_list_qat, config_list_dorefa, config_list_bnn] config_set = [config_list_qat, config_list_dorefa, config_list_bnn]
quantize_algorithm_set = [torch_quantizer.QAT_Quantizer, torch_quantizer.DoReFaQuantizer, torch_quantizer.BNNQuantizer] quantize_algorithm_set = [torch_quantizer.QAT_Quantizer, torch_quantizer.DoReFaQuantizer, torch_quantizer.BNNQuantizer]
dummy = torch.randn(1, 1, 28, 28)
for config, quantize_algorithm in zip(config_set, quantize_algorithm_set): for config, quantize_algorithm in zip(config_set, quantize_algorithm_set):
model = TorchModel() model = TorchModel()
model.relu = torch.nn.ReLU() model.relu = torch.nn.ReLU()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
if quantize_algorithm == torch_quantizer.QAT_Quantizer:
quantizer = quantize_algorithm(model, config, optimizer, dummy)
else:
quantizer = quantize_algorithm(model, config, optimizer) quantizer = quantize_algorithm(model, config, optimizer)
quantizer.compress() quantizer.compress()
...@@ -461,6 +591,10 @@ class CompressorTestCase(TestCase): ...@@ -461,6 +591,10 @@ class CompressorTestCase(TestCase):
model = TorchModel().eval() model = TorchModel().eval()
model.relu = torch.nn.ReLU() model.relu = torch.nn.ReLU()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
if quantize_algorithm == torch_quantizer.QAT_Quantizer:
dummy = torch.randn(1, 1, 28, 28)
quantizer = quantize_algorithm(model, configure_list, optimizer, dummy_input=dummy)
else:
quantizer = quantize_algorithm(model, configure_list, optimizer) quantizer = quantize_algorithm(model, configure_list, optimizer)
quantizer.compress() quantizer.compress()
if calibration_config is not None: if calibration_config is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment