Unverified Commit f13a9cd4 authored by lin bin's avatar lin bin Committed by GitHub
Browse files

[Quantization] fix QAT export param (#4252)

parent cdb65dac
......@@ -26,7 +26,6 @@ __all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer',
logger = logging.getLogger(__name__)
class NaiveQuantizer(Quantizer):
"""quantize weight to 8 bits
"""
......@@ -676,17 +675,20 @@ class QAT_Quantizer(Quantizer):
for layer, _ in modules_to_compress:
name, module = layer.name, layer.module
if name not in calibration_config:
if hasattr(module, 'weight_bits') or hasattr(module, 'output_bits') or hasattr(module, 'input_bits'):
if module.layer_quant_setting.weight or module.layer_quant_setting.input or module.layer_quant_setting.output:
logger.warning(f"Can not find module {name}'s parameter in input config.")
continue
if hasattr(module, 'weight_bits'):
assert calibration_config[name]['weight_bits'] == module.weight_bits, f"weight bits of module {name} fail to match"
if hasattr(module, 'input_bits'):
assert calibration_config[name]['input_bits'] == module.input_bits, f"input bits of module {name} fail to match"
if module.layer_quant_setting.weight:
assert calibration_config[name]['weight_bits'] == module.layer_quant_setting.weight.bits, \
f"weight bits of module {name} fail to match"
if module.layer_quant_setting.input:
assert calibration_config[name]['input_bits'] == module.layer_quant_setting.input.bits, \
f"input bits of module {name} fail to match"
module.tracked_min_input.data = torch.tensor([calibration_config[name]['tracked_min_input']])
module.tracked_max_input.data = torch.tensor([calibration_config[name]['tracked_max_input']])
if hasattr(module, 'output_bits'):
assert calibration_config[name]['output_bits'] == module.output_bits, f"output bits of module {name} fail to match"
if module.layer_quant_setting.output:
assert calibration_config[name]['output_bits'] == module.layer_quant_setting.output.bits, \
f"output bits of module {name} fail to match"
module.tracked_min_output.data = torch.tensor([calibration_config[name]['tracked_min_output']])
module.tracked_max_output.data = torch.tensor([calibration_config[name]['tracked_max_output']])
......@@ -716,11 +718,13 @@ class QAT_Quantizer(Quantizer):
self._unwrap_model()
calibration_config = {}
for name, module in self.bound_model.named_modules():
if hasattr(module, 'weight_bits') or hasattr(module, 'output_bits'):
modules_to_compress = self.get_modules_to_compress()
for layer, _ in modules_to_compress:
name, module = layer.name, layer.module
if hasattr(module.layer_quant_setting, 'weight') or hasattr(module.layer_quant_setting, 'output'):
calibration_config[name] = {}
if hasattr(module, 'weight_bits'):
calibration_config[name]['weight_bits'] = int(module.weight_bits)
if module.layer_quant_setting.weight:
calibration_config[name]['weight_bits'] = int(module.layer_quant_setting.weight.bits)
calibration_config[name]['weight_scale'] = module.weight_scale
calibration_config[name]['weight_zero_point'] = module.weight_zero_point
......@@ -738,13 +742,14 @@ class QAT_Quantizer(Quantizer):
module.register_parameter('bias', actual_bias)
else:
setattr(module, 'bias', None)
if hasattr(module, 'input_bits'):
calibration_config[name]['input_bits'] = int(module.input_bits)
if module.layer_quant_setting.input:
calibration_config[name]['input_bits'] = int(module.layer_quant_setting.input.bits)
calibration_config[name]['tracked_min_input'] = float(module.tracked_min_input)
calibration_config[name]['tracked_max_input'] = float(module.tracked_max_input)
if hasattr(module, 'output_bits'):
calibration_config[name]['output_bits'] = int(module.output_bits)
if module.layer_quant_setting.output:
calibration_config[name]['output_bits'] = int(module.layer_quant_setting.output.bits)
calibration_config[name]['tracked_min_output'] = float(module.tracked_min_output)
calibration_config[name]['tracked_max_output'] = float(module.tracked_max_output)
self._del_simulated_attr(module)
......
......@@ -79,5 +79,5 @@ def get_quant_shape(shape, quant_type, quant_scheme):
if is_per_channel(quant_scheme):
quant_shape = [1 if idx != default_idx else s for idx, s in enumerate(shape)]
else:
quant_shape = []
quant_shape = [1]
return quant_shape
......@@ -9,7 +9,7 @@ import torch.nn.functional as F
import schema
import nni.algorithms.compression.pytorch.pruning as torch_pruner
import nni.algorithms.compression.pytorch.quantization as torch_quantizer
from nni.compression.pytorch.quantization.utils import calculate_qmin_qmax, get_quant_shape, get_min_max_value
from nni.compression.pytorch.quantization.utils import calculate_qmin_qmax, get_quant_shape
import math
......@@ -398,11 +398,11 @@ class CompressorTestCase(TestCase):
target_zero_point = torch.ones([2, 1, 1, 1]) * 127
elif qscheme == 'per_tensor_symmetric':
if dtype == 'int':
target_scale = torch.tensor(18. / 127)
target_zero_point = torch.zeros([])
target_scale = torch.tensor([18. / 127])
target_zero_point = torch.zeros([1])
else:
target_scale = torch.tensor(18. / 127.5)
target_zero_point = torch.ones([]) * 127
target_scale = torch.tensor([18. / 127.5])
target_zero_point = torch.ones([1]) * 127
elif qscheme == 'per_channel_affine':
min_val = torch.tensor([0., 0.]).view([2, 1, 1, 1])
if dtype == 'int':
......@@ -413,10 +413,10 @@ class CompressorTestCase(TestCase):
target_zero_point = 0 - torch.round(min_val / target_scale)
else:
if dtype == 'int':
target_scale = torch.tensor(18. / 254)
target_scale = torch.tensor([18. / 254])
target_zero_point = -127 - torch.round(0 / target_scale)
else:
target_scale = torch.tensor(18. / 255)
target_scale = torch.tensor([18. / 255])
target_zero_point = 0 - torch.round(0 / target_scale)
wrapper = getattr(model, name)
wrapper.module.weight = weight
......@@ -434,11 +434,11 @@ class CompressorTestCase(TestCase):
target_zero_point = torch.ones([1, 1, 1, 1]) * 127
elif qscheme == 'per_tensor_symmetric':
if dtype == 'int':
target_scale = torch.tensor(15. / 127)
target_zero_point = torch.zeros([])
target_scale = torch.tensor([15. / 127])
target_zero_point = torch.zeros([1])
else:
target_scale = torch.tensor(15. / 127.5)
target_zero_point = torch.ones([]) * 127
target_scale = torch.tensor([15. / 127.5])
target_zero_point = torch.ones([1]) * 127
elif qscheme == 'per_channel_affine':
min_val = torch.tensor([0.]).view([1, 1, 1, 1])
if dtype == 'int':
......@@ -449,10 +449,10 @@ class CompressorTestCase(TestCase):
target_zero_point = 0 - torch.round(min_val / target_scale)
else:
if dtype == 'int':
target_scale = torch.tensor(15. / 254)
target_scale = torch.tensor([15. / 254])
target_zero_point = -127 - torch.round(0 / target_scale)
else:
target_scale = torch.tensor(15. / 255)
target_scale = torch.tensor([15. / 255])
target_zero_point = 0 - torch.round(0 / target_scale)
quantizer.quantize_input(inp, wrapper)
self.assertTrue(torch.equal(getattr(model, name).module.input_scale, target_scale))
......@@ -488,7 +488,7 @@ class CompressorTestCase(TestCase):
assert model.conv2.module.weight_zero_point == 0
quantizer.quantize_input(input, model.conv2)
self.assertTrue(torch.allclose(model.conv2.module.input_scale, torch.tensor([4. / 255])))
self.assertTrue(torch.equal(model.conv2.module.input_zero_point, torch.tensor(0.)))
self.assertTrue(torch.equal(model.conv2.module.input_zero_point, torch.tensor([0.])))
# range including 0
weight = torch.tensor([[-1, 2], [3, 5]]).float()
model.conv2.module.weight = weight
......@@ -497,7 +497,7 @@ class CompressorTestCase(TestCase):
assert model.conv2.module.weight_zero_point in (42, 43)
quantizer.quantize_input(input, model.conv2)
self.assertTrue(torch.allclose(model.conv2.module.input_scale, torch.tensor([4. / 255])))
self.assertTrue(torch.equal(model.conv2.module.input_zero_point, torch.tensor(0.)))
self.assertTrue(torch.equal(model.conv2.module.input_zero_point, torch.tensor([0.])))
# test value of weight and bias after quantization
weight = torch.tensor([[1.1287, 2.3456], [3.7814, 5.9723]])
weight_valid = torch.tensor([[1.1242, 2.3421], [3.7707, 5.9723]])
......@@ -513,14 +513,14 @@ class CompressorTestCase(TestCase):
eps = 1e-7
x = torch.tensor([[-0.2, 0], [0.1, 0.2]])
model.relu(x)
self.assertTrue(torch.equal(model.relu.module.tracked_min_output, torch.tensor(0.)))
self.assertTrue(torch.equal(model.relu.module.tracked_max_output, torch.tensor(0.2)))
self.assertTrue(torch.equal(model.relu.module.tracked_min_output, torch.tensor([0.])))
self.assertTrue(torch.equal(model.relu.module.tracked_max_output, torch.tensor([0.2])))
quantizer.step_with_optimizer()
x = torch.tensor([[0.2, 0.4], [0.6, 0.8]])
model.relu(x)
self.assertTrue(torch.equal(model.relu.module.tracked_min_output, torch.tensor(0.002)))
self.assertTrue(torch.equal(model.relu.module.tracked_max_output, torch.tensor(0.2060)))
self.assertTrue(torch.equal(model.relu.module.tracked_min_output, torch.tensor([0.002])))
self.assertTrue(torch.equal(model.relu.module.tracked_max_output, torch.tensor([0.2060])))
def test_torch_quantizer_export(self):
config_list_qat = [{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment