Commit 4f3ee9cb authored by Cjkkkk's avatar Cjkkkk Committed by chicm-ms
Browse files

add quantization backward support (#1854)

parent 7a558113
...@@ -35,7 +35,6 @@ def train(model, quantizer, device, train_loader, optimizer): ...@@ -35,7 +35,6 @@ def train(model, quantizer, device, train_loader, optimizer):
loss = F.nll_loss(output, target) loss = F.nll_loss(output, target)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
quantizer.step()
if batch_idx % 100 == 0: if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item())) print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))
......
...@@ -100,7 +100,7 @@ def get_bits_length(config, quant_type): ...@@ -100,7 +100,7 @@ def get_bits_length(config, quant_type):
class QAT_Quantizer(Quantizer): class QAT_Quantizer(Quantizer):
"""Quantizer using the DoReFa scheme, as defined in: """Quantizer defined in:
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
""" """
...@@ -227,20 +227,17 @@ class DoReFaQuantizer(Quantizer): ...@@ -227,20 +227,17 @@ class DoReFaQuantizer(Quantizer):
(https://arxiv.org/abs/1606.06160) (https://arxiv.org/abs/1606.06160)
""" """
def __init__(self, model, config_list): def __init__(self, model, config_list):
"""
config_list: supported keys:
- q_bits
"""
super().__init__(model, config_list) super().__init__(model, config_list)
def quantize_weight(self, weight, config, **kwargs): def quantize_weight(self, weight, config, **kwargs):
weight_bits = get_bits_length(config, 'weight')
out = weight.tanh() out = weight.tanh()
out = out / (2 * out.abs().max()) + 0.5 out = out / (2 * out.abs().max()) + 0.5
out = self.quantize(out, config['q_bits']) out = self.quantize(out, weight_bits)
out = 2 * out -1 out = 2 * out -1
return out return out
def quantize(self, input_ri, q_bits): def quantize(self, input_ri, q_bits):
scale = pow(2, q_bits)-1 scale = pow(2, q_bits)-1
output = torch.round(input_ri*scale)/scale output = torch.round(input_ri*scale)/scale
return output return output
\ No newline at end of file
...@@ -250,6 +250,10 @@ class Quantizer(Compressor): ...@@ -250,6 +250,10 @@ class Quantizer(Compressor):
Base quantizer for pytorch quantizer Base quantizer for pytorch quantizer
""" """
def __init__(self, model, config_list):
super().__init__(model, config_list)
self.quant_grad = QuantGrad
def quantize_weight(self, weight, config, op, op_type, op_name): def quantize_weight(self, weight, config, op, op_type, op_name):
""" """
quantize should overload this method to quantize weight. quantize should overload this method to quantize weight.
...@@ -262,7 +266,7 @@ class Quantizer(Compressor): ...@@ -262,7 +266,7 @@ class Quantizer(Compressor):
config : dict config : dict
the configuration for weight quantization the configuration for weight quantization
""" """
raise NotImplementedError("Quantizer must overload quantize_weight()") raise NotImplementedError('Quantizer must overload quantize_weight()')
def quantize_output(self, output, config, op, op_type, op_name): def quantize_output(self, output, config, op, op_type, op_name):
""" """
...@@ -276,7 +280,7 @@ class Quantizer(Compressor): ...@@ -276,7 +280,7 @@ class Quantizer(Compressor):
config : dict config : dict
the configuration for output quantization the configuration for output quantization
""" """
raise NotImplementedError("Quantizer must overload quantize_output()") raise NotImplementedError('Quantizer must overload quantize_output()')
def quantize_input(self, *inputs, config, op, op_type, op_name): def quantize_input(self, *inputs, config, op, op_type, op_name):
""" """
...@@ -290,7 +294,7 @@ class Quantizer(Compressor): ...@@ -290,7 +294,7 @@ class Quantizer(Compressor):
config : dict config : dict
the configuration for inputs quantization the configuration for inputs quantization
""" """
raise NotImplementedError("Quantizer must overload quantize_input()") raise NotImplementedError('Quantizer must overload quantize_input()')
def _instrument_layer(self, layer, config): def _instrument_layer(self, layer, config):
...@@ -305,62 +309,93 @@ class Quantizer(Compressor): ...@@ -305,62 +309,93 @@ class Quantizer(Compressor):
the configuration for quantization the configuration for quantization
""" """
assert layer._forward is None, 'Each model can only be compressed once' assert layer._forward is None, 'Each model can only be compressed once'
assert "quant_types" in config, 'must provide quant_types in config' assert 'quant_types' in config, 'must provide quant_types in config'
assert isinstance(config["quant_types"], list), 'quant_types must be list type' assert isinstance(config['quant_types'], list), 'quant_types must be list type'
assert "quant_bits" in config, 'must provide quant_bits in config' assert 'quant_bits' in config, 'must provide quant_bits in config'
assert isinstance(config["quant_bits"], int) or isinstance(config["quant_bits"], dict), 'quant_bits must be dict type or int type' assert isinstance(config['quant_bits'], int) or isinstance(config['quant_bits'], dict), 'quant_bits must be dict type or int type'
if isinstance(config["quant_bits"], dict): if isinstance(config['quant_bits'], dict):
for quant_type in config["quant_types"]: for quant_type in config['quant_types']:
assert quant_type in config["quant_bits"], 'bits length for %s must be specified in quant_bits dict' % quant_type assert quant_type in config['quant_bits'], 'bits length for %s must be specified in quant_bits dict' % quant_type
if 'weight' in config["quant_types"]: if 'weight' in config['quant_types']:
if not _check_weight(layer.module): if not _check_weight(layer.module):
_logger.warning('Module %s does not have parameter "weight"', layer.name) _logger.warning('Module %s does not have parameter "weight"', layer.name)
else:
# old_weight is used to store origin weight and weight is used to store quantized weight
# the reason why weight is buffer instead of parameter is because in pytorch parameter is used as leaf
# if weight is leaf , then old_weight can not be updated.
layer.module.register_parameter('old_weight', torch.nn.Parameter(layer.module.weight))
delattr(layer.module, 'weight')
layer.module.register_buffer('weight', layer.module.old_weight)
layer._forward = layer.module.forward layer._forward = layer.module.forward
def new_forward(*inputs): def new_forward(*inputs):
if 'input' in config["quant_types"]: if 'input' in config['quant_types']:
inputs = straight_through_quantize_input.apply(inputs, self, config, layer) inputs = self.quant_grad.apply(inputs, QuantType.QUANT_INPUT, self.quantize_input, config, layer)
if 'weight' in config["quant_types"] and _check_weight(layer.module): if 'weight' in config['quant_types'] and _check_weight(layer.module):
weight = layer.module.weight.data new_weight = self.quant_grad.apply(layer.module.old_weight, QuantType.QUANT_WEIGHT, self.quantize_weight, config, layer)
new_weight = self.quantize_weight(weight, config, op=layer.module, op_type=layer.type, op_name=layer.name) layer.module.weight = new_weight
layer.module.weight.data = new_weight
result = layer._forward(*inputs) result = layer._forward(*inputs)
layer.module.weight.data = weight
else: else:
result = layer._forward(*inputs) result = layer._forward(*inputs)
if 'output' in config["quant_types"]: if 'output' in config['quant_types']:
result = straight_through_quantize_output.apply(result, self, config, layer) result = self.quant_grad.apply(result, QuantType.QUANT_OUTPUT, self.quantize_output, config, layer)
return result return result
layer.module.forward = new_forward layer.module.forward = new_forward
class QuantType:
"""
Enum class for quantization type.
"""
QUANT_INPUT = 0
QUANT_WEIGHT = 1
QUANT_OUTPUT = 2
class straight_through_quantize_output(torch.autograd.Function): class QuantGrad(torch.autograd.Function):
"""
Base class for overriding backward function of quantization operation.
"""
@staticmethod @staticmethod
def forward(ctx, output, quantizer, config, layer): def quant_backward(tensor, grad_output, quant_type):
return quantizer.quantize_output(output, config, op=layer.module, op_type=layer.type, op_name=layer.name) """
This method should be overrided by subclass to provide customized backward function,
default implementation is Straight-Through Estimator
@staticmethod Parameters
def backward(ctx, grad_output): ----------
# Straight-through estimator tensor : Tensor
return grad_output, None, None, None input of quantization operation
grad_output : Tensor
gradient of the output of quantization operation
quant_type : QuantType
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`,
you can define different behavior for different types.
class straight_through_quantize_input(torch.autograd.Function): Returns
@staticmethod -------
def forward(ctx, inputs, quantizer, config, layer): tensor
return quantizer.quantize_input(inputs, config, op=layer.module, op_type=layer.type, op_name=layer.name) gradient of the input of quantization operation
"""
return grad_output
@staticmethod @staticmethod
def backward(ctx, grad_output): def forward(ctx, tensor, quant_type, quant_func, config, layer):
# Straight-through estimator ctx.save_for_backward(tensor, torch.Tensor([quant_type]))
return grad_output, None, None, None return quant_func(tensor, config, op=layer.module, op_type=layer.type, op_name=layer.name)
@classmethod
def backward(cls, ctx, grad_output):
tensor, quant_type = ctx.saved_variables
output = cls.quant_backward(tensor, grad_output, quant_type)
return output, None, None, None, None
def _check_weight(module): def _check_weight(module):
try: try:
return isinstance(module.weight, torch.nn.Parameter) and isinstance(module.weight.data, torch.Tensor) return isinstance(module.weight.data, torch.Tensor)
except AttributeError: except AttributeError:
return False return False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment