Unverified Commit f51d985b authored by lin bin's avatar lin bin Committed by GitHub
Browse files

Add model export for QAT (#3458)

parent 4635b559
......@@ -110,7 +110,6 @@ def get_bits_length(config, quant_type):
else:
return config["quant_bits"].get(quant_type)
class QATGrad(QuantGrad):
@staticmethod
def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax):
......@@ -153,13 +152,26 @@ class QAT_Quantizer(Quantizer):
for layer, config in modules_to_compress:
layer.module.register_buffer("zero_point", torch.Tensor([0.0]))
layer.module.register_buffer("scale", torch.Tensor([1.0]))
if "weight" in config.get("quant_types", []):
layer.module.register_buffer('weight_bit', torch.zeros(1))
if "output" in config.get("quant_types", []):
layer.module.register_buffer('activation_bit', torch.zeros(1))
layer.module.register_buffer('ema_decay', torch.Tensor([0.99]))
layer.module.register_buffer('tracked_min_biased', torch.zeros(1))
layer.module.register_buffer('tracked_min', torch.zeros(1))
layer.module.register_buffer('tracked_max_biased', torch.zeros(1))
layer.module.register_buffer('tracked_max', torch.zeros(1))
def _del_simulated_attr(self, module):
"""
delete redundant parameters in quantize module
"""
del_attr_list = ['old_weight', 'ema_decay', 'tracked_min_biased', 'tracked_max_biased', 'tracked_min', \
'tracked_max', 'scale', 'zero_point', 'weight_bit', 'activation_bit']
for attr in del_attr_list:
if hasattr(module, attr):
delattr(module, attr)
def validate_config(self, model, config_list):
"""
Parameters
......@@ -256,6 +268,7 @@ class QAT_Quantizer(Quantizer):
module.scale, module.zero_point = update_quantization_param(weight_bits, rmin, rmax)
weight = self._quantize(weight_bits, module, weight)
weight = self._dequantize(module, weight)
module.weight_bit = torch.Tensor([weight_bits])
wrapper.module.weight = weight
return weight
......@@ -263,6 +276,7 @@ class QAT_Quantizer(Quantizer):
config = wrapper.config
module = wrapper.module
output_bits = get_bits_length(config, 'output')
module.activation_bit = torch.Tensor([output_bits])
quant_start_step = config.get('quant_start_step', 0)
assert output_bits >= 1, "quant bits length should be at least 1"
......@@ -282,6 +296,47 @@ class QAT_Quantizer(Quantizer):
out = self._dequantize(module, out)
return out
def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None):
"""
Export quantized model weights and calibration parameters(optional)
Parameters
----------
model_path : str
path to save quantized model weight
calibration_path : str
(optional) path to save quantize parameters after calibration
onnx_path : str
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
Returns
-------
Dict
"""
assert model_path is not None, 'model_path must be specified'
self._unwrap_model()
calibration_config = {}
for name, module in self.bound_model.named_modules():
if hasattr(module, 'weight_bit') or hasattr(module, 'activation_bit'):
calibration_config[name] = {}
if hasattr(module, 'weight_bit'):
calibration_config[name]['weight_bit'] = int(module.weight_bit)
if hasattr(module, 'activation_bit'):
calibration_config[name]['activation_bit'] = int(module.activation_bit)
calibration_config[name]['tracked_min'] = float(module.tracked_min_biased)
calibration_config[name]['tracked_max'] = float(module.tracked_max_biased)
self._del_simulated_attr(module)
self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, input_shape, device)
return calibration_config
def fold_bn(self, config, **kwargs):
# TODO simulate folded weight
pass
......@@ -301,6 +356,19 @@ class DoReFaQuantizer(Quantizer):
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
modules_to_compress = self.get_modules_to_compress()
for layer, config in modules_to_compress:
if "weight" in config.get("quant_types", []):
layer.module.register_buffer('weight_bit', torch.zeros(1))
def _del_simulated_attr(self, module):
"""
delete redundant parameters in quantize module
"""
del_attr_list = ['old_weight', 'weight_bit']
for attr in del_attr_list:
if hasattr(module, attr):
delattr(module, attr)
def validate_config(self, model, config_list):
"""
......@@ -330,6 +398,7 @@ class DoReFaQuantizer(Quantizer):
weight = self.quantize(weight, weight_bits)
weight = 2 * weight - 1
wrapper.module.weight = weight
wrapper.module.weight_bit = torch.Tensor([weight_bits])
# wrapper.module.weight.data = weight
return weight
......@@ -338,6 +407,42 @@ class DoReFaQuantizer(Quantizer):
output = torch.round(input_ri * scale) / scale
return output
def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None):
"""
Export quantized model weights and calibration parameters(optional)
Parameters
----------
model_path : str
path to save quantized model weight
calibration_path : str
(optional) path to save quantize parameters after calibration
onnx_path : str
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
Returns
-------
Dict
"""
assert model_path is not None, 'model_path must be specified'
self._unwrap_model()
calibration_config = {}
for name, module in self.bound_model.named_modules():
if hasattr(module, 'weight_bit'):
calibration_config[name] = {}
calibration_config[name]['weight_bit'] = int(module.weight_bit)
self._del_simulated_attr(module)
self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, input_shape, device)
return calibration_config
class ClipGrad(QuantGrad):
@staticmethod
......@@ -356,6 +461,19 @@ class BNNQuantizer(Quantizer):
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
self.quant_grad = ClipGrad
modules_to_compress = self.get_modules_to_compress()
for layer, config in modules_to_compress:
if "weight" in config.get("quant_types", []):
layer.module.register_buffer('weight_bit', torch.zeros(1))
def _del_simulated_attr(self, module):
"""
delete redundant parameters in quantize module
"""
del_attr_list = ['old_weight', 'weight_bit']
for attr in del_attr_list:
if hasattr(module, attr):
delattr(module, attr)
def validate_config(self, model, config_list):
"""
......@@ -384,6 +502,7 @@ class BNNQuantizer(Quantizer):
# remove zeros
weight[weight == 0] = 1
wrapper.module.weight = weight
wrapper.module.weight_bit = torch.Tensor([1.0])
return weight
def quantize_output(self, output, wrapper, **kwargs):
......@@ -391,3 +510,39 @@ class BNNQuantizer(Quantizer):
# remove zeros
out[out == 0] = 1
return out
def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None):
"""
Export quantized model weights and calibration parameters(optional)
Parameters
----------
model_path : str
path to save quantized model weight
calibration_path : str
(optional) path to save quantize parameters after calibration
onnx_path : str
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
Returns
-------
Dict
"""
assert model_path is not None, 'model_path must be specified'
self._unwrap_model()
calibration_config = {}
for name, module in self.bound_model.named_modules():
if hasattr(module, 'weight_bit'):
calibration_config[name] = {}
calibration_config[name]['weight_bit'] = int(module.weight_bit)
self._del_simulated_attr(module)
self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, input_shape, device)
return calibration_config
\ No newline at end of file
......@@ -21,7 +21,6 @@ def _setattr(model, name, module):
model = getattr(model, name)
setattr(model, name_list[-1], module)
class Compressor:
"""
Abstract base PyTorch compressor
......@@ -573,6 +572,67 @@ class Quantizer(Compressor):
return QuantizerModuleWrapper(layer.module, layer.name, layer.type, config, self)
def export_model_save(self, model, model_path, calibration_config=None, calibration_path=None, onnx_path=None, \
input_shape=None, device=None):
"""
This method helps save pytorch model, calibration config, onnx model in quantizer.
Parameters
----------
model : pytorch model
pytorch model to be saved
model_path : str
path to save pytorch
calibration_config: dict
(optional) config of calibration parameters
calibration_path : str
(optional) path to save quantize parameters after calibration
onnx_path : str
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
"""
torch.save(model.state_dict(), model_path)
_logger.info('Model state_dict saved to %s', model_path)
if calibration_path is not None:
torch.save(calibration_config, calibration_path)
_logger.info('Mask dict saved to %s', calibration_path)
if onnx_path is not None:
assert input_shape is not None, 'input_shape must be specified to export onnx model'
# input info needed
if device is None:
device = torch.device('cpu')
input_data = torch.Tensor(*input_shape)
torch.onnx.export(self.bound_model, input_data.to(device), onnx_path)
_logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path)
def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None):
"""
Export quantized model weights and calibration parameters
Parameters
----------
model_path : str
path to save quantized model weight
calibration_path : str
(optional) path to save quantize parameters after calibration
onnx_path : str
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
Returns
-------
Dict
"""
raise NotImplementedError('Quantizer must overload export_model()')
def step_with_optimizer(self):
pass
......
......@@ -274,6 +274,55 @@ class CompressorTestCase(TestCase):
assert math.isclose(model.relu.module.tracked_min_biased, 0.002, abs_tol=eps)
assert math.isclose(model.relu.module.tracked_max_biased, 0.00998, abs_tol=eps)
def test_torch_quantizer_export(self):
config_list_qat = [{
'quant_types': ['weight'],
'quant_bits': 8,
'op_types': ['Conv2d', 'Linear']
}, {
'quant_types': ['output'],
'quant_bits': 8,
'quant_start_step': 0,
'op_types': ['ReLU']
}]
config_list_dorefa = [{
'quant_types': ['weight'],
'quant_bits': {
'weight': 8,
}, # you can just use `int` here because all `quan_types` share same bits length, see config for `ReLu6` below.
'op_types':['Conv2d', 'Linear']
}]
config_list_bnn = [{
'quant_types': ['weight'],
'quant_bits': 1,
'op_types': ['Conv2d', 'Linear']
}, {
'quant_types': ['output'],
'quant_bits': 1,
'op_types': ['ReLU']
}]
config_set = [config_list_qat, config_list_dorefa, config_list_bnn]
quantize_algorithm_set = [torch_quantizer.QAT_Quantizer, torch_quantizer.DoReFaQuantizer, torch_quantizer.BNNQuantizer]
for config, quantize_algorithm in zip(config_set, quantize_algorithm_set):
model = TorchModel()
model.relu = torch.nn.ReLU()
quantizer = quantize_algorithm(model, config)
quantizer.compress()
x = torch.rand((1, 1, 28, 28), requires_grad=True)
y = model(x)
y.backward(torch.ones_like(y))
model_path = "test_model.pth"
calibration_path = "test_calibration.pth"
onnx_path = "test_model.onnx"
input_shape = (1, 1, 28, 28)
device = torch.device("cpu")
calibration_config = quantizer.export_model(model_path, calibration_path, onnx_path, input_shape, device)
assert calibration_config is not None
def test_torch_pruner_validation(self):
# test bad configuraiton
pruner_classes = [torch_pruner.__dict__[x] for x in \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment