Unverified Commit 19914055 authored by lin bin's avatar lin bin Committed by GitHub
Browse files

[Quantization] fix parameter assign problem of weight (#4138)

parent 681ccc59
......@@ -253,7 +253,7 @@ class ObserverQuantizer(Quantizer):
module.weight_qmin,
module.weight_qmax)
delattr(module, 'weight')
module.register_parameter('weight', torch.nn.Parameter(quantized_weight))
module.register_buffer('weight', quantized_weight)
if "input" in config.get("quant_types", []):
scale, zero_point = self.calculate_qparams(layer.name, 'input')
module.register_buffer('input_scale', scale.to(self.device))
......@@ -301,6 +301,14 @@ class ObserverQuantizer(Quantizer):
calibration_config[name]['tracked_min_weight'] = -val
calibration_config[name]['tracked_qmin_weight'] = -127
calibration_config[name]['tracked_qmax_weight'] = 127
weight = module.weight
quantized_weight = self._quantize(weight,
module.weight_scale,
module.weight_zero_point,
module.weight_qmin,
module.weight_qmax)
delattr(module, 'weight')
module.register_parameter('weight', torch.nn.Parameter(quantized_weight))
# refactor these magic numbers when customizations of dtype and qscheme are ready.
if hasattr(module, 'input_scale'):
calibration_config[name]['input_bits'] = 8
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment