Unverified Commit 7fc5af07 authored by chenbohua3's avatar chenbohua3 Committed by GitHub
Browse files

Add batch normalization folding to QAT quantizer (#3911)

parent 441c5da5
...@@ -82,10 +82,25 @@ configuration needed by this algorithm : ...@@ -82,10 +82,25 @@ configuration needed by this algorithm :
disable quantization until model are run by certain number of steps, this allows the network to enter a more stable disable quantization until model are run by certain number of steps, this allows the network to enter a more stable
state where activation quantization ranges do not exclude a signicant fraction of values, default value is 0 state where activation quantization ranges do not exclude a signicant fraction of values, default value is 0
note Batch normalization folding
^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^
batch normalization folding is currently not supported. Batch normalization folding is supported in QAT quantizer. It can be easily enabled by passing an argument `dummy_input` to
the quantizer, like:
.. code-block:: python
# assume your model takes an input of shape (1, 1, 28, 28)
# and dummy_input must be on the same device as the model
dummy_input = torch.randn(1, 1, 28, 28)
# pass the dummy_input to the quantizer
quantizer = QAT_Quantizer(model, config_list, dummy_input=dummy_input)
The quantizer will automatically detect Conv-BN patterns and simulate batch normalization folding process in the training
graph. Note that when the quantization aware training process is finished, the folded weight/bias would be restored after calling
`quantizer.export_model`.
---- ----
......
...@@ -6,7 +6,7 @@ import copy ...@@ -6,7 +6,7 @@ import copy
import torch import torch
from schema import Schema, And, Or, Optional from schema import Schema, And, Or, Optional
from nni.compression.pytorch.utils.config_validation import QuantizerSchema from nni.compression.pytorch.utils.config_validation import QuantizerSchema
from nni.compression.pytorch.compressor import Quantizer, QuantForward, QuantGrad, QuantType from nni.compression.pytorch.compressor import BN_FOLD_TAG, Quantizer, QuantForward, QuantGrad, QuantType
__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer', 'LsqQuantizer'] __all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer', 'LsqQuantizer']
...@@ -126,7 +126,7 @@ class QAT_Quantizer(Quantizer): ...@@ -126,7 +126,7 @@ class QAT_Quantizer(Quantizer):
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
""" """
def __init__(self, model, config_list, optimizer=None): def __init__(self, model, config_list, optimizer=None, dummy_input=None):
""" """
Parameters Parameters
---------- ----------
...@@ -145,8 +145,13 @@ class QAT_Quantizer(Quantizer): ...@@ -145,8 +145,13 @@ class QAT_Quantizer(Quantizer):
state where activation quantization ranges do not exclude a significant fraction of values, default value is 0 state where activation quantization ranges do not exclude a significant fraction of values, default value is 0
- op_types : list of string - op_types : list of string
types of nn.module you want to apply quantization, eg. 'Conv2d' types of nn.module you want to apply quantization, eg. 'Conv2d'
- dummy_input : tuple of tensor
inputs to the model, which are used to get the graph of the module. The graph is used to find
Conv-Bn patterns. And then the batch normalization folding would be enabled. If dummy_input is not
given, the batch normalization folding would be disabled.
""" """
super().__init__(model, config_list, optimizer)
super().__init__(model, config_list, optimizer, dummy_input)
self.quant_grad = QATGrad.apply self.quant_grad = QATGrad.apply
modules_to_compress = self.get_modules_to_compress() modules_to_compress = self.get_modules_to_compress()
device = next(model.parameters()).device device = next(model.parameters()).device
...@@ -169,8 +174,9 @@ class QAT_Quantizer(Quantizer): ...@@ -169,8 +174,9 @@ class QAT_Quantizer(Quantizer):
""" """
delete redundant parameters in quantize module delete redundant parameters in quantize module
""" """
del_attr_list = ['old_weight', 'ema_decay', 'tracked_min_activation', 'tracked_max_activation', 'tracked_min_input', \ del_attr_list = ['old_weight', 'old_bias', 'ema_decay', 'tracked_min_activation', 'tracked_max_activation',
'tracked_max_input', 'scale', 'zero_point', 'weight_bit', 'activation_bit'] 'tracked_min_input', 'tracked_max_input', 'scale', 'zero_point', 'weight_bit',
'activation_bit', 'BN_FOLD_TAG']
for attr in del_attr_list: for attr in del_attr_list:
if hasattr(module, attr): if hasattr(module, attr):
delattr(module, attr) delattr(module, attr)
...@@ -334,6 +340,23 @@ class QAT_Quantizer(Quantizer): ...@@ -334,6 +340,23 @@ class QAT_Quantizer(Quantizer):
calibration_config[name]['weight_bit'] = int(module.weight_bit) calibration_config[name]['weight_bit'] = int(module.weight_bit)
calibration_config[name]['tracked_min_input'] = float(module.tracked_min_input) calibration_config[name]['tracked_min_input'] = float(module.tracked_min_input)
calibration_config[name]['tracked_max_input'] = float(module.tracked_max_input) calibration_config[name]['tracked_max_input'] = float(module.tracked_max_input)
# Recover weight/bias for batch normalization folding
if hasattr(module, BN_FOLD_TAG):
actual_weight = getattr(module, 'old_weight', None)
if actual_weight is None:
logger.warning("Can not recover weight for layer %s. "
"This may lead to a wrong accuracy performance on the backend.", name)
delattr(module, 'weight')
module.register_parameter('weight', actual_weight)
actual_bias = getattr(module, 'old_bias', None)
delattr(module, 'bias')
if actual_bias is not None:
module.register_parameter('bias', actual_bias)
else:
setattr(module, 'bias', None)
if hasattr(module, 'activation_bit'): if hasattr(module, 'activation_bit'):
calibration_config[name]['activation_bit'] = int(module.activation_bit) calibration_config[name]['activation_bit'] = int(module.activation_bit)
calibration_config[name]['tracked_min_activation'] = float(module.tracked_min_activation) calibration_config[name]['tracked_min_activation'] = float(module.tracked_min_activation)
...@@ -344,9 +367,39 @@ class QAT_Quantizer(Quantizer): ...@@ -344,9 +367,39 @@ class QAT_Quantizer(Quantizer):
return calibration_config return calibration_config
def fold_bn(self, config, **kwargs): def fold_bn(self, *inputs, wrapper):
# TODO simulate folded weight """
pass Simulate batch normalization folding in the training graph. Folded weight and bias are
returned for the following operations.
Parameters
----------
inputs : tuple of torch.Tensor
inputs for the module
wrapper : QuantizerModuleWrapper
the wrapper for origin module
Returns
-------
Tuple of torch.Tensor
"""
module = wrapper.module
bn_module = wrapper.bn_module
with torch.no_grad():
output = module(*inputs)
_ = bn_module(output)
running_mean = bn_module.running_mean
running_var = torch.sqrt(bn_module.running_var + bn_module.eps)
bn_weight = bn_module.weight
bn_bias = bn_module.bias
dimensions = len(module.weight.shape)
shape = [-1] + [1] * (dimensions - 1)
new_weight = module.old_weight * bn_weight.reshape(shape) / running_var.reshape(shape)
if hasattr(module, 'old_bias'):
new_bias = bn_bias + (module.old_bias - running_mean) / running_var * bn_weight
else:
new_bias = bn_bias - running_mean / running_var * bn_weight
return new_weight, new_bias
def step_with_optimizer(self): def step_with_optimizer(self):
""" """
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import types import types
import logging import logging
import torch import torch
from nni.common.graph_utils import build_module_graph
from . import default_layers from . import default_layers
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -463,7 +464,7 @@ class Pruner(Compressor): ...@@ -463,7 +464,7 @@ class Pruner(Compressor):
class QuantizerModuleWrapper(torch.nn.Module): class QuantizerModuleWrapper(torch.nn.Module):
def __init__(self, module, module_name, module_type, config, quantizer): def __init__(self, module, module_name, module_type, config, quantizer, bn_module=None):
""" """
Wrap an module to enable data parallel, forward method customization and buffer registeration. Wrap an module to enable data parallel, forward method customization and buffer registeration.
...@@ -479,6 +480,8 @@ class QuantizerModuleWrapper(torch.nn.Module): ...@@ -479,6 +480,8 @@ class QuantizerModuleWrapper(torch.nn.Module):
the type of the module to compress the type of the module to compress
quantizer :quantizer quantizer :quantizer
the quantizer used to calculate mask the quantizer used to calculate mask
bn_module : torch.nn.Module
batch norm layer corresponding to current module, used for simulating batch normalization folding
""" """
super().__init__() super().__init__()
# origin layer information # origin layer information
...@@ -488,6 +491,7 @@ class QuantizerModuleWrapper(torch.nn.Module): ...@@ -488,6 +491,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
# config and pruner # config and pruner
self.config = config self.config = config
self.quantizer = quantizer self.quantizer = quantizer
self.bn_module = bn_module
# register buffer and parameter # register buffer and parameter
# old_weight is used to store origin weight and weight is used to store quantized weight # old_weight is used to store origin weight and weight is used to store quantized weight
...@@ -501,6 +505,17 @@ class QuantizerModuleWrapper(torch.nn.Module): ...@@ -501,6 +505,17 @@ class QuantizerModuleWrapper(torch.nn.Module):
delattr(self.module, 'weight') delattr(self.module, 'weight')
self.module.register_buffer('weight', self.module.old_weight) self.module.register_buffer('weight', self.module.old_weight)
# for batch normalization folding
if self.bn_module is not None:
if _check_bias(self.module):
self.module.register_parameter('old_bias', torch.nn.Parameter(self.module.bias))
init_tensor = self.module.old_bias
else:
init_tensor = torch.zeros_like(self.bn_module.weight)
delattr(self.module, 'bias')
self.module.register_buffer('bias', init_tensor)
setattr(module, BN_FOLD_TAG, True)
def forward(self, *inputs): def forward(self, *inputs):
if 'input' in self.config['quant_types']: if 'input' in self.config['quant_types']:
inputs = self.quantizer.quant_grad( inputs = self.quantizer.quant_grad(
...@@ -509,13 +524,20 @@ class QuantizerModuleWrapper(torch.nn.Module): ...@@ -509,13 +524,20 @@ class QuantizerModuleWrapper(torch.nn.Module):
self) self)
if 'weight' in self.config['quant_types'] and _check_weight(self.module): if 'weight' in self.config['quant_types'] and _check_weight(self.module):
if self.bn_module is not None:
# simulate batch normalization folding
new_weight, new_bias = self.quantizer.fold_bn(*inputs, wrapper=self)
self.module.bias = new_bias
self.module.weight = new_weight
else:
new_weight = self.module.old_weight
self.quantizer.quant_grad( self.quantizer.quant_grad(
self.module.old_weight, new_weight,
QuantType.QUANT_WEIGHT, QuantType.QUANT_WEIGHT,
self, inputs[0]) self, inputs[0])
result = self.module(*inputs)
else: result = self.module(*inputs)
result = self.module(*inputs)
if 'output' in self.config['quant_types']: if 'output' in self.config['quant_types']:
result = self.quantizer.quant_grad( result = self.quantizer.quant_grad(
...@@ -525,12 +547,35 @@ class QuantizerModuleWrapper(torch.nn.Module): ...@@ -525,12 +547,35 @@ class QuantizerModuleWrapper(torch.nn.Module):
return result return result
class QuantizerIdentityWrapper(torch.nn.Module):
def __init__(self, module, module_name):
"""
Used to wrap modules that should be treated as torch.Identity
Parameters
----------
module : pytorch module
the module to be wrapped
module_name : str
the name of the module to wrapped, wrapper module shares same name
"""
super().__init__()
self.module = module
self.module_name = module_name
def forward(self, x):
return x
class Quantizer(Compressor): class Quantizer(Compressor):
""" """
Base quantizer for pytorch quantizer Base quantizer for pytorch quantizer
""" """
def __init__(self, model, config_list, optimizer=None): def __init__(self, model, config_list, optimizer=None, dummy_input=None):
self.identity_wrappers = []
self.conv_bn_patterns = {}
self.find_conv_bn_patterns(model, dummy_input)
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
self.quant_grad = QuantGrad.apply self.quant_grad = QuantGrad.apply
if self.optimizer is not None: if self.optimizer is not None:
...@@ -540,6 +585,10 @@ class Quantizer(Compressor): ...@@ -540,6 +585,10 @@ class Quantizer(Compressor):
# old_weight is registered to keep track of weight before quantization # old_weight is registered to keep track of weight before quantization
# and it is trainable, therefore, it should be added to optimizer. # and it is trainable, therefore, it should be added to optimizer.
self.optimizer.add_param_group({"params": wrapper.module.old_weight}) self.optimizer.add_param_group({"params": wrapper.module.old_weight})
# This is for conv with bias + bn. Although this situation is relatively rare,
# we still need to deal with the old_bias when it occurs
if hasattr(wrapper.module, "old_bias"):
self.optimizer.add_param_group({"params": getattr(wrapper.module, "old_bias")})
def quantize_weight(self, wrapper, **kwargs): def quantize_weight(self, wrapper, **kwargs):
""" """
...@@ -597,7 +646,36 @@ class Quantizer(Compressor): ...@@ -597,7 +646,36 @@ class Quantizer(Compressor):
for quant_type in config['quant_types']: for quant_type in config['quant_types']:
assert quant_type in config['quant_bits'], 'bits length for %s must be specified in quant_bits dict' % quant_type assert quant_type in config['quant_bits'], 'bits length for %s must be specified in quant_bits dict' % quant_type
return QuantizerModuleWrapper(layer.module, layer.name, layer.type, config, self) # bound bn module to corresponding conv module
bn_module = None
if layer.name in self.conv_bn_patterns:
bn_module_name = self.conv_bn_patterns[layer.name]
for name, module in self.bound_model.named_modules():
if name == bn_module_name:
bn_module = module
break
assert bn_module is not None, "BN module corresponding to layer {} is not found".format(layer.name)
self.identity_wrappers.append(QuantizerIdentityWrapper(bn_module, bn_module_name))
return QuantizerModuleWrapper(layer.module, layer.name, layer.type, config, self, bn_module)
def _wrap_model(self):
"""
wrap all modules that needed to be compressed
"""
# wrap folded bn in order to bypass its forward process
for wrapper in reversed(self.identity_wrappers):
_setattr(self.bound_model, wrapper.module_name, wrapper)
super()._wrap_model()
def _unwrap_model(self):
"""
unwrap all modules that needed to be compressed
"""
for wrapper in self.identity_wrappers:
_setattr(self.bound_model, wrapper.module_name, wrapper.module)
super()._unwrap_model()
def export_model_save(self, model, model_path, calibration_config=None, calibration_path=None, onnx_path=None, def export_model_save(self, model, model_path, calibration_config=None, calibration_path=None, onnx_path=None,
input_shape=None, device=None): input_shape=None, device=None):
...@@ -660,6 +738,30 @@ class Quantizer(Compressor): ...@@ -660,6 +738,30 @@ class Quantizer(Compressor):
""" """
raise NotImplementedError('Quantizer must overload export_model()') raise NotImplementedError('Quantizer must overload export_model()')
def find_conv_bn_patterns(self, model, dummy_input):
"""
Find all Conv-BN patterns, used for batch normalization folding
Parameters
----------
model : torch.nn.Module
model to be analyzed.
dummy_input : tupel of torch.tensor
inputs to the model, used for generating the torchscript
"""
if dummy_input is None:
_logger.debug("Model inputs are not given, batch normalization folding is disabled")
return
graph = build_module_graph(model, dummy_input)
for node_group in graph.nodes_py.nodes_op:
if node_group.op_type in BN_FOLD_OP:
successors = graph.find_successors(node_group.unique_name)
successors = [graph.name_to_node[x] for x in successors]
for successor in successors:
if successor.op_type == 'BatchNorm2d':
self.conv_bn_patterns[node_group.name] = successor.name
def step_with_optimizer(self): def step_with_optimizer(self):
pass pass
...@@ -677,6 +779,9 @@ QType_Dict = { ...@@ -677,6 +779,9 @@ QType_Dict = {
2: "output" 2: "output"
} }
BN_FOLD_OP = ["Conv2d"]
BN_FOLD_TAG = 'BN_FOLD_TAG'
class QuantGrad(torch.autograd.Function): class QuantGrad(torch.autograd.Function):
""" """
Base class for overriding backward function of quantization operation. Base class for overriding backward function of quantization operation.
...@@ -773,6 +878,12 @@ def _check_weight(module): ...@@ -773,6 +878,12 @@ def _check_weight(module):
except AttributeError: except AttributeError:
return False return False
def _check_bias(module):
try:
return isinstance(module.bias.data, torch.Tensor)
except AttributeError:
return False
def quantize_helper(tensor, quant_type, wrapper, input_tensor=None, **kwargs): def quantize_helper(tensor, quant_type, wrapper, input_tensor=None, **kwargs):
if quant_type == QuantType.QUANT_INPUT: if quant_type == QuantType.QUANT_INPUT:
output = wrapper.quantizer.quantize_input(*tensor, wrapper=wrapper, **kwargs) output = wrapper.quantizer.quantize_input(*tensor, wrapper=wrapper, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment