Unverified Commit 2f6a74f1 authored by Dalong's avatar Dalong Committed by GitHub
Browse files

fix quant grad function calculation error (#3160)

parent 78e874f9
...@@ -580,17 +580,55 @@ class QuantType: ...@@ -580,17 +580,55 @@ class QuantType:
""" """
Enum class for quantization type. Enum class for quantization type.
""" """
QUANT_INPUT = 0 QUANT_INPUT = 'input'
QUANT_WEIGHT = 1 QUANT_WEIGHT = 'weight'
QUANT_OUTPUT = 2 QUANT_OUTPUT = 'output'
class QuantGrad(torch.autograd.Function): class QuantGrad(torch.autograd.Function):
""" """
Base class for overriding backward function of quantization operation. Base class for overriding backward function of quantization operation.
""" """
@classmethod
def _quantize(cls, x, scale, zero_point):
"""
Reference function for quantizing x -- non-clamped.
Parameters
----------
x : Tensor
tensor to be quantized
scale : Tensor
scale for quantizing x
zero_point : Tensor
zero_point for quantizing x
Returns
-------
tensor
quantized x without clamped
"""
return ((x / scale) + zero_point).round()
@classmethod
def get_bits_length(cls, config, quant_type):
"""
Get bit for quantize config
Parameters
----------
config : Dict
the configuration for quantization
quant_type : str
quant type
Returns
-------
int
n-bits for quantization configuration
"""
if isinstance(config["quant_bits"], int):
return config["quant_bits"]
else:
return config["quant_bits"].get(quant_type)
@staticmethod @staticmethod
def quant_backward(tensor, grad_output, quant_type): def quant_backward(tensor, grad_output, scale, zero_point, qmin, qmax):
""" """
This method should be overrided by subclass to provide customized backward function, This method should be overrided by subclass to provide customized backward function,
default implementation is Straight-Through Estimator default implementation is Straight-Through Estimator
...@@ -600,32 +638,45 @@ class QuantGrad(torch.autograd.Function): ...@@ -600,32 +638,45 @@ class QuantGrad(torch.autograd.Function):
input of quantization operation input of quantization operation
grad_output : Tensor grad_output : Tensor
gradient of the output of quantization operation gradient of the output of quantization operation
quant_type : QuantType scale : Tensor
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`, the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`,
you can define different behavior for different types. you can define different behavior for different types.
zero_point : Tensor
zero_point for quantizing tensor
qmin : Tensor
quant_min for quantizing tensor
qmax : Tensor
quant_max for quantizng tensor
Returns Returns
------- -------
tensor tensor
gradient of the input of quantization operation gradient of the input of quantization operation
""" """
tensor_q = QuantGrad._quantize(tensor, scale, zero_point)
mask = (tensor_q < qmin) | (tensor_q > qmax)
grad_output[mask] = 0
return grad_output return grad_output
@staticmethod @staticmethod
def forward(ctx, tensor, quant_type, wrapper, **kwargs): def forward(ctx, tensor, quant_type, wrapper, **kwargs):
ctx.save_for_backward(tensor, torch.Tensor([quant_type]))
if quant_type == QuantType.QUANT_INPUT: if quant_type == QuantType.QUANT_INPUT:
return wrapper.quantizer.quantize_input(tensor, wrapper, **kwargs) output = wrapper.quantizer.quantize_input(tensor, wrapper, **kwargs)
elif quant_type == QuantType.QUANT_WEIGHT: elif quant_type == QuantType.QUANT_WEIGHT:
return wrapper.quantizer.quantize_weight(wrapper, **kwargs) output = wrapper.quantizer.quantize_weight(wrapper, **kwargs)
elif quant_type == QuantType.QUANT_OUTPUT: elif quant_type == QuantType.QUANT_OUTPUT:
return wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs) output = wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs)
else: else:
raise ValueError("unrecognized QuantType.") raise ValueError("unrecognized QuantType.")
bits = QuantGrad.get_bits_length(wrapper.config, quant_type)
qmin, qmax = torch.Tensor([0], device=tensor.device), torch.Tensor([(1 << bits) - 1], device=tensor.device)
ctx.save_for_backward(tensor, wrapper.module.scale, wrapper.module.zero_point, qmin, qmax)
return output
@classmethod @classmethod
def backward(cls, ctx, grad_output): def backward(cls, ctx, grad_output):
tensor, quant_type = ctx.saved_variables tensor, scale, zero_point, qmin, qmax = ctx.saved_variables
output = cls.quant_backward(tensor, grad_output, quant_type) output = cls.quant_backward(tensor, grad_output, scale, zero_point, qmin, qmax)
return output, None, None, None return output, None, None, None
def _check_weight(module): def _check_weight(module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment