Unverified Commit c9cd53aa authored by chenbohua3's avatar chenbohua3 Committed by GitHub
Browse files

support dtype&scheme customization for QAT quantizer (#4137)

parent b0f34da1
...@@ -155,7 +155,7 @@ Sometimes it's necessary for a quantization operation to have a customized backw ...@@ -155,7 +155,7 @@ Sometimes it's necessary for a quantization operation to have a customized backw
grad_output : Tensor grad_output : Tensor
gradient of the output of quantization operation gradient of the output of quantization operation
quant_type : QuantType quant_type : QuantType
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`, the type of quantization, it can be `QuantType.INPUT`, `QuantType.WEIGHT`, `QuantType.OUTPUT`,
you can define different behavior for different types. you can define different behavior for different types.
Returns Returns
------- -------
...@@ -164,7 +164,7 @@ Sometimes it's necessary for a quantization operation to have a customized backw ...@@ -164,7 +164,7 @@ Sometimes it's necessary for a quantization operation to have a customized backw
""" """
# for quant_output function, set grad to zero if the absolute value of tensor is larger than 1 # for quant_output function, set grad to zero if the absolute value of tensor is larger than 1
if quant_type == QuantType.QUANT_OUTPUT: if quant_type == QuantType.OUTPUT:
grad_output[torch.abs(tensor) > 1] = 0 grad_output[torch.abs(tensor) > 1] = 0
return grad_output return grad_output
......
...@@ -2,11 +2,13 @@ import torch ...@@ -2,11 +2,13 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torchvision import datasets, transforms from torchvision import datasets, transforms
from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer
from nni.compression.pytorch.quantization.settings import set_quant_scheme_dtype
import sys import sys
sys.path.append('../models') sys.path.append('../models')
from mnist.naive import NaiveModel from mnist.naive import NaiveModel
def train(model, device, train_loader, optimizer): def train(model, device, train_loader, optimizer):
model.train() model.train()
for batch_idx, (data, target) in enumerate(train_loader): for batch_idx, (data, target) in enumerate(train_loader):
...@@ -58,22 +60,32 @@ def main(): ...@@ -58,22 +60,32 @@ def main():
# {'quant_types': ['input'], 'op_names': ['b']} in the configure_list. # {'quant_types': ['input'], 'op_names': ['b']} in the configure_list.
configure_list = [{ configure_list = [{
'quant_types': ['weight', 'input'], 'quant_types': ['weight', 'input'],
'quant_bits': {'weight': 8, 'input': 8}, 'quant_bits': {'weight': 8, 'input': 8},
'op_names': ['conv1', 'conv2'] 'op_names': ['conv1', 'conv2']
}, { }, {
'quant_types': ['output'], 'quant_types': ['output'],
'quant_bits': {'output': 8, }, 'quant_bits': {'output': 8, },
'op_names': ['relu1', 'relu2'] 'op_names': ['relu1', 'relu2']
}, { }, {
'quant_types': ['output', 'weight', 'input'], 'quant_types': ['output', 'weight', 'input'],
'quant_bits': {'output': 8, 'weight': 8, 'input': 8}, 'quant_bits': {'output': 8, 'weight': 8, 'input': 8},
'op_names': ['fc1'], 'op_names': ['fc1', 'fc2'],
}, { }]
'quant_types': ['output', 'weight', 'input'],
'quant_bits': {'output': 8, 'weight': 8, 'input': 8}, # you can also set the quantization dtype and scheme layer-wise through configure_list like:
'op_names': ['fc2'], # configure_list = [{
}] # 'quant_types': ['weight', 'input'],
# 'quant_bits': {'weight': 8, 'input': 8},
# 'op_names': ['conv1', 'conv2'],
# 'quant_dtype': 'int',
# 'quant_scheme': 'per_channel_symmetric'
# }]
# For now quant_dtype's options are 'int' and 'uint. And quant_scheme's options are per_tensor_affine,
# per_tensor_symmetric, per_channel_affine and per_channel_symmetric.
set_quant_scheme_dtype('weight', 'per_channel_symmetric', 'int')
set_quant_scheme_dtype('output', 'per_tensor_symmetric', 'int')
set_quant_scheme_dtype('input', 'per_tensor_symmetric', 'int')
model = NaiveModel().to(device) model = NaiveModel().to(device)
dummy_input = torch.randn(1, 1, 28, 28).to(device) dummy_input = torch.randn(1, 1, 28, 28).to(device)
...@@ -98,5 +110,6 @@ def main(): ...@@ -98,5 +110,6 @@ def main():
calibration_config = quantizer.export_model(model_path, calibration_path, onnx_path, input_shape, device) calibration_config = quantizer.export_model(model_path, calibration_path, onnx_path, input_shape, device)
print("Generated calibration config is: ", calibration_config) print("Generated calibration config is: ", calibration_config)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
import logging
try:
import torch
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split(".")[:2])
except Exception:
logging.info("PyTorch is not installed.")
TORCH_VERSION = None
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import copy
import types import types
import logging import logging
import torch import torch
from nni.common.graph_utils import build_module_graph from nni.common.graph_utils import build_module_graph
from nni.compression.pytorch.quantization.literal import QuantType, BN_FOLD_OP, BN_FOLD_TAG
from nni.compression.pytorch.quantization.observers import RecordingObserver
from . import default_layers from . import default_layers
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -547,7 +550,7 @@ class QuantizerModuleWrapper(torch.nn.Module): ...@@ -547,7 +550,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
assert len(inputs) == 1, "Quantization of input only supports ops with single input." assert len(inputs) == 1, "Quantization of input only supports ops with single input."
new_inp = self.quantizer.quant_grad( new_inp = self.quantizer.quant_grad(
inputs[0], inputs[0],
QuantType.QUANT_INPUT, QuantType.INPUT,
self) self)
inputs = (new_inp,) inputs = (new_inp,)
...@@ -563,7 +566,7 @@ class QuantizerModuleWrapper(torch.nn.Module): ...@@ -563,7 +566,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
self.quantizer.quant_grad( self.quantizer.quant_grad(
new_weight, new_weight,
QuantType.QUANT_WEIGHT, QuantType.WEIGHT,
self, inputs[0]) self, inputs[0])
result = self.module(*inputs) result = self.module(*inputs)
...@@ -571,7 +574,7 @@ class QuantizerModuleWrapper(torch.nn.Module): ...@@ -571,7 +574,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
if 'output' in self.config['quant_types']: if 'output' in self.config['quant_types']:
result = self.quantizer.quant_grad( result = self.quantizer.quant_grad(
result, result,
QuantType.QUANT_OUTPUT, QuantType.OUTPUT,
self) self)
return result return result
...@@ -604,10 +607,13 @@ class Quantizer(Compressor): ...@@ -604,10 +607,13 @@ class Quantizer(Compressor):
def __init__(self, model, config_list, optimizer=None, dummy_input=None): def __init__(self, model, config_list, optimizer=None, dummy_input=None):
if isinstance(model, torch.nn.DataParallel): if isinstance(model, torch.nn.DataParallel):
model = model.module model = model.module
model_copied = copy.deepcopy(model)
self.identity_wrappers = [] self.identity_wrappers = []
self.conv_bn_patterns = {} self.conv_bn_patterns = {}
self.find_conv_bn_patterns(model, dummy_input) self.find_conv_bn_patterns(model, dummy_input)
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
self.all_shapes = {}
self.record_shape(model_copied, dummy_input)
self.quant_grad = QuantGrad.apply self.quant_grad = QuantGrad.apply
if self.optimizer is not None: if self.optimizer is not None:
self.patch_optimizer(self.step_with_optimizer) self.patch_optimizer(self.step_with_optimizer)
...@@ -845,25 +851,54 @@ class Quantizer(Compressor): ...@@ -845,25 +851,54 @@ class Quantizer(Compressor):
if successor.op_type == 'BatchNorm2d': if successor.op_type == 'BatchNorm2d':
self.conv_bn_patterns[node_group.name] = successor.name self.conv_bn_patterns[node_group.name] = successor.name
def step_with_optimizer(self): def record_shape(self, model, dummy_input):
pass """
Record input/output's shapes of each module to be quantized
class QuantType: Parameters
""" ----------
Enum class for quantization type. model : torch.nn.Module
""" model to be recorded.
QUANT_INPUT = 0 dummy_input : tupel of torch.tensor
QUANT_WEIGHT = 1 inputs to the model.
QUANT_OUTPUT = 2 """
def _pre_forward_hook(self, inp):
# Only record the first tensor of the input
return self.pre_forward(inp[0])
def _post_forward_hook(self, _, out):
return self.post_forward(out)
if dummy_input is None:
return
all_handles = []
all_observers = {}
modules_to_compress = self.get_modules_to_compress()
compress_names = [layer_info[0].name for layer_info in modules_to_compress]
for name, module in model.named_modules():
if name in compress_names:
all_observers[name] = {}
all_observers[name]['input_hook'] = RecordingObserver()
all_observers[name]['output_hook'] = RecordingObserver()
module.add_module('pre_forward', all_observers[name]['input_hook'])
module.add_module('post_forward', all_observers[name]['output_hook'])
all_handles.append(module.register_forward_pre_hook(_pre_forward_hook))
all_handles.append(module.register_forward_hook(_post_forward_hook))
model(dummy_input)
for name, hooks in all_observers.items():
# only support single input
input_val = hooks['input_hook'].tensor_val
input_shape = input_val[0].shape if input_val else None
output_val = hooks['output_hook'].tensor_val
output_shape = output_val[0].shape if output_val else None
shapes = [input_shape, output_shape]
self.all_shapes[name] = shapes
return
QType_Dict = { def step_with_optimizer(self):
0: "input", pass
1: "weight",
2: "output"
}
BN_FOLD_OP = ["Conv2d"]
BN_FOLD_TAG = 'BN_FOLD_TAG'
class QuantGrad(torch.autograd.Function): class QuantGrad(torch.autograd.Function):
""" """
...@@ -920,8 +955,8 @@ class QuantGrad(torch.autograd.Function): ...@@ -920,8 +955,8 @@ class QuantGrad(torch.autograd.Function):
grad_output : Tensor grad_output : Tensor
gradient of the output of quantization operation gradient of the output of quantization operation
scale : Tensor scale : Tensor
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, the type of quantization, it can be `QuantType.INPUT`, `QuantType.WEIGHT`,
`QuantType.QUANT_OUTPUT`, you can define different behavior for different types. `QuantType.OUTPUT`, you can define different behavior for different types.
zero_point : Tensor zero_point : Tensor
zero_point for quantizing tensor zero_point for quantizing tensor
qmin : Tensor qmin : Tensor
...@@ -939,28 +974,39 @@ class QuantGrad(torch.autograd.Function): ...@@ -939,28 +974,39 @@ class QuantGrad(torch.autograd.Function):
def forward(ctx, tensor, quant_type, wrapper, input_tensor=None, **kwargs): def forward(ctx, tensor, quant_type, wrapper, input_tensor=None, **kwargs):
output = quantize_helper(tensor, quant_type, wrapper, input_tensor, **kwargs) output = quantize_helper(tensor, quant_type, wrapper, input_tensor, **kwargs)
bits = QuantGrad.get_bits_length(wrapper.config, QType_Dict[quant_type]) if hasattr(wrapper.module, "layer_quant_setting"):
qmin, qmax = torch.Tensor([0]).to(tensor.device), torch.Tensor([(1 << bits) - 1]).to(tensor.device) layer_quant_setting = wrapper.module.layer_quant_setting
if hasattr(wrapper.module, 'scale') and hasattr(wrapper.module, 'zero_point'): qmin, qmax = getattr(layer_quant_setting, quant_type).get_qmin_qmax()
else:
# todo: when dtype/scheme customization is ready for all quantizers, remove this
bits = QuantGrad.get_bits_length(wrapper.config, quant_type)
qmin, qmax = 0, (1 << bits) - 1
scale_name, zero_point_name = quant_type.type_to_scale_zero_point_name()
if hasattr(wrapper.module, scale_name) and hasattr(wrapper.module, zero_point_name):
scale = getattr(wrapper.module, scale_name)
zero_point = getattr(wrapper.module, zero_point_name)
# todo: remove this when other quantizers use different scale & zero point for input/weight/output
elif hasattr(wrapper.module, 'scale') and hasattr(wrapper.module, 'zero_point'):
scale = wrapper.module.scale scale = wrapper.module.scale
zero_point = wrapper.module.zero_point zero_point = wrapper.module.zero_point
else: else:
scale, zero_point = None, None scale, zero_point = None, None
ctx.save_for_backward(tensor)
# Only tensors have gradients flowing back needs to be saved by save_for_backward. # Only tensors have gradients flowing back needs to be saved by save_for_backward.
# Others should directly assign to ctx. # Others should directly assign to ctx.
ctx.scale = scale ctx.save_for_backward(tensor)
ctx.zero_point = zero_point
ctx.quant_type = quant_type ctx.quant_type = quant_type
ctx.qmin, ctx.qmax = qmin, qmax ctx.qmin, ctx.qmax = qmin, qmax
ctx.scale = scale
ctx.zero_point = zero_point
return output return output
@classmethod @classmethod
def backward(cls, ctx, grad_output): def backward(cls, ctx, grad_output):
tensor = ctx.saved_variables[0] tensor = ctx.saved_variables[0]
scale, zero_point = ctx.scale, ctx.zero_point scale, zero_point = ctx.scale, ctx.zero_point
qmin, qmax = ctx.qmin, ctx.qmax
quant_type = ctx.quant_type quant_type = ctx.quant_type
qmin, qmax = ctx.qmin, ctx.qmax
output = cls.quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax) output = cls.quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax)
return output, None, None, None return output, None, None, None
...@@ -977,11 +1023,11 @@ def _check_bias(module): ...@@ -977,11 +1023,11 @@ def _check_bias(module):
return False return False
def quantize_helper(tensor, quant_type, wrapper, input_tensor=None, **kwargs): def quantize_helper(tensor, quant_type, wrapper, input_tensor=None, **kwargs):
if quant_type == QuantType.QUANT_INPUT: if quant_type == QuantType.INPUT:
output = wrapper.quantizer.quantize_input(tensor, wrapper=wrapper, **kwargs) output = wrapper.quantizer.quantize_input(tensor, wrapper=wrapper, **kwargs)
elif quant_type == QuantType.QUANT_WEIGHT: elif quant_type == QuantType.WEIGHT:
output = wrapper.quantizer.quantize_weight(wrapper, input_tensor=input_tensor, **kwargs) output = wrapper.quantizer.quantize_weight(wrapper, input_tensor=input_tensor, **kwargs)
elif quant_type == QuantType.QUANT_OUTPUT: elif quant_type == QuantType.OUTPUT:
output = wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs) output = wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs)
else: else:
raise ValueError("unrecognized QuantType.") raise ValueError("unrecognized QuantType.")
......
from enum import Enum, EnumMeta
class _QuantLiteralEnumMeta(EnumMeta):
def __contains__(cls, item):
try:
cls(item)
except ValueError:
return False
return True
class _QuantLiteralEnum(Enum, metaclass=_QuantLiteralEnumMeta):
pass
class QuantScheme(str, _QuantLiteralEnum):
PER_TENSOR_AFFINE = 'per_tensor_affine'
PER_TENSOR_SYMMETRIC = 'per_tensor_symmetric'
PER_CHANNEL_AFFINE = 'per_channel_affine'
PER_CHANNEL_SYMMETRIC = 'per_channel_symmetric'
PER_CHANNEL_QUANT_SCHEME = [QuantScheme.PER_CHANNEL_AFFINE, QuantScheme.PER_CHANNEL_SYMMETRIC]
class QuantDtype(str, _QuantLiteralEnum):
UINT = 'uint'
INT = 'int'
class QuantType(str, _QuantLiteralEnum):
INPUT = 'input'
WEIGHT = 'weight'
OUTPUT = 'output'
def type_to_scale_zero_point_name(self):
if self == QuantType.INPUT:
return 'input_scale', 'input_zero_point'
elif self == QuantType.WEIGHT:
return 'weight_scale', 'weight_zero_point'
elif self == QuantType.OUTPUT:
return 'output_scale', 'output_zero_point'
else:
raise TypeError
# Just show each attribute's name, no practical effect
class QuantConfigLiteral(str, _QuantLiteralEnum):
QUANT_SETTINGS = 'quant_settings'
QUANT_SCHEME = 'quant_scheme'
QUANT_DTYPE = 'quant_dtype'
BITS = 'bits'
QMIN = 'qmin'
QMAX = 'qmax'
INPUT_SCALE = 'input_scale'
INPUT_ZERO_POINT = 'input_zero_point'
OUTPUT_SCALE = 'output_scale'
OUTPUT_ZERO_POINT = 'output_zero_point'
WEIGHT_SCALE = 'weight_scale'
WEIGHT_ZERO_POINT = 'weight_zero_point'
BN_FOLD_OP = ["Conv2d"]
BN_FOLD_TAG = 'BN_FOLD_TAG'
from torch.quantization import default_weight_observer, default_histogram_observer from torch.quantization import default_weight_observer, default_histogram_observer
from torch.quantization import RecordingObserver as _RecordingObserver
__all__ = ["default_weight_observer", "default_histogram_observer"] __all__ = ["default_weight_observer", "default_histogram_observer", "RecordingObserver"]
class RecordingObserver(_RecordingObserver):
"""
A extended version of PyTorch's RecordingObserver, used to record gpu tensor
"""
def forward(self, x):
val = x.cpu()
super().forward(val)
return x
from typing import Any, Optional
from .literal import QuantDtype, QuantType, QuantScheme
from .utils import calculate_qmin_qmax, get_bits_length
# default settings for quantization module
quant_default_settings = {
QuantType.WEIGHT: {
'quant_scheme': QuantScheme.PER_TENSOR_AFFINE,
'quant_dtype': QuantDtype.UINT,
},
QuantType.INPUT: {
'quant_scheme': QuantScheme.PER_TENSOR_AFFINE,
'quant_dtype': QuantDtype.UINT
},
QuantType.OUTPUT: {
'quant_scheme': QuantScheme.PER_TENSOR_AFFINE,
'quant_dtype': QuantDtype.UINT
}
}
class TensorQuantSetting(object):
def __init__(self, **kwargs):
self._fields = {}
for k, v in kwargs.items():
self._fields[k] = v
def __setattr__(self, name: str, val: Any) -> None:
if name.startswith("_"):
super().__setattr__(name, val)
else:
self._fields[name] = val
def __getattr__(self, name):
if name == "_fields" or name not in self._fields:
raise AttributeError("Cannot find {} in TensorQuantSetting!".format(name))
return self._fields[name]
def get_qmin_qmax(self):
assert 'qmin' in self._fields and 'qmax' in self._fields, \
"Can not found qmin & qmax in TensorQuantSetting"
return self._fields['qmin'], self._fields['qmax']
class LayerQuantSetting(object):
def __init__(self, config):
self.input: Optional[TensorQuantSetting] = None
self.weight: Optional[TensorQuantSetting] = None
self.output: Optional[TensorQuantSetting] = None
self._extra_layer_setting = {}
for quant_type in QuantType:
if quant_type in config.get("quant_types", []):
setting = TensorQuantSetting()
quant_scheme = self.parse_optional_config(config, quant_type, 'quant_scheme')
setting.quant_scheme = quant_scheme
quant_dtype = self.parse_optional_config(config, quant_type, 'quant_dtype')
setting.quant_dtype = quant_dtype
bits = get_bits_length(config, quant_type)
qmin, qmax = calculate_qmin_qmax(bits, quant_dtype)
setting.bits = bits
setting.qmin = qmin
setting.qmax = qmax
setattr(self, quant_type, setting)
def __setattr__(self, name: str, val: Any) -> None:
if name.startswith("_") or name in QuantType:
super().__setattr__(name, val)
else:
self._extra_layer_setting[name] = val
def __getattr__(self, name):
if name == "_extra_layer_setting" or name not in self._extra_layer_setting:
raise AttributeError("Cannot find {} in LayerQuantSetting!".format(name))
return self._extra_layer_setting[name]
@staticmethod
def parse_optional_config(config, quant_type, target):
def get_config(config, quant_type, target):
if not config.get(target):
return None
if isinstance(config[target], dict):
return config[target].get(quant_type)
else:
return config[target]
default_val = quant_default_settings[quant_type].get(target, None)
config_val = get_config(config, quant_type, target)
val = config_val if config_val else default_val
return val
def set_quant_scheme_dtype(quant_type, new_scheme=None, new_dtype=None):
# todo: remove this if we convert string config to enum type.
if isinstance(quant_type, str):
assert quant_type in QuantType, "Wrong quant_type"
if isinstance(new_scheme, str):
assert new_scheme in QuantScheme, "Wrong quant_scheme"
if isinstance(new_dtype, str):
assert new_dtype in QuantDtype, "Wrong quant_dtype"
# TODO: It is not a good idea to directly modify global settings. A better choice is
# making this function an attribute function of Quantizer and call this function after
# the quantizer is initialized. However, within current framework of quantization, if
# we want to modify the dtype & scheme when the quantizer is initialized, we must do
# some other things (like changing the shapes of scales and zero_points and other quantization
# information in the subclass).
global quant_default_settings
if new_scheme is not None:
quant_default_settings[quant_type]['quant_scheme'] = new_scheme
if new_dtype is not None:
quant_default_settings[quant_type]['quant_dtype'] = new_dtype
return
import torch
from nni.common.version import TORCH_VERSION
from .literal import QuantDtype, QuantScheme, QuantType
def calculate_qmin_qmax(bits, dtype):
if dtype == QuantDtype.INT:
qmin, qmax = -2 ** (bits - 1) + 1, 2 ** (bits - 1) - 1
elif dtype == QuantDtype.UINT:
qmin, qmax = 0, 2 ** bits - 1
else:
raise TypeError("Wrong quantization dtype, please make sure it is one of 'int' and 'uint'.")
return qmin, qmax
def get_bits_length(config, quant_type):
if isinstance(config["quant_bits"], int):
return config["quant_bits"]
else:
return config["quant_bits"].get(quant_type)
def get_target_dim(quant_type, quant_scheme):
# for weight: c_out x c_in x (h) * (w)
# for feature maps: batch * channel * (t) * h * w
# other type is not supported for now
default_idx = 0 if quant_type == QuantType.WEIGHT else 1
if is_per_channel(quant_scheme):
target_dim = default_idx
else:
target_dim = None
return target_dim
def get_min_max_value(x, quant_type, quant_scheme):
target_dim = get_target_dim(quant_type, quant_scheme)
if target_dim is None:
return torch.min(x), torch.max(x)
indices = list(range(len(x.shape)))
assert target_dim < len(indices), "target_dim needs to be less than the number of dim of the tensor"
del indices[target_dim]
if TORCH_VERSION > (1, 6):
min_val = torch.amin(x, indices, keepdims=True)
max_val = torch.amax(x, indices, keepdims=True)
else:
min_val = max_val = x
for ind in indices:
min_val = torch.min(min_val, dim=ind, keepdim=True)[0]
max_val = torch.max(max_val, dim=ind, keepdim=True)[0]
return min_val, max_val
def get_mean_value(x, target_dim=None):
if target_dim is None:
return torch.mean(x)
indices = list(range(len(x.shape)))
assert target_dim < len(indices), "target_dim needs to be less than the number of dim of the tensor"
del indices[target_dim]
mean_val = torch.mean(x, dim=indices, keepdim=True)
return mean_val
def is_per_channel(quant_scheme):
if quant_scheme in [QuantScheme.PER_CHANNEL_AFFINE, QuantScheme.PER_CHANNEL_SYMMETRIC]:
return True
else:
return False
def get_quant_shape(shape, quant_type, quant_scheme):
default_idx = 0 if quant_type == QuantType.WEIGHT else 1
if is_per_channel(quant_scheme):
quant_shape = [1 if idx != default_idx else s for idx, s in enumerate(shape)]
else:
quant_shape = []
return quant_shape
...@@ -9,6 +9,7 @@ import torch.nn.functional as F ...@@ -9,6 +9,7 @@ import torch.nn.functional as F
import schema import schema
import nni.algorithms.compression.pytorch.pruning as torch_pruner import nni.algorithms.compression.pytorch.pruning as torch_pruner
import nni.algorithms.compression.pytorch.quantization as torch_quantizer import nni.algorithms.compression.pytorch.quantization as torch_quantizer
from nni.compression.pytorch.quantization.utils import calculate_qmin_qmax, get_quant_shape, get_min_max_value
import math import math
...@@ -50,7 +51,8 @@ class CompressorTestCase(TestCase): ...@@ -50,7 +51,8 @@ class CompressorTestCase(TestCase):
model.relu = torch.nn.ReLU() model.relu = torch.nn.ReLU()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer) dummy = torch.randn(1, 1, 28, 28)
quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer, dummy_input=dummy)
quantizer.compress() quantizer.compress()
modules_to_compress = quantizer.get_modules_to_compress() modules_to_compress = quantizer.get_modules_to_compress()
modules_to_compress_name = [t[0].name for t in modules_to_compress] modules_to_compress_name = [t[0].name for t in modules_to_compress]
...@@ -332,6 +334,130 @@ class CompressorTestCase(TestCase): ...@@ -332,6 +334,130 @@ class CompressorTestCase(TestCase):
self.assertFalse(isinstance(model.fc1.module.weight, torch.nn.Parameter)) self.assertFalse(isinstance(model.fc1.module.weight, torch.nn.Parameter))
self.assertFalse(isinstance(model.fc2.module.weight, torch.nn.Parameter)) self.assertFalse(isinstance(model.fc2.module.weight, torch.nn.Parameter))
def test_quantization_dtype_scheme(self):
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 2, 3, 1)
self.bn1 = torch.nn.BatchNorm2d(2)
def forward(self, x):
x = self.bn1(self.conv1(x))
return x
dtypes = ['int', 'uint']
qschemes = ['per_tensor_affine', 'per_tensor_symmetric', 'per_channel_affine', 'per_channel_symmetric']
for dtype in dtypes:
for qscheme in qschemes:
config_list = [{
'quant_types': ['weight', 'input'],
'quant_bits': 8,
'op_types': ['Conv2d'],
'quant_dtype': dtype,
'quant_scheme': qscheme
}]
model = TestModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
# only QAT_quantizer is supported for now
dummy = torch.randn(1, 1, 4, 4)
quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer, dummy_input=dummy)
# test layer setting
for layer, config in quantizer.modules_to_compress:
module = layer.module
name = layer.name
layer_setting = module.layer_quant_setting
qmin, qmax = calculate_qmin_qmax(8, dtype)
all_quant_types = ['input', 'weight']
for quant_type in all_quant_types:
# check for settings
tensor_setting = getattr(layer_setting, quant_type)
self.assertTrue(tensor_setting is not None)
self.assertTrue(tensor_setting.quant_scheme == qscheme)
self.assertTrue(tensor_setting.quant_dtype == dtype)
self.assertTrue(tensor_setting.qmin == qmin)
self.assertTrue(tensor_setting.qmax == qmax)
input_shape, output_shape = quantizer.all_shapes[name]
shape = input_shape if quant_type == 'input' else module.weight.shape
quant_shape = get_quant_shape(shape, quant_type, qscheme)
scale_name = quant_type + '_scale'
zero_point_name = quant_type + '_zero_point'
scale = getattr(module, scale_name)
zero_point = getattr(module, zero_point_name)
self.assertTrue(list(scale.shape) == quant_shape)
self.assertTrue(list(zero_point.shape) == quant_shape)
weight = torch.arange(start=1, end=19).view(2, 1, 3, 3)
if qscheme == 'per_channel_symmetric':
if dtype == 'int':
target_scale = torch.tensor([9. / 127, 18. / 127]).view([2, 1, 1, 1])
target_zero_point = torch.ones([2, 1, 1, 1]) * 0
else:
target_scale = torch.tensor([9. / 127.5, 18. / 127.5]).view([2, 1, 1, 1])
target_zero_point = torch.ones([2, 1, 1, 1]) * 127
elif qscheme == 'per_tensor_symmetric':
if dtype == 'int':
target_scale = torch.tensor(18. / 127)
target_zero_point = torch.zeros([])
else:
target_scale = torch.tensor(18. / 127.5)
target_zero_point = torch.ones([]) * 127
elif qscheme == 'per_channel_affine':
min_val = torch.tensor([0., 0.]).view([2, 1, 1, 1])
if dtype == 'int':
target_scale = torch.tensor([9. / 254, 18. / 254]).view([2, 1, 1, 1])
target_zero_point = -127 - torch.round(min_val / target_scale)
else:
target_scale = torch.tensor([9. / 255, 18. / 255]).view([2, 1, 1, 1])
target_zero_point = 0 - torch.round(min_val / target_scale)
else:
if dtype == 'int':
target_scale = torch.tensor(18. / 254)
target_zero_point = -127 - torch.round(0 / target_scale)
else:
target_scale = torch.tensor(18. / 255)
target_zero_point = 0 - torch.round(0 / target_scale)
wrapper = getattr(model, name)
wrapper.module.weight = weight
quantizer.quantize_weight(wrapper)
self.assertTrue(torch.equal(getattr(model, name).module.weight_scale, target_scale))
self.assertTrue(torch.equal(getattr(model, name).module.weight_zero_point, target_zero_point))
inp = torch.arange(start=0, end=16).view(1, 1, 4, 4)
if qscheme == 'per_channel_symmetric':
if dtype == 'int':
target_scale = torch.tensor([15. / 127]).view([1, 1, 1, 1])
target_zero_point = torch.ones([1, 1, 1, 1]) * 0
else:
target_scale = torch.tensor([15. / 127.5]).view([1, 1, 1, 1])
target_zero_point = torch.ones([1, 1, 1, 1]) * 127
elif qscheme == 'per_tensor_symmetric':
if dtype == 'int':
target_scale = torch.tensor(15. / 127)
target_zero_point = torch.zeros([])
else:
target_scale = torch.tensor(15. / 127.5)
target_zero_point = torch.ones([]) * 127
elif qscheme == 'per_channel_affine':
min_val = torch.tensor([0.]).view([1, 1, 1, 1])
if dtype == 'int':
target_scale = torch.tensor([15. / 254]).view([1, 1, 1, 1])
target_zero_point = -127 - torch.round(min_val / target_scale)
else:
target_scale = torch.tensor([15. / 255]).view([1, 1, 1, 1])
target_zero_point = 0 - torch.round(min_val / target_scale)
else:
if dtype == 'int':
target_scale = torch.tensor(15. / 254)
target_zero_point = -127 - torch.round(0 / target_scale)
else:
target_scale = torch.tensor(15. / 255)
target_zero_point = 0 - torch.round(0 / target_scale)
quantizer.quantize_input(inp, wrapper)
self.assertTrue(torch.equal(getattr(model, name).module.input_scale, target_scale))
self.assertTrue(torch.equal(getattr(model, name).module.input_zero_point, target_zero_point))
def test_torch_QAT_quantizer(self): def test_torch_QAT_quantizer(self):
model = TorchModel() model = TorchModel()
config_list = [{ config_list = [{
...@@ -347,7 +473,8 @@ class CompressorTestCase(TestCase): ...@@ -347,7 +473,8 @@ class CompressorTestCase(TestCase):
model.relu = torch.nn.ReLU() model.relu = torch.nn.ReLU()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer) dummy = torch.randn(1, 1, 28, 28)
quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer, dummy_input=dummy)
quantizer.compress() quantizer.compress()
# test quantize # test quantize
...@@ -357,20 +484,20 @@ class CompressorTestCase(TestCase): ...@@ -357,20 +484,20 @@ class CompressorTestCase(TestCase):
weight = torch.tensor([[1, 2], [3, 5]]).float() weight = torch.tensor([[1, 2], [3, 5]]).float()
model.conv2.module.weight.data = weight model.conv2.module.weight.data = weight
quantizer.quantize_weight(model.conv2, input_tensor=input) quantizer.quantize_weight(model.conv2, input_tensor=input)
assert math.isclose(model.conv2.module.scale, 5 / 255, abs_tol=eps) assert math.isclose(model.conv2.module.weight_scale, 5 / 255, abs_tol=eps)
assert model.conv2.module.zero_point == 0 assert model.conv2.module.weight_zero_point == 0
quantizer.quantize_input(input, model.conv2) quantizer.quantize_input(input, model.conv2)
self.assertTrue(torch.allclose(model.conv2.module.scale, torch.tensor([0.04 / 255]))) self.assertTrue(torch.allclose(model.conv2.module.input_scale, torch.tensor([4. / 255])))
self.assertTrue(torch.equal(model.conv2.module.zero_point, torch.tensor([0.]))) self.assertTrue(torch.equal(model.conv2.module.input_zero_point, torch.tensor(0.)))
# range including 0 # range including 0
weight = torch.tensor([[-1, 2], [3, 5]]).float() weight = torch.tensor([[-1, 2], [3, 5]]).float()
model.conv2.module.weight = weight model.conv2.module.weight = weight
quantizer.quantize_weight(model.conv2, input_tensor=input) quantizer.quantize_weight(model.conv2, input_tensor=input)
assert math.isclose(model.conv2.module.scale, 6 / 255, abs_tol=eps) assert math.isclose(model.conv2.module.weight_scale, 6 / 255, abs_tol=eps)
assert model.conv2.module.zero_point in (42, 43) assert model.conv2.module.weight_zero_point in (42, 43)
quantizer.quantize_input(input, model.conv2) quantizer.quantize_input(input, model.conv2)
self.assertTrue(torch.allclose(model.conv2.module.scale, torch.tensor([0.0796 / 255]))) self.assertTrue(torch.allclose(model.conv2.module.input_scale, torch.tensor([4. / 255])))
self.assertTrue(torch.equal(model.conv2.module.zero_point, torch.tensor([0.]))) self.assertTrue(torch.equal(model.conv2.module.input_zero_point, torch.tensor(0.)))
# test value of weight and bias after quantization # test value of weight and bias after quantization
weight = torch.tensor([[1.1287, 2.3456], [3.7814, 5.9723]]) weight = torch.tensor([[1.1287, 2.3456], [3.7814, 5.9723]])
weight_valid = torch.tensor([[1.1242, 2.3421], [3.7707, 5.9723]]) weight_valid = torch.tensor([[1.1242, 2.3421], [3.7707, 5.9723]])
...@@ -385,15 +512,15 @@ class CompressorTestCase(TestCase): ...@@ -385,15 +512,15 @@ class CompressorTestCase(TestCase):
# test ema # test ema
eps = 1e-7 eps = 1e-7
x = torch.tensor([[-0.2, 0], [0.1, 0.2]]) x = torch.tensor([[-0.2, 0], [0.1, 0.2]])
out = model.relu(x) model.relu(x)
assert math.isclose(model.relu.module.tracked_min_output, 0, abs_tol=eps) self.assertTrue(torch.equal(model.relu.module.tracked_min_output, torch.tensor(0.)))
assert math.isclose(model.relu.module.tracked_max_output, 0.002, abs_tol=eps) self.assertTrue(torch.equal(model.relu.module.tracked_max_output, torch.tensor(0.2)))
quantizer.step_with_optimizer() quantizer.step_with_optimizer()
x = torch.tensor([[0.2, 0.4], [0.6, 0.8]]) x = torch.tensor([[0.2, 0.4], [0.6, 0.8]])
out = model.relu(x) model.relu(x)
assert math.isclose(model.relu.module.tracked_min_output, 0.002, abs_tol=eps) self.assertTrue(torch.equal(model.relu.module.tracked_min_output, torch.tensor(0.002)))
assert math.isclose(model.relu.module.tracked_max_output, 0.00998, abs_tol=eps) self.assertTrue(torch.equal(model.relu.module.tracked_max_output, torch.tensor(0.2060)))
def test_torch_quantizer_export(self): def test_torch_quantizer_export(self):
config_list_qat = [{ config_list_qat = [{
...@@ -424,12 +551,15 @@ class CompressorTestCase(TestCase): ...@@ -424,12 +551,15 @@ class CompressorTestCase(TestCase):
}] }]
config_set = [config_list_qat, config_list_dorefa, config_list_bnn] config_set = [config_list_qat, config_list_dorefa, config_list_bnn]
quantize_algorithm_set = [torch_quantizer.QAT_Quantizer, torch_quantizer.DoReFaQuantizer, torch_quantizer.BNNQuantizer] quantize_algorithm_set = [torch_quantizer.QAT_Quantizer, torch_quantizer.DoReFaQuantizer, torch_quantizer.BNNQuantizer]
dummy = torch.randn(1, 1, 28, 28)
for config, quantize_algorithm in zip(config_set, quantize_algorithm_set): for config, quantize_algorithm in zip(config_set, quantize_algorithm_set):
model = TorchModel() model = TorchModel()
model.relu = torch.nn.ReLU() model.relu = torch.nn.ReLU()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
quantizer = quantize_algorithm(model, config, optimizer) if quantize_algorithm == torch_quantizer.QAT_Quantizer:
quantizer = quantize_algorithm(model, config, optimizer, dummy)
else:
quantizer = quantize_algorithm(model, config, optimizer)
quantizer.compress() quantizer.compress()
x = torch.rand((1, 1, 28, 28), requires_grad=True) x = torch.rand((1, 1, 28, 28), requires_grad=True)
...@@ -461,7 +591,11 @@ class CompressorTestCase(TestCase): ...@@ -461,7 +591,11 @@ class CompressorTestCase(TestCase):
model = TorchModel().eval() model = TorchModel().eval()
model.relu = torch.nn.ReLU() model.relu = torch.nn.ReLU()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
quantizer = quantize_algorithm(model, configure_list, optimizer) if quantize_algorithm == torch_quantizer.QAT_Quantizer:
dummy = torch.randn(1, 1, 28, 28)
quantizer = quantize_algorithm(model, configure_list, optimizer, dummy_input=dummy)
else:
quantizer = quantize_algorithm(model, configure_list, optimizer)
quantizer.compress() quantizer.compress()
if calibration_config is not None: if calibration_config is not None:
quantizer.load_calibration_config(calibration_config) quantizer.load_calibration_config(calibration_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment