Unverified Commit 39e3a990 authored by chenbohua3's avatar chenbohua3 Committed by GitHub
Browse files

change signature of quantize_input (#4039)

parent 86335921
......@@ -187,12 +187,7 @@ class ObserverQuantizer(Quantizer):
def record(self, wrapper, quant_type, tensor):
name = wrapper.name
observer = self.all_observers[name][quant_type]
if isinstance(tensor, tuple):
# NB: This only works for single tensor
tensor = (t.cpu() for t in tensor)
observer(*tensor)
else:
observer(tensor.cpu())
observer(tensor.cpu())
def calculate_qparams(self, name, quant_type):
observer = self.all_observers[name][quant_type]
......@@ -206,17 +201,14 @@ class ObserverQuantizer(Quantizer):
x = (x - zero_point) * scale
return x
def quantize_input(self, *inputs, wrapper, **kwargs):
def quantize_input(self, inputs, wrapper, **kwargs):
if self.compressed:
module = wrapper.module
new_input = self._quantize(inputs[0],
inputs = self._quantize(inputs,
module.input_scale,
module.input_zero_point,
module.input_qmin,
module.input_qmax)
list_inp = list(inputs)
list_inp[0] = new_input
inputs = tuple(list_inp)
else:
self.record(wrapper, 'input', inputs)
return inputs
......@@ -973,20 +965,16 @@ class LsqQuantizer(Quantizer):
output = self.quantize(output, module.output_scale, module.output_qmin, module.output_qmax)
return output
def quantize_input(self, *inputs, wrapper, **kwargs):
# This is hacky since it is not recommended to modify a tuple
# NB: support layers with multi inputs
def quantize_input(self, inputs, wrapper, **kwargs):
module = wrapper.module
# initialize the scale
if self.bound_model.steps == 1:
qmax = module.input_qmax
init_oup_scale = inputs[0].data.detach().abs().mean() * 2 / (qmax ** 0.5)
init_oup_scale = inputs.data.detach().abs().mean() * 2 / (qmax ** 0.5)
module.input_scale.data = init_oup_scale
new_input = self.quantize(inputs[0], module.input_scale, module.input_qmin, module.input_qmax)
list_inp = list(inputs)
list_inp[0] = new_input
return tuple(list_inp)
inputs = self.quantize(inputs, module.input_scale, module.input_qmin, module.input_qmax)
return inputs
def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None):
"""
......
......@@ -544,10 +544,12 @@ class QuantizerModuleWrapper(torch.nn.Module):
def forward(self, *inputs):
if 'input' in self.config['quant_types']:
inputs = self.quantizer.quant_grad(
inputs,
assert len(inputs) == 1, "Quantization of input only supports ops with single input."
new_inp = self.quantizer.quant_grad(
inputs[0],
QuantType.QUANT_INPUT,
self)
inputs = (new_inp,)
if 'weight' in self.config['quant_types'] and _check_weight(self.module):
if self.bn_module is not None:
......@@ -640,7 +642,7 @@ class Quantizer(Compressor):
"""
raise NotImplementedError('Quantizer must overload quantize_output()')
def quantize_input(self, *inputs, wrapper, **kwargs):
def quantize_input(self, inputs, wrapper, **kwargs):
"""
quantize should overload this method to quantize input.
This method is effectively hooked to :meth:`forward` of the model.
......@@ -912,7 +914,7 @@ def _check_bias(module):
def quantize_helper(tensor, quant_type, wrapper, input_tensor=None, **kwargs):
if quant_type == QuantType.QUANT_INPUT:
output = wrapper.quantizer.quantize_input(*tensor, wrapper=wrapper, **kwargs)
output = wrapper.quantizer.quantize_input(tensor, wrapper=wrapper, **kwargs)
elif quant_type == QuantType.QUANT_WEIGHT:
output = wrapper.quantizer.quantize_weight(wrapper, input_tensor=input_tensor, **kwargs)
elif quant_type == QuantType.QUANT_OUTPUT:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment