Unverified Commit 7b16db17 authored by lin bin's avatar lin bin Committed by GitHub
Browse files

[Quantization] support bn-folding for lsq (#4148)

parent 19914055
...@@ -635,40 +635,6 @@ class QAT_Quantizer(Quantizer): ...@@ -635,40 +635,6 @@ class QAT_Quantizer(Quantizer):
return calibration_config return calibration_config
def fold_bn(self, *inputs, wrapper):
"""
Simulate batch normalization folding in the training graph. Folded weight and bias are
returned for the following operations.
Parameters
----------
inputs : tuple of torch.Tensor
inputs for the module
wrapper : QuantizerModuleWrapper
the wrapper for origin module
Returns
-------
Tuple of torch.Tensor
"""
module = wrapper.module
bn_module = wrapper.bn_module
with torch.no_grad():
output = module(*inputs)
_ = bn_module(output)
running_mean = bn_module.running_mean
running_var = torch.sqrt(bn_module.running_var + bn_module.eps)
bn_weight = bn_module.weight
bn_bias = bn_module.bias
dimensions = len(module.weight.shape)
shape = [-1] + [1] * (dimensions - 1)
new_weight = module.old_weight * bn_weight.reshape(shape) / running_var.reshape(shape)
if hasattr(module, 'old_bias'):
new_bias = bn_bias + (module.old_bias - running_mean) / running_var * bn_weight
else:
new_bias = bn_bias - running_mean / running_var * bn_weight
return new_weight, new_bias
def step_with_optimizer(self): def step_with_optimizer(self):
""" """
override `compressor` `step` method, quantization only happens after certain number of steps override `compressor` `step` method, quantization only happens after certain number of steps
...@@ -890,7 +856,7 @@ class LsqQuantizer(Quantizer): ...@@ -890,7 +856,7 @@ class LsqQuantizer(Quantizer):
https://arxiv.org/pdf/1902.08153.pdf https://arxiv.org/pdf/1902.08153.pdf
""" """
def __init__(self, model, config_list, optimizer): def __init__(self, model, config_list, optimizer, dummy_input=None):
""" """
Parameters Parameters
---------- ----------
...@@ -909,9 +875,13 @@ class LsqQuantizer(Quantizer): ...@@ -909,9 +875,13 @@ class LsqQuantizer(Quantizer):
state where output quantization ranges do not exclude a significant fraction of values, default value is 0 state where output quantization ranges do not exclude a significant fraction of values, default value is 0
- op_types : list of string - op_types : list of string
types of nn.module you want to apply quantization, eg. 'Conv2d' types of nn.module you want to apply quantization, eg. 'Conv2d'
- dummy_input : tuple of tensor
inputs to the model, which are used to get the graph of the module. The graph is used to find
Conv-Bn patterns. And then the batch normalization folding would be enabled. If dummy_input is not
given, the batch normalization folding would be disabled.
""" """
assert isinstance(optimizer, torch.optim.Optimizer), "unrecognized optimizer type" assert isinstance(optimizer, torch.optim.Optimizer), "unrecognized optimizer type"
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer, dummy_input)
device = next(model.parameters()).device device = next(model.parameters()).device
self.quant_grad = QuantForward() self.quant_grad = QuantForward()
modules_to_compress = self.get_modules_to_compress() modules_to_compress = self.get_modules_to_compress()
...@@ -1057,6 +1027,19 @@ class LsqQuantizer(Quantizer): ...@@ -1057,6 +1027,19 @@ class LsqQuantizer(Quantizer):
abs_max_input = float(module.input_scale * module.input_qmax) abs_max_input = float(module.input_scale * module.input_qmax)
calibration_config[name]['tracked_min_input'] = -abs_max_input calibration_config[name]['tracked_min_input'] = -abs_max_input
calibration_config[name]['tracked_max_input'] = abs_max_input calibration_config[name]['tracked_max_input'] = abs_max_input
actual_weight = getattr(module, 'old_weight', None)
if actual_weight is None:
logger.warning("Can not recover weight for layer %s. "
"This may lead to a wrong accuracy performance on the backend.", name)
delattr(module, 'weight')
module.register_parameter('weight', actual_weight)
if hasattr(module, BN_FOLD_TAG):
actual_bias = getattr(module, 'old_bias', None)
delattr(module, 'bias')
if actual_bias is not None:
module.register_parameter('bias', actual_bias)
else:
setattr(module, 'bias', None)
if hasattr(module, 'output_bits'): if hasattr(module, 'output_bits'):
calibration_config[name]['output_bits'] = int(module.output_bits) calibration_config[name]['output_bits'] = int(module.output_bits)
abs_max_output = float(module.output_scale * module.output_qmax) abs_max_output = float(module.output_scale * module.output_qmax)
...@@ -1074,7 +1057,7 @@ class LsqQuantizer(Quantizer): ...@@ -1074,7 +1057,7 @@ class LsqQuantizer(Quantizer):
delete redundant parameters in quantize module delete redundant parameters in quantize module
""" """
del_attr_list = ['old_weight', 'tracked_min_input', 'tracked_max_input', 'tracked_min_output', \ del_attr_list = ['old_weight', 'tracked_min_input', 'tracked_max_input', 'tracked_min_output', \
'tracked_max_output', 'output_scale', 'input_scale', 'weight_scale','weight_bits', 'output_bits', 'input_bits'] 'tracked_max_output', 'output_scale', 'input_scale', 'weight_scale','weight_bits', 'output_bits', 'input_bits', 'BN_FOLD_TAG']
for attr in del_attr_list: for attr in del_attr_list:
if hasattr(module, attr): if hasattr(module, attr):
delattr(module, attr) delattr(module, attr)
......
...@@ -658,6 +658,40 @@ class Quantizer(Compressor): ...@@ -658,6 +658,40 @@ class Quantizer(Compressor):
""" """
raise NotImplementedError('Quantizer must overload quantize_input()') raise NotImplementedError('Quantizer must overload quantize_input()')
def fold_bn(self, *inputs, wrapper):
"""
Simulate batch normalization folding in the training graph. Folded weight and bias are
returned for the following operations.
Parameters
----------
inputs : tuple of torch.Tensor
inputs for the module
wrapper : QuantizerModuleWrapper
the wrapper for origin module
Returns
-------
Tuple of torch.Tensor
"""
module = wrapper.module
bn_module = wrapper.bn_module
with torch.no_grad():
output = module(*inputs)
_ = bn_module(output)
running_mean = bn_module.running_mean
running_var = torch.sqrt(bn_module.running_var + bn_module.eps)
bn_weight = bn_module.weight
bn_bias = bn_module.bias
dimensions = len(module.weight.shape)
shape = [-1] + [1] * (dimensions - 1)
new_weight = module.old_weight * bn_weight.reshape(shape) / running_var.reshape(shape)
if hasattr(module, 'old_bias'):
new_bias = bn_bias + (module.old_bias - running_mean) / running_var * bn_weight
else:
new_bias = bn_bias - running_mean / running_var * bn_weight
return new_weight, new_bias
def _wrap_modules(self, layer, config): def _wrap_modules(self, layer, config):
""" """
Create a wrapper forward function to replace the original one. Create a wrapper forward function to replace the original one.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment