Unverified Commit a6069165 authored by chenbohua3's avatar chenbohua3 Committed by GitHub
Browse files

fix wrong quantization target in weight quantization (#4038)

parent e9c21fd3
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import logging import logging
import copy
from collections import defaultdict from collections import defaultdict
import torch import torch
from schema import Schema, And, Or, Optional from schema import Schema, And, Or, Optional
...@@ -36,7 +35,7 @@ class NaiveQuantizer(Quantizer): ...@@ -36,7 +35,7 @@ class NaiveQuantizer(Quantizer):
schema.validate(config_list) schema.validate(config_list)
def quantize_weight(self, wrapper, **kwargs): def quantize_weight(self, wrapper, **kwargs):
weight = copy.deepcopy(wrapper.module.old_weight.data) weight = wrapper.module.weight
new_scale = weight.abs().max() / 127 new_scale = weight.abs().max() / 127
scale = max(self.layer_scale.get(wrapper.name, 0), new_scale) scale = max(self.layer_scale.get(wrapper.name, 0), new_scale)
self.layer_scale[wrapper.name] = scale self.layer_scale[wrapper.name] = scale
...@@ -218,10 +217,8 @@ class ObserverQuantizer(Quantizer): ...@@ -218,10 +217,8 @@ class ObserverQuantizer(Quantizer):
# the Pseudo-quantized one. So there is no need to quantize it # the Pseudo-quantized one. So there is no need to quantize it
if self.compressed: if self.compressed:
return return
weight = wrapper.module.weight
module = wrapper.module self.record(wrapper, 'weight', weight)
old_weight = module.weight
self.record(wrapper, 'weight', old_weight)
def quantize_output(self, output, wrapper, **kwargs): def quantize_output(self, output, wrapper, **kwargs):
if self.compressed: if self.compressed:
...@@ -474,8 +471,8 @@ class QAT_Quantizer(Quantizer): ...@@ -474,8 +471,8 @@ class QAT_Quantizer(Quantizer):
def quantize_weight(self, wrapper, **kwargs): def quantize_weight(self, wrapper, **kwargs):
config = wrapper.config config = wrapper.config
module = wrapper.module module = wrapper.module
weight = module.weight
input = kwargs['input_tensor'] # pylint: disable=redefined-builtin input = kwargs['input_tensor'] # pylint: disable=redefined-builtin
weight = copy.deepcopy(wrapper.module.old_weight.data)
weight_bits = get_bits_length(config, 'weight') weight_bits = get_bits_length(config, 'weight')
quant_start_step = config.get('quant_start_step', 0) quant_start_step = config.get('quant_start_step', 0)
assert weight_bits >= 1, "quant bits length should be at least 1" assert weight_bits >= 1, "quant bits length should be at least 1"
...@@ -675,7 +672,7 @@ class DoReFaQuantizer(Quantizer): ...@@ -675,7 +672,7 @@ class DoReFaQuantizer(Quantizer):
schema.validate(config_list) schema.validate(config_list)
def quantize_weight(self, wrapper, **kwargs): def quantize_weight(self, wrapper, **kwargs):
weight = copy.deepcopy(wrapper.module.old_weight.data) weight = wrapper.module.weight
weight_bits = get_bits_length(wrapper.config, 'weight') weight_bits = get_bits_length(wrapper.config, 'weight')
weight = weight.tanh() weight = weight.tanh()
weight = weight / (2 * weight.abs().max()) + 0.5 weight = weight / (2 * weight.abs().max()) + 0.5
...@@ -785,7 +782,7 @@ class BNNQuantizer(Quantizer): ...@@ -785,7 +782,7 @@ class BNNQuantizer(Quantizer):
schema.validate(config_list) schema.validate(config_list)
def quantize_weight(self, wrapper, **kwargs): def quantize_weight(self, wrapper, **kwargs):
weight = copy.deepcopy(wrapper.module.old_weight.data) weight = wrapper.module.weight
weight = torch.sign(weight) weight = torch.sign(weight)
# remove zeros # remove zeros
weight[weight == 0] = 1 weight[weight == 0] = 1
...@@ -944,11 +941,11 @@ class LsqQuantizer(Quantizer): ...@@ -944,11 +941,11 @@ class LsqQuantizer(Quantizer):
def quantize_weight(self, wrapper, **kwargs): def quantize_weight(self, wrapper, **kwargs):
module = wrapper.module module = wrapper.module
weight = wrapper.module.weight
# todo: add support for quantize bias. If we use TensorRT as backend, there is no need to quantize # todo: add support for quantize bias. If we use TensorRT as backend, there is no need to quantize
# bias # bias
old_weight = module.old_weight weight = self.quantize(weight, module.weight_scale, module.weight_qmin, module.weight_qmax)
weight = self.quantize(old_weight, module.weight_scale, module.weight_qmin, module.weight_qmax)
module.weight = weight module.weight = weight
return weight return weight
......
...@@ -559,6 +559,7 @@ class QuantizerModuleWrapper(torch.nn.Module): ...@@ -559,6 +559,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
self.module.weight = new_weight self.module.weight = new_weight
else: else:
new_weight = self.module.old_weight new_weight = self.module.old_weight
self.module.weight = new_weight.data
self.quantizer.quant_grad( self.quantizer.quant_grad(
new_weight, new_weight,
......
...@@ -328,13 +328,13 @@ class CompressorTestCase(TestCase): ...@@ -328,13 +328,13 @@ class CompressorTestCase(TestCase):
eps = 1e-7 eps = 1e-7
input = torch.tensor([[0, 4], [2, 1]]).float() input = torch.tensor([[0, 4], [2, 1]]).float()
weight = torch.tensor([[1, 2], [3, 5]]).float() weight = torch.tensor([[1, 2], [3, 5]]).float()
model.conv2.module.old_weight.data = weight model.conv2.module.weight.data = weight
quantizer.quantize_weight(model.conv2, input_tensor=input) quantizer.quantize_weight(model.conv2, input_tensor=input)
assert math.isclose(model.conv2.module.scale, 5 / 255, abs_tol=eps) assert math.isclose(model.conv2.module.scale, 5 / 255, abs_tol=eps)
assert model.conv2.module.zero_point == 0 assert model.conv2.module.zero_point == 0
# range including 0 # range including 0
weight = torch.tensor([[-1, 2], [3, 5]]).float() weight = torch.tensor([[-1, 2], [3, 5]]).float()
model.conv2.module.old_weight.data = weight model.conv2.module.weight = weight
quantizer.quantize_weight(model.conv2, input_tensor=input) quantizer.quantize_weight(model.conv2, input_tensor=input)
assert math.isclose(model.conv2.module.scale, 6 / 255, abs_tol=eps) assert math.isclose(model.conv2.module.scale, 6 / 255, abs_tol=eps)
assert model.conv2.module.zero_point in (42, 43) assert model.conv2.module.zero_point in (42, 43)
...@@ -343,7 +343,7 @@ class CompressorTestCase(TestCase): ...@@ -343,7 +343,7 @@ class CompressorTestCase(TestCase):
weight_valid = torch.tensor([[1.1242, 2.3421], [3.7707, 5.9723]]) weight_valid = torch.tensor([[1.1242, 2.3421], [3.7707, 5.9723]])
bias = torch.tensor([2.3432, 3.4342, 1.3414, 5.2341]) bias = torch.tensor([2.3432, 3.4342, 1.3414, 5.2341])
bias_valid = torch.tensor([2.3432, 3.4342, 1.3414, 5.2341]) bias_valid = torch.tensor([2.3432, 3.4342, 1.3414, 5.2341])
model.conv2.module.old_weight.data = weight model.conv2.module.weight = weight
model.conv2.module.bias.data = bias model.conv2.module.bias.data = bias
quantizer.quantize_weight(model.conv2, input_tensor=input) quantizer.quantize_weight(model.conv2, input_tensor=input)
assert torch.all(torch.isclose(model.conv2.module.weight.data, weight_valid, rtol=1e-4)) assert torch.all(torch.isclose(model.conv2.module.weight.data, weight_valid, rtol=1e-4))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment