"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "ed46b8c1da3656a32aa0e98ea5dd0d71cd0e6716"
Unverified Commit 0a6c234a authored by lin bin's avatar lin bin Committed by GitHub
Browse files

Add bias quantization in QAT and refactor the code of weight quantization (#2914)

parent 6126960c
...@@ -481,11 +481,10 @@ class QuantizerModuleWrapper(torch.nn.Module): ...@@ -481,11 +481,10 @@ class QuantizerModuleWrapper(torch.nn.Module):
self) self)
if 'weight' in self.config['quant_types'] and _check_weight(self.module): if 'weight' in self.config['quant_types'] and _check_weight(self.module):
new_weight = self.quantizer.quant_grad.apply( self.quantizer.quant_grad.apply(
self.module.old_weight, self.module.old_weight,
QuantType.QUANT_WEIGHT, QuantType.QUANT_WEIGHT,
self) self)
self.module.weight = new_weight
result = self.module(*inputs) result = self.module(*inputs)
else: else:
result = self.module(*inputs) result = self.module(*inputs)
...@@ -617,7 +616,7 @@ class QuantGrad(torch.autograd.Function): ...@@ -617,7 +616,7 @@ class QuantGrad(torch.autograd.Function):
if quant_type == QuantType.QUANT_INPUT: if quant_type == QuantType.QUANT_INPUT:
return wrapper.quantizer.quantize_input(tensor, wrapper, **kwargs) return wrapper.quantizer.quantize_input(tensor, wrapper, **kwargs)
elif quant_type == QuantType.QUANT_WEIGHT: elif quant_type == QuantType.QUANT_WEIGHT:
return wrapper.quantizer.quantize_weight(tensor, wrapper, **kwargs) return wrapper.quantizer.quantize_weight(wrapper, **kwargs)
elif quant_type == QuantType.QUANT_OUTPUT: elif quant_type == QuantType.QUANT_OUTPUT:
return wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs) return wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs)
else: else:
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import logging import logging
import copy
import torch import torch
from schema import Schema, And, Or, Optional from schema import Schema, And, Or, Optional
from ..utils.config_validation import CompressorSchema from ..utils.config_validation import CompressorSchema
...@@ -15,6 +16,7 @@ logger = logging.getLogger(__name__) ...@@ -15,6 +16,7 @@ logger = logging.getLogger(__name__)
class NaiveQuantizer(Quantizer): class NaiveQuantizer(Quantizer):
"""quantize weight to 8 bits """quantize weight to 8 bits
""" """
def __init__(self, model, config_list, optimizer=None): def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
self.layer_scale = {} self.layer_scale = {}
...@@ -29,13 +31,15 @@ class NaiveQuantizer(Quantizer): ...@@ -29,13 +31,15 @@ class NaiveQuantizer(Quantizer):
schema.validate(config_list) schema.validate(config_list)
def quantize_weight(self, weight, wrapper, **kwargs): def quantize_weight(self, wrapper, **kwargs):
weight = copy.deepcopy(wrapper.module.old_weight.data)
new_scale = weight.abs().max() / 127 new_scale = weight.abs().max() / 127
scale = max(self.layer_scale.get(wrapper.name, 0), new_scale) scale = max(self.layer_scale.get(wrapper.name, 0), new_scale)
self.layer_scale[wrapper.name] = scale self.layer_scale[wrapper.name] = scale
orig_type = weight.type() # TODO: user layer orig_type = weight.type() # TODO: user layer
return weight.div(scale).type(torch.int8).type(orig_type).mul(scale) weight = weight.div(scale).type(torch.int8).type(orig_type).mul(scale)
wrapper.module.weight = weight
return weight
def update_ema(biased_ema, value, decay, step): def update_ema(biased_ema, value, decay, step):
""" """
...@@ -60,6 +64,7 @@ def update_ema(biased_ema, value, decay, step): ...@@ -60,6 +64,7 @@ def update_ema(biased_ema, value, decay, step):
unbiased_ema = biased_ema / (1 - decay ** step) # Bias correction unbiased_ema = biased_ema / (1 - decay ** step) # Bias correction
return biased_ema, unbiased_ema return biased_ema, unbiased_ema
def update_quantization_param(bits, rmin, rmax): def update_quantization_param(bits, rmin, rmax):
""" """
calculate the `zero_point` and `scale`. calculate the `zero_point` and `scale`.
...@@ -116,6 +121,7 @@ class QAT_Quantizer(Quantizer): ...@@ -116,6 +121,7 @@ class QAT_Quantizer(Quantizer):
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
""" """
def __init__(self, model, config_list, optimizer=None): def __init__(self, model, config_list, optimizer=None):
""" """
Parameters Parameters
...@@ -215,20 +221,35 @@ class QAT_Quantizer(Quantizer): ...@@ -215,20 +221,35 @@ class QAT_Quantizer(Quantizer):
real_val = op.scale * (quantized_val - op.zero_point) real_val = op.scale * (quantized_val - op.zero_point)
return real_val return real_val
def quantize_weight(self, weight, wrapper, **kwargs): def quantize_weight(self, wrapper, **kwargs):
config = wrapper.config config = wrapper.config
module = wrapper.module module = wrapper.module
weight = copy.deepcopy(wrapper.module.old_weight.data)
weight_bits = get_bits_length(config, 'weight') weight_bits = get_bits_length(config, 'weight')
quant_start_step = config.get('quant_start_step', 0) quant_start_step = config.get('quant_start_step', 0)
assert weight_bits >= 1, "quant bits length should be at least 1" assert weight_bits >= 1, "quant bits length should be at least 1"
if quant_start_step > self.steps: if quant_start_step > self.steps:
return weight return weight
# if bias exists, quantize bias to uint32
if hasattr(wrapper.module, 'bias') and wrapper.module.bias is not None:
bias = wrapper.module.bias.data
bias_bits = 32
rmin, rmax = torch.min(bias), torch.max(bias)
module.scale, module.zero_point = update_quantization_param(bias_bits, rmin, rmax)
bias = self._quantize(bias_bits, module, bias)
bias = self._dequantize(module, bias)
wrapper.module.bias.data = bias
# quantize weight
rmin, rmax = torch.min(weight), torch.max(weight) rmin, rmax = torch.min(weight), torch.max(weight)
module.scale, module.zero_point = update_quantization_param(weight_bits, rmin, rmax) module.scale, module.zero_point = update_quantization_param(weight_bits, rmin, rmax)
out = self._quantize(weight_bits, module, weight) weight = self._quantize(weight_bits, module, weight)
out = self._dequantize(module, out) weight = self._dequantize(module, weight)
return out wrapper.module.weight = weight
return weight
def quantize_output(self, output, wrapper, **kwargs): def quantize_output(self, output, wrapper, **kwargs):
config = wrapper.config config = wrapper.config
...@@ -241,8 +262,10 @@ class QAT_Quantizer(Quantizer): ...@@ -241,8 +262,10 @@ class QAT_Quantizer(Quantizer):
return output return output
current_min, current_max = torch.min(output), torch.max(output) current_min, current_max = torch.min(output), torch.max(output)
module.tracked_min_biased, module.tracked_min = update_ema(module.tracked_min_biased, current_min, module.ema_decay, self.steps) module.tracked_min_biased, module.tracked_min = update_ema(module.tracked_min_biased, current_min,
module.tracked_max_biased, module.tracked_max = update_ema(module.tracked_max_biased, current_max, module.ema_decay, self.steps) module.ema_decay, self.steps)
module.tracked_max_biased, module.tracked_max = update_ema(module.tracked_max_biased, current_max,
module.ema_decay, self.steps)
module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min, module.tracked_max) module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min, module.tracked_max)
out = self._quantize(output_bits, module, output) out = self._quantize(output_bits, module, output)
out = self._dequantize(module, out) out = self._dequantize(module, out)
...@@ -264,6 +287,7 @@ class DoReFaQuantizer(Quantizer): ...@@ -264,6 +287,7 @@ class DoReFaQuantizer(Quantizer):
Zhou et al., DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients Zhou et al., DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients
(https://arxiv.org/abs/1606.06160) (https://arxiv.org/abs/1606.06160)
""" """
def __init__(self, model, config_list, optimizer=None): def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
...@@ -287,17 +311,20 @@ class DoReFaQuantizer(Quantizer): ...@@ -287,17 +311,20 @@ class DoReFaQuantizer(Quantizer):
schema.validate(config_list) schema.validate(config_list)
def quantize_weight(self, weight, wrapper, **kwargs): def quantize_weight(self, wrapper, **kwargs):
weight = copy.deepcopy(wrapper.module.old_weight.data)
weight_bits = get_bits_length(wrapper.config, 'weight') weight_bits = get_bits_length(wrapper.config, 'weight')
out = weight.tanh() weight = weight.tanh()
out = out / (2 * out.abs().max()) + 0.5 weight = weight / (2 * weight.abs().max()) + 0.5
out = self.quantize(out, weight_bits) weight = self.quantize(weight, weight_bits)
out = 2 * out -1 weight = 2 * weight - 1
return out wrapper.module.weight = weight
# wrapper.module.weight.data = weight
return weight
def quantize(self, input_ri, q_bits): def quantize(self, input_ri, q_bits):
scale = pow(2, q_bits)-1 scale = pow(2, q_bits) - 1
output = torch.round(input_ri*scale)/scale output = torch.round(input_ri * scale) / scale
return output return output
...@@ -314,6 +341,7 @@ class BNNQuantizer(Quantizer): ...@@ -314,6 +341,7 @@ class BNNQuantizer(Quantizer):
Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1 Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1
(https://arxiv.org/abs/1602.02830) (https://arxiv.org/abs/1602.02830)
""" """
def __init__(self, model, config_list, optimizer=None): def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
self.quant_grad = ClipGrad self.quant_grad = ClipGrad
...@@ -339,11 +367,13 @@ class BNNQuantizer(Quantizer): ...@@ -339,11 +367,13 @@ class BNNQuantizer(Quantizer):
schema.validate(config_list) schema.validate(config_list)
def quantize_weight(self, weight, wrapper, **kwargs): def quantize_weight(self, wrapper, **kwargs):
out = torch.sign(weight) weight = copy.deepcopy(wrapper.module.old_weight.data)
weight = torch.sign(weight)
# remove zeros # remove zeros
out[out == 0] = 1 weight[weight == 0] = 1
return out wrapper.module.weight = weight
return weight
def quantize_output(self, output, wrapper, **kwargs): def quantize_output(self, output, wrapper, **kwargs):
out = torch.sign(output) out = torch.sign(output)
......
...@@ -234,20 +234,34 @@ class CompressorTestCase(TestCase): ...@@ -234,20 +234,34 @@ class CompressorTestCase(TestCase):
model.relu = torch.nn.ReLU() model.relu = torch.nn.ReLU()
quantizer = torch_compressor.QAT_Quantizer(model, config_list) quantizer = torch_compressor.QAT_Quantizer(model, config_list)
quantizer.compress() quantizer.compress()
# test quantize # test quantize
# range not including 0 # range not including 0
eps = 1e-7 eps = 1e-7
weight = torch.tensor([[1, 2], [3, 5]]).float() weight = torch.tensor([[1, 2], [3, 5]]).float()
quantize_weight = quantizer.quantize_weight(weight, model.conv2) model.conv2.module.old_weight.data = weight
quantizer.quantize_weight(model.conv2)
assert math.isclose(model.conv2.module.scale, 5 / 255, abs_tol=eps) assert math.isclose(model.conv2.module.scale, 5 / 255, abs_tol=eps)
assert model.conv2.module.zero_point == 0 assert model.conv2.module.zero_point == 0
# range including 0 # range including 0
weight = torch.tensor([[-1, 2], [3, 5]]).float() weight = torch.tensor([[-1, 2], [3, 5]]).float()
quantize_weight = quantizer.quantize_weight(weight, model.conv2) model.conv2.module.old_weight.data = weight
quantizer.quantize_weight(model.conv2)
assert math.isclose(model.conv2.module.scale, 6 / 255, abs_tol=eps) assert math.isclose(model.conv2.module.scale, 6 / 255, abs_tol=eps)
assert model.conv2.module.zero_point in (42, 43) assert model.conv2.module.zero_point in (42, 43)
# test value of weight and bias after quantization
weight = torch.tensor([[1.1287, 2.3456], [3.7814, 5.9723]])
weight_valid = torch.tensor([[1.1242, 2.3421], [3.7707, 5.9723]])
bias = torch.tensor([2.3432, 3.4342, 1.3414, 5.2341])
bias_valid = torch.tensor([2.3432, 3.4342, 1.3414, 5.2341])
model.conv2.module.old_weight.data = weight
model.conv2.module.bias.data = bias
quantizer.quantize_weight(model.conv2)
assert torch.all(torch.isclose(model.conv2.module.weight.data, weight_valid, rtol=1e-4))
assert torch.all(torch.isclose(model.conv2.module.bias.data, bias_valid, rtol=1e-7))
# test ema # test ema
eps = 1e-7
x = torch.tensor([[-0.2, 0], [0.1, 0.2]]) x = torch.tensor([[-0.2, 0], [0.1, 0.2]])
out = model.relu(x) out = model.relu(x)
assert math.isclose(model.relu.module.tracked_min_biased, 0, abs_tol=eps) assert math.isclose(model.relu.module.tracked_min_biased, 0, abs_tol=eps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment