"src/vscode:/vscode.git/clone" did not exist on "924225ed5f393fc620344e1e907769209ed11f06"
Unverified Commit 0a6c234a authored by lin bin's avatar lin bin Committed by GitHub
Browse files

Add bias quantization in QAT and refactor the code of weight quantization (#2914)

parent 6126960c
......@@ -481,11 +481,10 @@ class QuantizerModuleWrapper(torch.nn.Module):
self)
if 'weight' in self.config['quant_types'] and _check_weight(self.module):
new_weight = self.quantizer.quant_grad.apply(
self.quantizer.quant_grad.apply(
self.module.old_weight,
QuantType.QUANT_WEIGHT,
self)
self.module.weight = new_weight
result = self.module(*inputs)
else:
result = self.module(*inputs)
......@@ -617,7 +616,7 @@ class QuantGrad(torch.autograd.Function):
if quant_type == QuantType.QUANT_INPUT:
return wrapper.quantizer.quantize_input(tensor, wrapper, **kwargs)
elif quant_type == QuantType.QUANT_WEIGHT:
return wrapper.quantizer.quantize_weight(tensor, wrapper, **kwargs)
return wrapper.quantizer.quantize_weight(wrapper, **kwargs)
elif quant_type == QuantType.QUANT_OUTPUT:
return wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs)
else:
......
......@@ -2,6 +2,7 @@
# Licensed under the MIT license.
import logging
import copy
import torch
from schema import Schema, And, Or, Optional
from ..utils.config_validation import CompressorSchema
......@@ -15,6 +16,7 @@ logger = logging.getLogger(__name__)
class NaiveQuantizer(Quantizer):
"""quantize weight to 8 bits
"""
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
self.layer_scale = {}
......@@ -29,13 +31,15 @@ class NaiveQuantizer(Quantizer):
schema.validate(config_list)
def quantize_weight(self, weight, wrapper, **kwargs):
def quantize_weight(self, wrapper, **kwargs):
weight = copy.deepcopy(wrapper.module.old_weight.data)
new_scale = weight.abs().max() / 127
scale = max(self.layer_scale.get(wrapper.name, 0), new_scale)
self.layer_scale[wrapper.name] = scale
orig_type = weight.type() # TODO: user layer
return weight.div(scale).type(torch.int8).type(orig_type).mul(scale)
weight = weight.div(scale).type(torch.int8).type(orig_type).mul(scale)
wrapper.module.weight = weight
return weight
def update_ema(biased_ema, value, decay, step):
"""
......@@ -60,6 +64,7 @@ def update_ema(biased_ema, value, decay, step):
unbiased_ema = biased_ema / (1 - decay ** step) # Bias correction
return biased_ema, unbiased_ema
def update_quantization_param(bits, rmin, rmax):
"""
calculate the `zero_point` and `scale`.
......@@ -116,6 +121,7 @@ class QAT_Quantizer(Quantizer):
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
"""
def __init__(self, model, config_list, optimizer=None):
"""
Parameters
......@@ -215,20 +221,35 @@ class QAT_Quantizer(Quantizer):
real_val = op.scale * (quantized_val - op.zero_point)
return real_val
def quantize_weight(self, weight, wrapper, **kwargs):
def quantize_weight(self, wrapper, **kwargs):
config = wrapper.config
module = wrapper.module
weight = copy.deepcopy(wrapper.module.old_weight.data)
weight_bits = get_bits_length(config, 'weight')
quant_start_step = config.get('quant_start_step', 0)
assert weight_bits >= 1, "quant bits length should be at least 1"
if quant_start_step > self.steps:
return weight
# if bias exists, quantize bias to uint32
if hasattr(wrapper.module, 'bias') and wrapper.module.bias is not None:
bias = wrapper.module.bias.data
bias_bits = 32
rmin, rmax = torch.min(bias), torch.max(bias)
module.scale, module.zero_point = update_quantization_param(bias_bits, rmin, rmax)
bias = self._quantize(bias_bits, module, bias)
bias = self._dequantize(module, bias)
wrapper.module.bias.data = bias
# quantize weight
rmin, rmax = torch.min(weight), torch.max(weight)
module.scale, module.zero_point = update_quantization_param(weight_bits, rmin, rmax)
out = self._quantize(weight_bits, module, weight)
out = self._dequantize(module, out)
return out
weight = self._quantize(weight_bits, module, weight)
weight = self._dequantize(module, weight)
wrapper.module.weight = weight
return weight
def quantize_output(self, output, wrapper, **kwargs):
config = wrapper.config
......@@ -241,8 +262,10 @@ class QAT_Quantizer(Quantizer):
return output
current_min, current_max = torch.min(output), torch.max(output)
module.tracked_min_biased, module.tracked_min = update_ema(module.tracked_min_biased, current_min, module.ema_decay, self.steps)
module.tracked_max_biased, module.tracked_max = update_ema(module.tracked_max_biased, current_max, module.ema_decay, self.steps)
module.tracked_min_biased, module.tracked_min = update_ema(module.tracked_min_biased, current_min,
module.ema_decay, self.steps)
module.tracked_max_biased, module.tracked_max = update_ema(module.tracked_max_biased, current_max,
module.ema_decay, self.steps)
module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min, module.tracked_max)
out = self._quantize(output_bits, module, output)
out = self._dequantize(module, out)
......@@ -264,6 +287,7 @@ class DoReFaQuantizer(Quantizer):
Zhou et al., DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients
(https://arxiv.org/abs/1606.06160)
"""
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
......@@ -287,17 +311,20 @@ class DoReFaQuantizer(Quantizer):
schema.validate(config_list)
def quantize_weight(self, weight, wrapper, **kwargs):
def quantize_weight(self, wrapper, **kwargs):
weight = copy.deepcopy(wrapper.module.old_weight.data)
weight_bits = get_bits_length(wrapper.config, 'weight')
out = weight.tanh()
out = out / (2 * out.abs().max()) + 0.5
out = self.quantize(out, weight_bits)
out = 2 * out -1
return out
weight = weight.tanh()
weight = weight / (2 * weight.abs().max()) + 0.5
weight = self.quantize(weight, weight_bits)
weight = 2 * weight - 1
wrapper.module.weight = weight
# wrapper.module.weight.data = weight
return weight
def quantize(self, input_ri, q_bits):
scale = pow(2, q_bits)-1
output = torch.round(input_ri*scale)/scale
scale = pow(2, q_bits) - 1
output = torch.round(input_ri * scale) / scale
return output
......@@ -314,6 +341,7 @@ class BNNQuantizer(Quantizer):
Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1
(https://arxiv.org/abs/1602.02830)
"""
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
self.quant_grad = ClipGrad
......@@ -339,11 +367,13 @@ class BNNQuantizer(Quantizer):
schema.validate(config_list)
def quantize_weight(self, weight, wrapper, **kwargs):
out = torch.sign(weight)
def quantize_weight(self, wrapper, **kwargs):
weight = copy.deepcopy(wrapper.module.old_weight.data)
weight = torch.sign(weight)
# remove zeros
out[out == 0] = 1
return out
weight[weight == 0] = 1
wrapper.module.weight = weight
return weight
def quantize_output(self, output, wrapper, **kwargs):
out = torch.sign(output)
......
......@@ -234,20 +234,34 @@ class CompressorTestCase(TestCase):
model.relu = torch.nn.ReLU()
quantizer = torch_compressor.QAT_Quantizer(model, config_list)
quantizer.compress()
# test quantize
# range not including 0
eps = 1e-7
weight = torch.tensor([[1, 2], [3, 5]]).float()
quantize_weight = quantizer.quantize_weight(weight, model.conv2)
model.conv2.module.old_weight.data = weight
quantizer.quantize_weight(model.conv2)
assert math.isclose(model.conv2.module.scale, 5 / 255, abs_tol=eps)
assert model.conv2.module.zero_point == 0
# range including 0
weight = torch.tensor([[-1, 2], [3, 5]]).float()
quantize_weight = quantizer.quantize_weight(weight, model.conv2)
model.conv2.module.old_weight.data = weight
quantizer.quantize_weight(model.conv2)
assert math.isclose(model.conv2.module.scale, 6 / 255, abs_tol=eps)
assert model.conv2.module.zero_point in (42, 43)
# test value of weight and bias after quantization
weight = torch.tensor([[1.1287, 2.3456], [3.7814, 5.9723]])
weight_valid = torch.tensor([[1.1242, 2.3421], [3.7707, 5.9723]])
bias = torch.tensor([2.3432, 3.4342, 1.3414, 5.2341])
bias_valid = torch.tensor([2.3432, 3.4342, 1.3414, 5.2341])
model.conv2.module.old_weight.data = weight
model.conv2.module.bias.data = bias
quantizer.quantize_weight(model.conv2)
assert torch.all(torch.isclose(model.conv2.module.weight.data, weight_valid, rtol=1e-4))
assert torch.all(torch.isclose(model.conv2.module.bias.data, bias_valid, rtol=1e-7))
# test ema
eps = 1e-7
x = torch.tensor([[-0.2, 0], [0.1, 0.2]])
out = model.relu(x)
assert math.isclose(model.relu.module.tracked_min_biased, 0, abs_tol=eps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment