Unverified Commit 39e3a990 authored by chenbohua3's avatar chenbohua3 Committed by GitHub
Browse files

change signature of quantize_input (#4039)

parent 86335921
...@@ -187,12 +187,7 @@ class ObserverQuantizer(Quantizer): ...@@ -187,12 +187,7 @@ class ObserverQuantizer(Quantizer):
def record(self, wrapper, quant_type, tensor): def record(self, wrapper, quant_type, tensor):
name = wrapper.name name = wrapper.name
observer = self.all_observers[name][quant_type] observer = self.all_observers[name][quant_type]
if isinstance(tensor, tuple): observer(tensor.cpu())
# NB: This only works for single tensor
tensor = (t.cpu() for t in tensor)
observer(*tensor)
else:
observer(tensor.cpu())
def calculate_qparams(self, name, quant_type): def calculate_qparams(self, name, quant_type):
observer = self.all_observers[name][quant_type] observer = self.all_observers[name][quant_type]
...@@ -206,17 +201,14 @@ class ObserverQuantizer(Quantizer): ...@@ -206,17 +201,14 @@ class ObserverQuantizer(Quantizer):
x = (x - zero_point) * scale x = (x - zero_point) * scale
return x return x
def quantize_input(self, *inputs, wrapper, **kwargs): def quantize_input(self, inputs, wrapper, **kwargs):
if self.compressed: if self.compressed:
module = wrapper.module module = wrapper.module
new_input = self._quantize(inputs[0], inputs = self._quantize(inputs,
module.input_scale, module.input_scale,
module.input_zero_point, module.input_zero_point,
module.input_qmin, module.input_qmin,
module.input_qmax) module.input_qmax)
list_inp = list(inputs)
list_inp[0] = new_input
inputs = tuple(list_inp)
else: else:
self.record(wrapper, 'input', inputs) self.record(wrapper, 'input', inputs)
return inputs return inputs
...@@ -973,20 +965,16 @@ class LsqQuantizer(Quantizer): ...@@ -973,20 +965,16 @@ class LsqQuantizer(Quantizer):
output = self.quantize(output, module.output_scale, module.output_qmin, module.output_qmax) output = self.quantize(output, module.output_scale, module.output_qmin, module.output_qmax)
return output return output
def quantize_input(self, *inputs, wrapper, **kwargs): def quantize_input(self, inputs, wrapper, **kwargs):
# This is hacky since it is not recommended to modify a tuple
# NB: support layers with multi inputs
module = wrapper.module module = wrapper.module
# initialize the scale # initialize the scale
if self.bound_model.steps == 1: if self.bound_model.steps == 1:
qmax = module.input_qmax qmax = module.input_qmax
init_oup_scale = inputs[0].data.detach().abs().mean() * 2 / (qmax ** 0.5) init_oup_scale = inputs.data.detach().abs().mean() * 2 / (qmax ** 0.5)
module.input_scale.data = init_oup_scale module.input_scale.data = init_oup_scale
new_input = self.quantize(inputs[0], module.input_scale, module.input_qmin, module.input_qmax) inputs = self.quantize(inputs, module.input_scale, module.input_qmin, module.input_qmax)
list_inp = list(inputs) return inputs
list_inp[0] = new_input
return tuple(list_inp)
def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None): def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None):
""" """
......
...@@ -544,10 +544,12 @@ class QuantizerModuleWrapper(torch.nn.Module): ...@@ -544,10 +544,12 @@ class QuantizerModuleWrapper(torch.nn.Module):
def forward(self, *inputs): def forward(self, *inputs):
if 'input' in self.config['quant_types']: if 'input' in self.config['quant_types']:
inputs = self.quantizer.quant_grad( assert len(inputs) == 1, "Quantization of input only supports ops with single input."
inputs, new_inp = self.quantizer.quant_grad(
inputs[0],
QuantType.QUANT_INPUT, QuantType.QUANT_INPUT,
self) self)
inputs = (new_inp,)
if 'weight' in self.config['quant_types'] and _check_weight(self.module): if 'weight' in self.config['quant_types'] and _check_weight(self.module):
if self.bn_module is not None: if self.bn_module is not None:
...@@ -640,7 +642,7 @@ class Quantizer(Compressor): ...@@ -640,7 +642,7 @@ class Quantizer(Compressor):
""" """
raise NotImplementedError('Quantizer must overload quantize_output()') raise NotImplementedError('Quantizer must overload quantize_output()')
def quantize_input(self, *inputs, wrapper, **kwargs): def quantize_input(self, inputs, wrapper, **kwargs):
""" """
quantize should overload this method to quantize input. quantize should overload this method to quantize input.
This method is effectively hooked to :meth:`forward` of the model. This method is effectively hooked to :meth:`forward` of the model.
...@@ -912,7 +914,7 @@ def _check_bias(module): ...@@ -912,7 +914,7 @@ def _check_bias(module):
def quantize_helper(tensor, quant_type, wrapper, input_tensor=None, **kwargs): def quantize_helper(tensor, quant_type, wrapper, input_tensor=None, **kwargs):
if quant_type == QuantType.QUANT_INPUT: if quant_type == QuantType.QUANT_INPUT:
output = wrapper.quantizer.quantize_input(*tensor, wrapper=wrapper, **kwargs) output = wrapper.quantizer.quantize_input(tensor, wrapper=wrapper, **kwargs)
elif quant_type == QuantType.QUANT_WEIGHT: elif quant_type == QuantType.QUANT_WEIGHT:
output = wrapper.quantizer.quantize_weight(wrapper, input_tensor=input_tensor, **kwargs) output = wrapper.quantizer.quantize_weight(wrapper, input_tensor=input_tensor, **kwargs)
elif quant_type == QuantType.QUANT_OUTPUT: elif quant_type == QuantType.QUANT_OUTPUT:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment