Unverified Commit 9f6c6752 authored by lin bin's avatar lin bin Committed by GitHub
Browse files

fix lsq export (#4264)


Co-authored-by: default avatarlinbinskn <linbinskn@outlook.com>
parent 7fe57a0e
...@@ -1162,7 +1162,7 @@ class LsqQuantizer(Quantizer): ...@@ -1162,7 +1162,7 @@ class LsqQuantizer(Quantizer):
calibration_config = {} calibration_config = {}
for name, module in self.bound_model.named_modules(): for name, module in self.bound_model.named_modules():
if hasattr(module, 'input_bits') or hasattr(module, 'output_bits'): if hasattr(module, 'input_bits') or hasattr(module, 'weight_bits') or hasattr(module, 'output_bits'):
calibration_config[name] = {} calibration_config[name] = {}
if hasattr(module, 'weight_bits'): if hasattr(module, 'weight_bits'):
calibration_config[name]['weight_bits'] = int(module.weight_bits) calibration_config[name]['weight_bits'] = int(module.weight_bits)
...@@ -1182,6 +1182,11 @@ class LsqQuantizer(Quantizer): ...@@ -1182,6 +1182,11 @@ class LsqQuantizer(Quantizer):
module.register_parameter('bias', actual_bias) module.register_parameter('bias', actual_bias)
else: else:
setattr(module, 'bias', None) setattr(module, 'bias', None)
if hasattr(module, 'input_bits'):
calibration_config[name]['input_bits'] = int(module.input_bits)
abs_max_input = float(module.input_scale * module.input_qmax)
calibration_config[name]['tracked_min_input'] = -abs_max_input
calibration_config[name]['tracked_max_input'] = abs_max_input
if hasattr(module, 'output_bits'): if hasattr(module, 'output_bits'):
calibration_config[name]['output_bits'] = int(module.output_bits) calibration_config[name]['output_bits'] = int(module.output_bits)
abs_max_output = float(module.output_scale * module.output_qmax) abs_max_output = float(module.output_scale * module.output_qmax)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment