Unverified Commit 3b1d5cd4 authored by lin bin's avatar lin bin Committed by GitHub
Browse files

Fix dorefa bnn (#3247)

parent 0a20c3fc
...@@ -111,6 +111,15 @@ def get_bits_length(config, quant_type): ...@@ -111,6 +111,15 @@ def get_bits_length(config, quant_type):
return config["quant_bits"].get(quant_type) return config["quant_bits"].get(quant_type)
class QATGrad(QuantGrad):
@staticmethod
def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax):
tensor_q = QuantGrad._quantize(tensor, scale, zero_point)
mask = (tensor_q < qmin) | (tensor_q > qmax)
grad_output[mask] = 0
return grad_output
class QAT_Quantizer(Quantizer): class QAT_Quantizer(Quantizer):
"""Quantizer defined in: """Quantizer defined in:
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
...@@ -138,6 +147,7 @@ class QAT_Quantizer(Quantizer): ...@@ -138,6 +147,7 @@ class QAT_Quantizer(Quantizer):
types of nn.module you want to apply quantization, eg. 'Conv2d' types of nn.module you want to apply quantization, eg. 'Conv2d'
""" """
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
self.quant_grad = QATGrad
modules_to_compress = self.get_modules_to_compress() modules_to_compress = self.get_modules_to_compress()
self.bound_model.register_buffer("steps", torch.Tensor([1])) self.bound_model.register_buffer("steps", torch.Tensor([1]))
for layer, config in modules_to_compress: for layer, config in modules_to_compress:
...@@ -331,7 +341,7 @@ class DoReFaQuantizer(Quantizer): ...@@ -331,7 +341,7 @@ class DoReFaQuantizer(Quantizer):
class ClipGrad(QuantGrad): class ClipGrad(QuantGrad):
@staticmethod @staticmethod
def quant_backward(tensor, grad_output, quant_type): def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax):
if quant_type == QuantType.QUANT_OUTPUT: if quant_type == QuantType.QUANT_OUTPUT:
grad_output[torch.abs(tensor) > 1] = 0 grad_output[torch.abs(tensor) > 1] = 0
return grad_output return grad_output
......
...@@ -580,10 +580,15 @@ class QuantType: ...@@ -580,10 +580,15 @@ class QuantType:
""" """
Enum class for quantization type. Enum class for quantization type.
""" """
QUANT_INPUT = 'input' QUANT_INPUT = 0
QUANT_WEIGHT = 'weight' QUANT_WEIGHT = 1
QUANT_OUTPUT = 'output' QUANT_OUTPUT = 2
QType_Dict = {
0: "input",
1: "weight",
2: "output"
}
class QuantGrad(torch.autograd.Function): class QuantGrad(torch.autograd.Function):
""" """
...@@ -628,7 +633,7 @@ class QuantGrad(torch.autograd.Function): ...@@ -628,7 +633,7 @@ class QuantGrad(torch.autograd.Function):
return config["quant_bits"].get(quant_type) return config["quant_bits"].get(quant_type)
@staticmethod @staticmethod
def quant_backward(tensor, grad_output, scale, zero_point, qmin, qmax): def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax):
""" """
This method should be overrided by subclass to provide customized backward function, This method should be overrided by subclass to provide customized backward function,
default implementation is Straight-Through Estimator default implementation is Straight-Through Estimator
...@@ -652,9 +657,6 @@ class QuantGrad(torch.autograd.Function): ...@@ -652,9 +657,6 @@ class QuantGrad(torch.autograd.Function):
tensor tensor
gradient of the input of quantization operation gradient of the input of quantization operation
""" """
tensor_q = QuantGrad._quantize(tensor, scale, zero_point)
mask = (tensor_q < qmin) | (tensor_q > qmax)
grad_output[mask] = 0
return grad_output return grad_output
@staticmethod @staticmethod
...@@ -668,15 +670,21 @@ class QuantGrad(torch.autograd.Function): ...@@ -668,15 +670,21 @@ class QuantGrad(torch.autograd.Function):
else: else:
raise ValueError("unrecognized QuantType.") raise ValueError("unrecognized QuantType.")
bits = QuantGrad.get_bits_length(wrapper.config, quant_type)
qmin, qmax = torch.Tensor([0]).to(device=tensor.device), torch.Tensor([(1 << bits)-1]).to(device=tensor.device) bits = QuantGrad.get_bits_length(wrapper.config, QType_Dict[quant_type])
ctx.save_for_backward(tensor, wrapper.module.scale, wrapper.module.zero_point, qmin, qmax) qmin, qmax = torch.Tensor([0]).to(tensor.device), torch.Tensor([(1 << bits) - 1]).to(tensor.device)
if hasattr(wrapper.module, 'scale') and hasattr(wrapper.module, 'zero_point'):
scale = wrapper.module.scale
zero_point = wrapper.module.zero_point
else:
scale, zero_point = None, None
ctx.save_for_backward(tensor, torch.Tensor([quant_type]), scale, zero_point, qmin, qmax)
return output return output
@classmethod @classmethod
def backward(cls, ctx, grad_output): def backward(cls, ctx, grad_output):
tensor, scale, zero_point, qmin, qmax = ctx.saved_variables tensor, quant_type, scale, zero_point, qmin, qmax = ctx.saved_variables
output = cls.quant_backward(tensor, grad_output, scale, zero_point, qmin, qmax) output = cls.quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax)
return output, None, None, None return output, None, None, None
def _check_weight(module): def _check_weight(module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment