Unverified Commit a6069165 authored by chenbohua3's avatar chenbohua3 Committed by GitHub
Browse files

fix wrong quantization target in weight quantization (#4038)

parent e9c21fd3
......@@ -2,7 +2,6 @@
# Licensed under the MIT license.
import logging
import copy
from collections import defaultdict
import torch
from schema import Schema, And, Or, Optional
......@@ -36,7 +35,7 @@ class NaiveQuantizer(Quantizer):
schema.validate(config_list)
def quantize_weight(self, wrapper, **kwargs):
weight = copy.deepcopy(wrapper.module.old_weight.data)
weight = wrapper.module.weight
new_scale = weight.abs().max() / 127
scale = max(self.layer_scale.get(wrapper.name, 0), new_scale)
self.layer_scale[wrapper.name] = scale
......@@ -218,10 +217,8 @@ class ObserverQuantizer(Quantizer):
# the Pseudo-quantized one. So there is no need to quantize it
if self.compressed:
return
module = wrapper.module
old_weight = module.weight
self.record(wrapper, 'weight', old_weight)
weight = wrapper.module.weight
self.record(wrapper, 'weight', weight)
def quantize_output(self, output, wrapper, **kwargs):
if self.compressed:
......@@ -474,8 +471,8 @@ class QAT_Quantizer(Quantizer):
def quantize_weight(self, wrapper, **kwargs):
config = wrapper.config
module = wrapper.module
weight = module.weight
input = kwargs['input_tensor'] # pylint: disable=redefined-builtin
weight = copy.deepcopy(wrapper.module.old_weight.data)
weight_bits = get_bits_length(config, 'weight')
quant_start_step = config.get('quant_start_step', 0)
assert weight_bits >= 1, "quant bits length should be at least 1"
......@@ -675,7 +672,7 @@ class DoReFaQuantizer(Quantizer):
schema.validate(config_list)
def quantize_weight(self, wrapper, **kwargs):
weight = copy.deepcopy(wrapper.module.old_weight.data)
weight = wrapper.module.weight
weight_bits = get_bits_length(wrapper.config, 'weight')
weight = weight.tanh()
weight = weight / (2 * weight.abs().max()) + 0.5
......@@ -785,7 +782,7 @@ class BNNQuantizer(Quantizer):
schema.validate(config_list)
def quantize_weight(self, wrapper, **kwargs):
weight = copy.deepcopy(wrapper.module.old_weight.data)
weight = wrapper.module.weight
weight = torch.sign(weight)
# remove zeros
weight[weight == 0] = 1
......@@ -944,11 +941,11 @@ class LsqQuantizer(Quantizer):
def quantize_weight(self, wrapper, **kwargs):
module = wrapper.module
weight = wrapper.module.weight
# todo: add support for quantize bias. If we use TensorRT as backend, there is no need to quantize
# bias
old_weight = module.old_weight
weight = self.quantize(old_weight, module.weight_scale, module.weight_qmin, module.weight_qmax)
weight = self.quantize(weight, module.weight_scale, module.weight_qmin, module.weight_qmax)
module.weight = weight
return weight
......
......@@ -559,6 +559,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
self.module.weight = new_weight
else:
new_weight = self.module.old_weight
self.module.weight = new_weight.data
self.quantizer.quant_grad(
new_weight,
......
......@@ -328,13 +328,13 @@ class CompressorTestCase(TestCase):
eps = 1e-7
input = torch.tensor([[0, 4], [2, 1]]).float()
weight = torch.tensor([[1, 2], [3, 5]]).float()
model.conv2.module.old_weight.data = weight
model.conv2.module.weight.data = weight
quantizer.quantize_weight(model.conv2, input_tensor=input)
assert math.isclose(model.conv2.module.scale, 5 / 255, abs_tol=eps)
assert model.conv2.module.zero_point == 0
# range including 0
weight = torch.tensor([[-1, 2], [3, 5]]).float()
model.conv2.module.old_weight.data = weight
model.conv2.module.weight = weight
quantizer.quantize_weight(model.conv2, input_tensor=input)
assert math.isclose(model.conv2.module.scale, 6 / 255, abs_tol=eps)
assert model.conv2.module.zero_point in (42, 43)
......@@ -343,7 +343,7 @@ class CompressorTestCase(TestCase):
weight_valid = torch.tensor([[1.1242, 2.3421], [3.7707, 5.9723]])
bias = torch.tensor([2.3432, 3.4342, 1.3414, 5.2341])
bias_valid = torch.tensor([2.3432, 3.4342, 1.3414, 5.2341])
model.conv2.module.old_weight.data = weight
model.conv2.module.weight = weight
model.conv2.module.bias.data = bias
quantizer.quantize_weight(model.conv2, input_tensor=input)
assert torch.all(torch.isclose(model.conv2.module.weight.data, weight_valid, rtol=1e-4))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment