"docs/source/compression/overview_zh.rst" did not exist on "88ffe9089d39695ffeba095b8b598946597cefdf"
Unverified Commit 86335921 authored by lin bin's avatar lin bin Committed by GitHub
Browse files

[Model Compression Quantization] Unify variable name (#3990)

parent e5c3ac63
...@@ -58,10 +58,10 @@ def post_training_quantization_example(train_loader, test_loader, device): ...@@ -58,10 +58,10 @@ def post_training_quantization_example(train_loader, test_loader, device):
model = NaiveModel() model = NaiveModel()
config = { config = {
'conv1':{'weight_bit':8, 'activation_bit':8}, 'conv1':{'weight_bits':8, 'output_bits':8},
'conv2':{'weight_bit':32, 'activation_bit':32}, 'conv2':{'weight_bits':32, 'output_bits':32},
'fc1':{'weight_bit':16, 'activation_bit':16}, 'fc1':{'weight_bits':16, 'output_bits':16},
'fc2':{'weight_bit':8, 'activation_bit':8} 'fc2':{'weight_bits':8, 'output_bits':8}
} }
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
...@@ -102,8 +102,10 @@ def quantization_aware_training_example(train_loader, test_loader, device): ...@@ -102,8 +102,10 @@ def quantization_aware_training_example(train_loader, test_loader, device):
] ]
# finetune the model by using QAT # finetune the model by using QAT
# enable batchnorm folding mode
dummy_input = torch.randn(1, 1, 28, 28)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
quantizer = QAT_Quantizer(model, configure_list, optimizer) quantizer = QAT_Quantizer(model, configure_list, optimizer, dummy_input=dummy_input)
quantizer.compress() quantizer.compress()
model.to(device) model.to(device)
......
...@@ -124,7 +124,7 @@ class QATGrad(QuantGrad): ...@@ -124,7 +124,7 @@ class QATGrad(QuantGrad):
class ObserverQuantizer(Quantizer): class ObserverQuantizer(Quantizer):
"""This quantizer uses observers to record weight/activation statistics to get quantization information. """This quantizer uses observers to record weight/output statistics to get quantization information.
The whole process can be divided into three steps: The whole process can be divided into three steps:
1. It will register observers to the place where quantization would happen (just like registering hooks). 1. It will register observers to the place where quantization would happen (just like registering hooks).
2. The observers would record tensors' statistics during calibration. 2. The observers would record tensors' statistics during calibration.
...@@ -140,7 +140,7 @@ class ObserverQuantizer(Quantizer): ...@@ -140,7 +140,7 @@ class ObserverQuantizer(Quantizer):
# TODO: # TODO:
# 1. support dtype and qscheme customization through config_list. Current settings: # 1. support dtype and qscheme customization through config_list. Current settings:
# weight observer : per_tensor_symmetric, qint8 # weight observer : per_tensor_symmetric, qint8
# activation observer : per_tensor_affine, quint8, reduce_range=True # output observer : per_tensor_affine, quint8, reduce_range=True
# 2. add more kinds of observers, such as Kullback-Leibler divergence. # 2. add more kinds of observers, such as Kullback-Leibler divergence.
# 3. add batch normalization folding # 3. add batch normalization folding
assert not model.training, "Currently the observer quantizer only works in evaluation mode." assert not model.training, "Currently the observer quantizer only works in evaluation mode."
...@@ -148,8 +148,8 @@ class ObserverQuantizer(Quantizer): ...@@ -148,8 +148,8 @@ class ObserverQuantizer(Quantizer):
self.device = next(model.parameters()).device self.device = next(model.parameters()).device
modules_to_compress = self.get_modules_to_compress() modules_to_compress = self.get_modules_to_compress()
all_observers = defaultdict(dict) all_observers = defaultdict(dict)
weight_q_min, weight_q_max = -127, 127 weight_qmin, weight_qmax = -127, 127
activation_q_min, activation_q_max = 0, 127 # reduce_range is set to True output_qmin, output_qmax = 0, 127 # reduce_range is set to True
self.compressed = False self.compressed = False
for layer, config in modules_to_compress: for layer, config in modules_to_compress:
...@@ -157,16 +157,16 @@ class ObserverQuantizer(Quantizer): ...@@ -157,16 +157,16 @@ class ObserverQuantizer(Quantizer):
module = layer.module module = layer.module
if "weight" in config.get("quant_types", []): if "weight" in config.get("quant_types", []):
all_observers[layer_name]["weight"] = default_weight_observer() all_observers[layer_name]["weight"] = default_weight_observer()
setattr(module, "weight_qmax", weight_q_max) setattr(module, "weight_qmax", weight_qmax)
setattr(module, "weight_qmin", weight_q_min) setattr(module, "weight_qmin", weight_qmin)
if "input" in config.get("quant_types", []): if "input" in config.get("quant_types", []):
all_observers[layer_name]["input"] = default_histogram_observer() all_observers[layer_name]["input"] = default_histogram_observer()
setattr(module, "input_qmax", activation_q_max) setattr(module, "input_qmax", output_qmax)
setattr(module, "input_qmin", activation_q_min) setattr(module, "input_qmin", output_qmin)
if "output" in config.get("quant_types", []): if "output" in config.get("quant_types", []):
all_observers[layer_name]["output"] = default_histogram_observer() all_observers[layer_name]["output"] = default_histogram_observer()
setattr(module, "output_qmax", activation_q_max) setattr(module, "output_qmax", output_qmax)
setattr(module, "output_qmin", activation_q_min) setattr(module, "output_qmin", output_qmin)
self.all_observers = all_observers self.all_observers = all_observers
self.bound_model.to(self.device) self.bound_model.to(self.device)
...@@ -306,29 +306,29 @@ class ObserverQuantizer(Quantizer): ...@@ -306,29 +306,29 @@ class ObserverQuantizer(Quantizer):
if hasattr(module, 'weight_scale') or hasattr(module, 'input_scale') or hasattr(module, 'output_scale'): if hasattr(module, 'weight_scale') or hasattr(module, 'input_scale') or hasattr(module, 'output_scale'):
calibration_config[name] = {} calibration_config[name] = {}
if hasattr(module, 'weight_scale'): if hasattr(module, 'weight_scale'):
calibration_config[name]['weight_bit'] = 8 calibration_config[name]['weight_bits'] = 8
val = float(module.weight_scale * module.weight_qmax) val = float(module.weight_scale * module.weight_qmax)
calibration_config[name]['tracked_max_weight'] = val calibration_config[name]['tracked_max_weight'] = val
calibration_config[name]['tracked_min_weight'] = -val calibration_config[name]['tracked_min_weight'] = -val
calibration_config[name]['tracked_weight_qmin'] = -127 calibration_config[name]['tracked_qmin_weight'] = -127
calibration_config[name]['tracked_weight_qmax'] = 127 calibration_config[name]['tracked_qmax_weight'] = 127
# refactor these magic numbers when customizations of dtype and qscheme are ready. # refactor these magic numbers when customizations of dtype and qscheme are ready.
if hasattr(module, 'input_scale'): if hasattr(module, 'input_scale'):
calibration_config[name]['input_bit'] = 8 calibration_config[name]['input_bits'] = 8
max_input = float(module.input_scale * (module.input_qmax - module.input_zero_point)) max_input = float(module.input_scale * (module.input_qmax - module.input_zero_point))
min_input = float(module.input_scale * (module.input_qmin - module.input_zero_point)) min_input = float(module.input_scale * (module.input_qmin - module.input_zero_point))
calibration_config[name]['tracked_min_input'] = min_input calibration_config[name]['tracked_min_input'] = min_input
calibration_config[name]['tracked_max_input'] = max_input calibration_config[name]['tracked_max_input'] = max_input
calibration_config[name]['tracked_input_qmin'] = 0 calibration_config[name]['tracked_qmin_input'] = 0
calibration_config[name]['tracked_input_qmax'] = 127 calibration_config[name]['tracked_qmax_input'] = 127
if hasattr(module, 'output_scale'): if hasattr(module, 'output_scale'):
calibration_config[name]['activation_bit'] = 8 calibration_config[name]['output_bits'] = 8
max_input = float(module.output_scale * (module.output_qmax - module.output_zero_point)) max_input = float(module.output_scale * (module.output_qmax - module.output_zero_point))
min_input = float(module.output_scale * (module.output_qmin - module.output_zero_point)) min_input = float(module.output_scale * (module.output_qmin - module.output_zero_point))
calibration_config[name]['tracked_min_activation'] = min_input calibration_config[name]['tracked_min_output'] = min_input
calibration_config[name]['tracked_max_activation'] = max_input calibration_config[name]['tracked_max_output'] = max_input
calibration_config[name]['tracked_activation_qmin'] = 0 calibration_config[name]['tracked_qmin_output'] = 0
calibration_config[name]['tracked_activation_qmax'] = 127 calibration_config[name]['tracked_qmax_output'] = 127
self._del_simulated_attr(module) self._del_simulated_attr(module)
self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path,
...@@ -354,7 +354,7 @@ class QAT_Quantizer(Quantizer): ...@@ -354,7 +354,7 @@ class QAT_Quantizer(Quantizer):
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
""" """
def __init__(self, model, config_list, optimizer=None, dummy_input=None): def __init__(self, model, config_list, optimizer, dummy_input=None):
""" """
Parameters Parameters
---------- ----------
...@@ -370,7 +370,7 @@ class QAT_Quantizer(Quantizer): ...@@ -370,7 +370,7 @@ class QAT_Quantizer(Quantizer):
when the type is int, all quantization types share same bits length when the type is int, all quantization types share same bits length
- quant_start_step : int - quant_start_step : int
disable quantization until model are run by certain number of steps, this allows the network to enter a more stable disable quantization until model are run by certain number of steps, this allows the network to enter a more stable
state where activation quantization ranges do not exclude a significant fraction of values, default value is 0 state where output quantization ranges do not exclude a significant fraction of values, default value is 0
- op_types : list of string - op_types : list of string
types of nn.module you want to apply quantization, eg. 'Conv2d' types of nn.module you want to apply quantization, eg. 'Conv2d'
- dummy_input : tuple of tensor - dummy_input : tuple of tensor
...@@ -379,6 +379,7 @@ class QAT_Quantizer(Quantizer): ...@@ -379,6 +379,7 @@ class QAT_Quantizer(Quantizer):
given, the batch normalization folding would be disabled. given, the batch normalization folding would be disabled.
""" """
assert isinstance(optimizer, torch.optim.Optimizer), "unrecognized optimizer type"
super().__init__(model, config_list, optimizer, dummy_input) super().__init__(model, config_list, optimizer, dummy_input)
self.quant_grad = QATGrad.apply self.quant_grad = QATGrad.apply
modules_to_compress = self.get_modules_to_compress() modules_to_compress = self.get_modules_to_compress()
...@@ -389,22 +390,22 @@ class QAT_Quantizer(Quantizer): ...@@ -389,22 +390,22 @@ class QAT_Quantizer(Quantizer):
layer.module.register_buffer("scale", torch.Tensor([1.0])) layer.module.register_buffer("scale", torch.Tensor([1.0]))
layer.module.register_buffer('ema_decay', torch.Tensor([0.99])) layer.module.register_buffer('ema_decay', torch.Tensor([0.99]))
if "weight" in config.get("quant_types", []): if "weight" in config.get("quant_types", []):
layer.module.register_buffer('weight_bit', torch.zeros(1)) layer.module.register_buffer('weight_bits', torch.zeros(1))
layer.module.register_buffer('tracked_min_input', torch.zeros(1)) layer.module.register_buffer('tracked_min_input', torch.zeros(1))
layer.module.register_buffer('tracked_max_input', torch.zeros(1)) layer.module.register_buffer('tracked_max_input', torch.zeros(1))
if "output" in config.get("quant_types", []): if "output" in config.get("quant_types", []):
layer.module.register_buffer('activation_bit', torch.zeros(1)) layer.module.register_buffer('output_bits', torch.zeros(1))
layer.module.register_buffer('tracked_min_activation', torch.zeros(1)) layer.module.register_buffer('tracked_min_output', torch.zeros(1))
layer.module.register_buffer('tracked_max_activation', torch.zeros(1)) layer.module.register_buffer('tracked_max_output', torch.zeros(1))
self.bound_model.to(device) self.bound_model.to(device)
def _del_simulated_attr(self, module): def _del_simulated_attr(self, module):
""" """
delete redundant parameters in quantize module delete redundant parameters in quantize module
""" """
del_attr_list = ['old_weight', 'old_bias', 'ema_decay', 'tracked_min_activation', 'tracked_max_activation', del_attr_list = ['old_weight', 'old_bias', 'ema_decay', 'tracked_min_output', 'tracked_max_output',
'tracked_min_input', 'tracked_max_input', 'scale', 'zero_point', 'weight_bit', 'tracked_min_input', 'tracked_max_input', 'scale', 'zero_point', 'weight_bits',
'activation_bit', 'BN_FOLD_TAG'] 'output_bits', 'BN_FOLD_TAG']
for attr in del_attr_list: for attr in del_attr_list:
if hasattr(module, attr): if hasattr(module, attr):
delattr(module, attr) delattr(module, attr)
...@@ -506,7 +507,7 @@ class QAT_Quantizer(Quantizer): ...@@ -506,7 +507,7 @@ class QAT_Quantizer(Quantizer):
module.scale, module.zero_point = update_quantization_param(weight_bits, rmin, rmax) module.scale, module.zero_point = update_quantization_param(weight_bits, rmin, rmax)
weight = self._quantize(weight_bits, module, weight) weight = self._quantize(weight_bits, module, weight)
weight = self._dequantize(module, weight) weight = self._dequantize(module, weight)
module.weight_bit = torch.Tensor([weight_bits]) module.weight_bits = torch.Tensor([weight_bits])
wrapper.module.weight = weight wrapper.module.weight = weight
return weight return weight
...@@ -514,23 +515,23 @@ class QAT_Quantizer(Quantizer): ...@@ -514,23 +515,23 @@ class QAT_Quantizer(Quantizer):
config = wrapper.config config = wrapper.config
module = wrapper.module module = wrapper.module
output_bits = get_bits_length(config, 'output') output_bits = get_bits_length(config, 'output')
module.activation_bit = torch.Tensor([output_bits]) module.output_bits = torch.Tensor([output_bits])
quant_start_step = config.get('quant_start_step', 0) quant_start_step = config.get('quant_start_step', 0)
assert output_bits >= 1, "quant bits length should be at least 1" assert output_bits >= 1, "quant bits length should be at least 1"
if quant_start_step > self.bound_model.steps: if quant_start_step > self.bound_model.steps:
module.tracked_min_activation, module.tracked_max_activation = torch.min(output), torch.max(output) module.tracked_min_output, module.tracked_max_output = torch.min(output), torch.max(output)
return output return output
# we dont update output quantization parameters in evaluation stage # we dont update output quantization parameters in evaluation stage
if wrapper.training: if wrapper.training:
current_min, current_max = torch.min(output), torch.max(output) current_min, current_max = torch.min(output), torch.max(output)
module.tracked_min_activation = update_ema(module.tracked_min_activation, current_min, module.tracked_min_output = update_ema(module.tracked_min_output, current_min,
module.ema_decay) module.ema_decay)
module.tracked_max_activation = update_ema(module.tracked_max_activation, current_max, module.tracked_max_output = update_ema(module.tracked_max_output, current_max,
module.ema_decay) module.ema_decay)
module.scale, module.zero_point = update_quantization_param( module.scale, module.zero_point = update_quantization_param(
output_bits, module.tracked_min_activation, module.tracked_max_activation) output_bits, module.tracked_min_output, module.tracked_max_output)
out = self._quantize(output_bits, module, output) out = self._quantize(output_bits, module, output)
out = self._dequantize(module, out) out = self._dequantize(module, out)
return out return out
...@@ -562,10 +563,10 @@ class QAT_Quantizer(Quantizer): ...@@ -562,10 +563,10 @@ class QAT_Quantizer(Quantizer):
calibration_config = {} calibration_config = {}
for name, module in self.bound_model.named_modules(): for name, module in self.bound_model.named_modules():
if hasattr(module, 'weight_bit') or hasattr(module, 'activation_bit'): if hasattr(module, 'weight_bits') or hasattr(module, 'output_bits'):
calibration_config[name] = {} calibration_config[name] = {}
if hasattr(module, 'weight_bit'): if hasattr(module, 'weight_bits'):
calibration_config[name]['weight_bit'] = int(module.weight_bit) calibration_config[name]['weight_bits'] = int(module.weight_bits)
calibration_config[name]['tracked_min_input'] = float(module.tracked_min_input) calibration_config[name]['tracked_min_input'] = float(module.tracked_min_input)
calibration_config[name]['tracked_max_input'] = float(module.tracked_max_input) calibration_config[name]['tracked_max_input'] = float(module.tracked_max_input)
...@@ -585,10 +586,10 @@ class QAT_Quantizer(Quantizer): ...@@ -585,10 +586,10 @@ class QAT_Quantizer(Quantizer):
else: else:
setattr(module, 'bias', None) setattr(module, 'bias', None)
if hasattr(module, 'activation_bit'): if hasattr(module, 'output_bits'):
calibration_config[name]['activation_bit'] = int(module.activation_bit) calibration_config[name]['output_bits'] = int(module.output_bits)
calibration_config[name]['tracked_min_activation'] = float(module.tracked_min_activation) calibration_config[name]['tracked_min_output'] = float(module.tracked_min_output)
calibration_config[name]['tracked_max_activation'] = float(module.tracked_max_activation) calibration_config[name]['tracked_max_output'] = float(module.tracked_max_output)
self._del_simulated_attr(module) self._del_simulated_attr(module)
self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, input_shape, device) self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, input_shape, device)
...@@ -642,20 +643,21 @@ class DoReFaQuantizer(Quantizer): ...@@ -642,20 +643,21 @@ class DoReFaQuantizer(Quantizer):
(https://arxiv.org/abs/1606.06160) (https://arxiv.org/abs/1606.06160)
""" """
def __init__(self, model, config_list, optimizer=None): def __init__(self, model, config_list, optimizer):
assert isinstance(optimizer, torch.optim.Optimizer), "unrecognized optimizer type"
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
device = next(model.parameters()).device device = next(model.parameters()).device
modules_to_compress = self.get_modules_to_compress() modules_to_compress = self.get_modules_to_compress()
for layer, config in modules_to_compress: for layer, config in modules_to_compress:
if "weight" in config.get("quant_types", []): if "weight" in config.get("quant_types", []):
layer.module.register_buffer('weight_bit', torch.zeros(1)) layer.module.register_buffer('weight_bits', torch.zeros(1))
self.bound_model.to(device) self.bound_model.to(device)
def _del_simulated_attr(self, module): def _del_simulated_attr(self, module):
""" """
delete redundant parameters in quantize module delete redundant parameters in quantize module
""" """
del_attr_list = ['old_weight', 'weight_bit'] del_attr_list = ['old_weight', 'weight_bits']
for attr in del_attr_list: for attr in del_attr_list:
if hasattr(module, attr): if hasattr(module, attr):
delattr(module, attr) delattr(module, attr)
...@@ -689,7 +691,7 @@ class DoReFaQuantizer(Quantizer): ...@@ -689,7 +691,7 @@ class DoReFaQuantizer(Quantizer):
weight = self.quantize(weight, weight_bits) weight = self.quantize(weight, weight_bits)
weight = 2 * weight - 1 weight = 2 * weight - 1
wrapper.module.weight = weight wrapper.module.weight = weight
wrapper.module.weight_bit = torch.Tensor([weight_bits]) wrapper.module.weight_bits = torch.Tensor([weight_bits])
# wrapper.module.weight.data = weight # wrapper.module.weight.data = weight
return weight return weight
...@@ -725,9 +727,9 @@ class DoReFaQuantizer(Quantizer): ...@@ -725,9 +727,9 @@ class DoReFaQuantizer(Quantizer):
calibration_config = {} calibration_config = {}
for name, module in self.bound_model.named_modules(): for name, module in self.bound_model.named_modules():
if hasattr(module, 'weight_bit'): if hasattr(module, 'weight_bits'):
calibration_config[name] = {} calibration_config[name] = {}
calibration_config[name]['weight_bit'] = int(module.weight_bit) calibration_config[name]['weight_bits'] = int(module.weight_bits)
self._del_simulated_attr(module) self._del_simulated_attr(module)
self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, input_shape, device) self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, input_shape, device)
...@@ -745,25 +747,26 @@ class ClipGrad(QuantGrad): ...@@ -745,25 +747,26 @@ class ClipGrad(QuantGrad):
class BNNQuantizer(Quantizer): class BNNQuantizer(Quantizer):
"""Binarized Neural Networks, as defined in: """Binarized Neural Networks, as defined in:
Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1 Binarized Neural Networks: Training Deep Neural Networks with Weights and Outputs Constrained to +1 or -1
(https://arxiv.org/abs/1602.02830) (https://arxiv.org/abs/1602.02830)
""" """
def __init__(self, model, config_list, optimizer=None): def __init__(self, model, config_list, optimizer):
assert isinstance(optimizer, torch.optim.Optimizer), "unrecognized optimizer type"
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
device = next(model.parameters()).device device = next(model.parameters()).device
self.quant_grad = ClipGrad.apply self.quant_grad = ClipGrad.apply
modules_to_compress = self.get_modules_to_compress() modules_to_compress = self.get_modules_to_compress()
for layer, config in modules_to_compress: for layer, config in modules_to_compress:
if "weight" in config.get("quant_types", []): if "weight" in config.get("quant_types", []):
layer.module.register_buffer('weight_bit', torch.zeros(1)) layer.module.register_buffer('weight_bits', torch.zeros(1))
self.bound_model.to(device) self.bound_model.to(device)
def _del_simulated_attr(self, module): def _del_simulated_attr(self, module):
""" """
delete redundant parameters in quantize module delete redundant parameters in quantize module
""" """
del_attr_list = ['old_weight', 'weight_bit'] del_attr_list = ['old_weight', 'weight_bits']
for attr in del_attr_list: for attr in del_attr_list:
if hasattr(module, attr): if hasattr(module, attr):
delattr(module, attr) delattr(module, attr)
...@@ -796,7 +799,7 @@ class BNNQuantizer(Quantizer): ...@@ -796,7 +799,7 @@ class BNNQuantizer(Quantizer):
# remove zeros # remove zeros
weight[weight == 0] = 1 weight[weight == 0] = 1
wrapper.module.weight = weight wrapper.module.weight = weight
wrapper.module.weight_bit = torch.Tensor([1.0]) wrapper.module.weight_bits = torch.Tensor([1.0])
return weight return weight
def quantize_output(self, output, wrapper, **kwargs): def quantize_output(self, output, wrapper, **kwargs):
...@@ -832,9 +835,9 @@ class BNNQuantizer(Quantizer): ...@@ -832,9 +835,9 @@ class BNNQuantizer(Quantizer):
calibration_config = {} calibration_config = {}
for name, module in self.bound_model.named_modules(): for name, module in self.bound_model.named_modules():
if hasattr(module, 'weight_bit'): if hasattr(module, 'weight_bits'):
calibration_config[name] = {} calibration_config[name] = {}
calibration_config[name]['weight_bit'] = int(module.weight_bit) calibration_config[name]['weight_bits'] = int(module.weight_bits)
self._del_simulated_attr(module) self._del_simulated_attr(module)
self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, input_shape, device) self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, input_shape, device)
...@@ -848,7 +851,7 @@ class LsqQuantizer(Quantizer): ...@@ -848,7 +851,7 @@ class LsqQuantizer(Quantizer):
https://arxiv.org/pdf/1902.08153.pdf https://arxiv.org/pdf/1902.08153.pdf
""" """
def __init__(self, model, config_list, optimizer=None): def __init__(self, model, config_list, optimizer):
""" """
Parameters Parameters
---------- ----------
...@@ -864,10 +867,11 @@ class LsqQuantizer(Quantizer): ...@@ -864,10 +867,11 @@ class LsqQuantizer(Quantizer):
when the type is int, all quantization types share same bits length when the type is int, all quantization types share same bits length
- quant_start_step : int - quant_start_step : int
disable quantization until model are run by certain number of steps, this allows the network to enter a more stable disable quantization until model are run by certain number of steps, this allows the network to enter a more stable
state where activation quantization ranges do not exclude a significant fraction of values, default value is 0 state where output quantization ranges do not exclude a significant fraction of values, default value is 0
- op_types : list of string - op_types : list of string
types of nn.module you want to apply quantization, eg. 'Conv2d' types of nn.module you want to apply quantization, eg. 'Conv2d'
""" """
assert isinstance(optimizer, torch.optim.Optimizer), "unrecognized optimizer type"
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
device = next(model.parameters()).device device = next(model.parameters()).device
self.quant_grad = QuantForward() self.quant_grad = QuantForward()
...@@ -877,10 +881,10 @@ class LsqQuantizer(Quantizer): ...@@ -877,10 +881,10 @@ class LsqQuantizer(Quantizer):
if "weight" in config.get("quant_types", []): if "weight" in config.get("quant_types", []):
layer.module.register_parameter("weight_scale", torch.nn.Parameter(torch.Tensor([1.0]))) layer.module.register_parameter("weight_scale", torch.nn.Parameter(torch.Tensor([1.0])))
# todo: support per-channel quantization for weight since TensorRT use it for conv weight # todo: support per-channel quantization for weight since TensorRT use it for conv weight
q_bit = get_bits_length(config, "weight") q_bits = get_bits_length(config, "weight")
layer.module.register_buffer('weight_bit', torch.Tensor([q_bit])) layer.module.register_buffer('weight_bits', torch.Tensor([q_bits]))
qmax = 2 ** (q_bit - 1) - 1 qmax = 2 ** (q_bits - 1) - 1
qmin = -2 ** (q_bit - 1) qmin = -2 ** (q_bits - 1)
init_weight_scale = layer.module.weight.data.detach().abs().mean() * 2 / (qmax ** 0.5) init_weight_scale = layer.module.weight.data.detach().abs().mean() * 2 / (qmax ** 0.5)
layer.module.weight_scale = torch.nn.Parameter(init_weight_scale) layer.module.weight_scale = torch.nn.Parameter(init_weight_scale)
layer.module.weight_qmax = qmax layer.module.weight_qmax = qmax
...@@ -889,12 +893,12 @@ class LsqQuantizer(Quantizer): ...@@ -889,12 +893,12 @@ class LsqQuantizer(Quantizer):
self.optimizer.add_param_group({"params": layer.module.weight_scale}) self.optimizer.add_param_group({"params": layer.module.weight_scale})
if "output" in config.get("quant_types", []): if "output" in config.get("quant_types", []):
# scale of activation will be initialized using the first batch data # scale of output will be initialized using the first batch data
layer.module.register_parameter("output_scale", torch.nn.Parameter(torch.Tensor([1.0]))) layer.module.register_parameter("output_scale", torch.nn.Parameter(torch.Tensor([1.0])))
q_bit = get_bits_length(config, "output") q_bits = get_bits_length(config, "output")
layer.module.register_buffer('output_bit', torch.Tensor([q_bit])) layer.module.register_buffer('output_bits', torch.Tensor([q_bits]))
qmax = 2 ** (q_bit - 1) - 1 qmax = 2 ** (q_bits - 1) - 1
qmin = -2 ** (q_bit - 1) qmin = -2 ** (q_bits - 1)
layer.module.output_qmax = qmax layer.module.output_qmax = qmax
layer.module.output_qmin = qmin layer.module.output_qmin = qmin
...@@ -903,10 +907,10 @@ class LsqQuantizer(Quantizer): ...@@ -903,10 +907,10 @@ class LsqQuantizer(Quantizer):
if "input" in config.get("quant_types", []): if "input" in config.get("quant_types", []):
# scale of input will be initialized using the first batch data # scale of input will be initialized using the first batch data
layer.module.register_parameter("input_scale", torch.nn.Parameter(torch.Tensor([1.0]))) layer.module.register_parameter("input_scale", torch.nn.Parameter(torch.Tensor([1.0])))
q_bit = get_bits_length(config, "input") q_bits = get_bits_length(config, "input")
layer.module.register_buffer('input_bit', torch.Tensor([q_bit])) layer.module.register_buffer('input_bits', torch.Tensor([q_bits]))
qmax = 2 ** (q_bit - 1) - 1 qmax = 2 ** (q_bits - 1) - 1
qmin = -2 ** (q_bit - 1) qmin = -2 ** (q_bits - 1)
layer.module.input_qmax = qmax layer.module.input_qmax = qmax
layer.module.input_qmin = qmin layer.module.input_qmin = qmin
...@@ -1011,18 +1015,18 @@ class LsqQuantizer(Quantizer): ...@@ -1011,18 +1015,18 @@ class LsqQuantizer(Quantizer):
calibration_config = {} calibration_config = {}
for name, module in self.bound_model.named_modules(): for name, module in self.bound_model.named_modules():
if hasattr(module, 'input_bit') or hasattr(module, 'output_bit'): if hasattr(module, 'input_bits') or hasattr(module, 'output_bits'):
calibration_config[name] = {} calibration_config[name] = {}
if hasattr(module, 'weight_bit'): if hasattr(module, 'weight_bits'):
calibration_config[name]['weight_bit'] = int(module.weight_bit) calibration_config[name]['weight_bits'] = int(module.weight_bits)
abs_max_input = float(module.input_scale * module.input_qmax) abs_max_input = float(module.input_scale * module.input_qmax)
calibration_config[name]['tracked_min_input'] = -abs_max_input calibration_config[name]['tracked_min_input'] = -abs_max_input
calibration_config[name]['tracked_max_input'] = abs_max_input calibration_config[name]['tracked_max_input'] = abs_max_input
if hasattr(module, 'output_bit'): if hasattr(module, 'output_bits'):
calibration_config[name]['activation_bit'] = int(module.output_bit) calibration_config[name]['output_bits'] = int(module.output_bits)
abs_max_output = float(module.output_scale * module.output_qmax) abs_max_output = float(module.output_scale * module.output_qmax)
calibration_config[name]['tracked_min_activation'] = -abs_max_output calibration_config[name]['tracked_min_output'] = -abs_max_output
calibration_config[name]['tracked_max_activation'] = abs_max_output calibration_config[name]['tracked_max_output'] = abs_max_output
self._del_simulated_attr(module) self._del_simulated_attr(module)
self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path,
...@@ -1034,8 +1038,8 @@ class LsqQuantizer(Quantizer): ...@@ -1034,8 +1038,8 @@ class LsqQuantizer(Quantizer):
""" """
delete redundant parameters in quantize module delete redundant parameters in quantize module
""" """
del_attr_list = ['old_weight', 'tracked_min_input', 'tracked_max_input', 'tracked_min_activation', \ del_attr_list = ['old_weight', 'tracked_min_input', 'tracked_max_input', 'tracked_min_output', \
'tracked_max_activation', 'output_scale', 'input_scale', 'weight_scale','weight_bit', 'output_bit', 'input_bit'] 'tracked_max_output', 'output_scale', 'input_scale', 'weight_scale','weight_bits', 'output_bits', 'input_bits']
for attr in del_attr_list: for attr in del_attr_list:
if hasattr(module, attr): if hasattr(module, attr):
delattr(module, attr) delattr(module, attr)
......
...@@ -834,7 +834,7 @@ class QuantGrad(torch.autograd.Function): ...@@ -834,7 +834,7 @@ class QuantGrad(torch.autograd.Function):
@classmethod @classmethod
def get_bits_length(cls, config, quant_type): def get_bits_length(cls, config, quant_type):
""" """
Get bit for quantize config Get bits for quantize config
Parameters Parameters
---------- ----------
config : Dict config : Dict
......
...@@ -9,26 +9,26 @@ The main function of this page is to convert pytorch model to onnx model. ...@@ -9,26 +9,26 @@ The main function of this page is to convert pytorch model to onnx model.
Convertion from pytorch model to onnx model is primary so that a critical Convertion from pytorch model to onnx model is primary so that a critical
problem is caused that Layer name of pytorch model fail to convert to onnx problem is caused that Layer name of pytorch model fail to convert to onnx
layer name directly. To solve it, we wrap pytorch model in new wrapper which layer name directly. To solve it, we wrap pytorch model in new wrapper which
multiply bit number and input before computation of each op. Only in this multiply bits number and input before computation of each op. Only in this
way can onnx model get bit number of corresponded layer. way can onnx model get bits number of corresponded layer.
""" """
class LayernameModuleWrapper(torch.nn.Module): class LayernameModuleWrapper(torch.nn.Module):
def __init__(self, module, module_bit) -> None: def __init__(self, module, module_bits) -> None:
""" """
Parameters Parameters
---------- ----------
module : torch.nn.Module module : torch.nn.Module
Layer module of pytorch model Layer module of pytorch model
module_bit : int module_bits : int
Bit width setting for module Bits width setting for module
""" """
super().__init__() super().__init__()
self.module = module self.module = module
self.module_bit = module_bit self.module_bits = module_bits
def forward(self, inputs): def forward(self, inputs):
inputs = inputs*self.module_bit inputs = inputs*self.module_bits
inputs = self.module(inputs) inputs = self.module(inputs)
return inputs return inputs
...@@ -93,14 +93,14 @@ def unwrapper(model_onnx, index2name, config): ...@@ -93,14 +93,14 @@ def unwrapper(model_onnx, index2name, config):
def torch_to_onnx(model, config, input_shape, model_path, input_names, output_names): def torch_to_onnx(model, config, input_shape, model_path, input_names, output_names):
""" """
Convert torch model to onnx model and get layer bit config of onnx model. Convert torch model to onnx model and get layer bits config of onnx model.
Parameters Parameters
---------- ----------
model : pytorch model model : pytorch model
The model to speed up by quantization The model to speed up by quantization
config : dict config : dict
Config recording bit number and name of layers Config recording bits number and name of layers
input_shape : tuple input_shape : tuple
The input shape of model, shall pass it to torch.onnx.export The input shape of model, shall pass it to torch.onnx.export
model_path : str model_path : str
...@@ -119,7 +119,7 @@ def torch_to_onnx(model, config, input_shape, model_path, input_names, output_na ...@@ -119,7 +119,7 @@ def torch_to_onnx(model, config, input_shape, model_path, input_names, output_na
""" """
# Support Gemm, Conv, Relu, Clip(Relu6) and MaxPool # Support Gemm, Conv, Relu, Clip(Relu6) and MaxPool
support_op = [torch.nn.Conv2d, torch.nn.Linear, torch.nn.ReLU, torch.nn.ReLU6, torch.nn.MaxPool2d] support_op = [torch.nn.Conv2d, torch.nn.Linear, torch.nn.ReLU, torch.nn.ReLU6, torch.nn.MaxPool2d]
# Transfer bit number to onnx layer by using wrapper # Transfer bits number to onnx layer by using wrapper
index2name = {} index2name = {}
name2index = {} name2index = {}
if config is not None: if config is not None:
......
...@@ -31,18 +31,18 @@ Precision_Dict = { ...@@ -31,18 +31,18 @@ Precision_Dict = {
def valid_config(config=None): def valid_config(config=None):
""" """
This function validates the bit setting configuration This function validates the bits setting configuration
""" """
if config is None: if config is None:
return return
support_bit = [8, 16, 32] support_bits = [8, 16, 32]
for name in config.keys(): for name in config.keys():
if 'weight_bit' in config[name]: if 'weight_bits' in config[name]:
w_bit = config[name]['weight_bit'] w_bits = config[name]['weight_bits']
assert w_bit in support_bit, "weight bit should be 8, 16, 32" assert w_bits in support_bits, "weight bits should be 8, 16, 32"
if 'activation_bit' in config[name]: if 'output_bits' in config[name]:
a_bit = config[name]['activation_bit'] a_bits = config[name]['output_bits']
assert a_bit in support_bit, "activation bit should be 8, 16, 32" assert a_bits in support_bits, "output bits should be 8, 16, 32"
def handle_gemm(network, layer_idx, config): def handle_gemm(network, layer_idx, config):
""" """
...@@ -55,26 +55,26 @@ def handle_gemm(network, layer_idx, config): ...@@ -55,26 +55,26 @@ def handle_gemm(network, layer_idx, config):
layer_idx : int layer_idx : int
layer index of gemm layer index of gemm
config : dict config : dict
Config recording bit number and name of layers Config recording bits number and name of layers
""" """
layer = network.get_layer(layer_idx) layer = network.get_layer(layer_idx)
pre_layer = network.get_layer(layer_idx-1) pre_layer = network.get_layer(layer_idx-1)
next_layer = network.get_layer(layer_idx+1) next_layer = network.get_layer(layer_idx+1)
# if weight bit exists, set three layers' precision, # if weight bits exists, set three layers' precision,
# input tensor range and the first two layers' output type # input tensor range and the first two layers' output type
if 'weight_bit' in config[layer.name]: if 'weight_bits' in config[layer.name]:
assert 'tracked_min_input' in config[layer.name] assert 'tracked_min_input' in config[layer.name]
assert 'tracked_max_input' in config[layer.name] assert 'tracked_max_input' in config[layer.name]
w_bit = config[layer.name]['weight_bit'] w_bits = config[layer.name]['weight_bits']
tracked_min_input = config[layer.name]['tracked_min_input'] tracked_min_input = config[layer.name]['tracked_min_input']
tracked_max_input = config[layer.name]['tracked_max_input'] tracked_max_input = config[layer.name]['tracked_max_input']
# set three layers the same precision # set three layers the same precision
layer.precision = Precision_Dict[w_bit] layer.precision = Precision_Dict[w_bits]
pre_layer.precision = Precision_Dict[w_bit] pre_layer.precision = Precision_Dict[w_bits]
next_layer.precision = Precision_Dict[w_bit] next_layer.precision = Precision_Dict[w_bits]
# set the first two layers' output type # set the first two layers' output type
pre_layer.set_output_type(0, Precision_Dict[w_bit]) pre_layer.set_output_type(0, Precision_Dict[w_bits])
layer.set_output_type(0, Precision_Dict[w_bit]) layer.set_output_type(0, Precision_Dict[w_bits])
pre_in_tensor = pre_layer.get_input(0) pre_in_tensor = pre_layer.get_input(0)
in_tensor = layer.get_input(0) in_tensor = layer.get_input(0)
next_in_tensor = next_layer.get_input(0) next_in_tensor = next_layer.get_input(0)
...@@ -83,20 +83,20 @@ def handle_gemm(network, layer_idx, config): ...@@ -83,20 +83,20 @@ def handle_gemm(network, layer_idx, config):
in_tensor.dynamic_range = (tracked_min_input, tracked_max_input) in_tensor.dynamic_range = (tracked_min_input, tracked_max_input)
next_in_tensor.dynamic_range = (tracked_min_input, tracked_max_input) next_in_tensor.dynamic_range = (tracked_min_input, tracked_max_input)
# if activation bit exists, set the last layer's output type output tensor range # if output bits exists, set the last layer's output type output tensor range
if 'activation_bit' in config[layer.name]: if 'output_bits' in config[layer.name]:
assert 'tracked_min_activation' in config[layer.name] assert 'tracked_min_output' in config[layer.name]
assert 'tracked_max_activation' in config[layer.name] assert 'tracked_max_output' in config[layer.name]
a_bit = config[layer.name]['activation_bit'] a_bits = config[layer.name]['output_bits']
tracked_min_activation = config[layer.name]['tracked_min_activation'] tracked_min_output = config[layer.name]['tracked_min_output']
tracked_max_activation = config[layer.name]['tracked_max_activation'] tracked_max_output = config[layer.name]['tracked_max_output']
# set the last layer's output type # set the last layer's output type
next_layer.set_output_type(0, Precision_Dict[a_bit]) next_layer.set_output_type(0, Precision_Dict[a_bits])
next_out_tensor = next_layer.get_output(0) next_out_tensor = next_layer.get_output(0)
# set the last layer's output tensor range # set the last layer's output tensor range
next_out_tensor.dynamic_range = (tracked_min_activation, tracked_max_activation) next_out_tensor.dynamic_range = (tracked_min_output, tracked_max_output)
def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=False, calib=None): def build_engine(model_file, config=None, extra_layer_bits=32, strict_datatype=False, calib=None):
""" """
This function builds an engine from an onnx model with calibration process. This function builds an engine from an onnx model with calibration process.
...@@ -105,12 +105,12 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa ...@@ -105,12 +105,12 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa
model_file : str model_file : str
The path of onnx model The path of onnx model
config : dict config : dict
Config recording bit number and name of layers Config recording bits number and name of layers
extra_layer_bit : int extra_layer_bits : int
Other layers which are not in config will be quantized to corresponding bit number Other layers which are not in config will be quantized to corresponding bits number
strict_datatype : bool strict_datatype : bool
Whether constrain layer bit to the number given in config or not. If true, all the layer Whether constrain layer bits to the number given in config or not. If true, all the layer
will be set to given bit strictly. Otherwise, these layers will be set automatically by will be set to given bits strictly. Otherwise, these layers will be set automatically by
tensorrt tensorrt
calib : numpy array calib : numpy array
The data using to calibrate quantization model The data using to calibrate quantization model
...@@ -135,14 +135,14 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa ...@@ -135,14 +135,14 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa
else: else:
builder.max_workspace_size = common.GiB(4) builder.max_workspace_size = common.GiB(4)
if extra_layer_bit == 32 and config is None: if extra_layer_bits == 32 and config is None:
pass pass
elif extra_layer_bit == 16 and config is None: elif extra_layer_bits == 16 and config is None:
if trt_version == TRT8: if trt_version == TRT8:
trt_config.set_flag(trt.BuilderFlag.FP16) trt_config.set_flag(trt.BuilderFlag.FP16)
else: else:
builder.fp16_mode = True builder.fp16_mode = True
elif extra_layer_bit == 8 and config is None: elif extra_layer_bits == 8 and config is None:
# entire model in 8bit mode # entire model in 8bit mode
if trt_version == TRT8: if trt_version == TRT8:
trt_config.set_flag(trt.BuilderFlag.INT8) trt_config.set_flag(trt.BuilderFlag.INT8)
...@@ -180,15 +180,15 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa ...@@ -180,15 +180,15 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa
break break
layer = network.get_layer(i) layer = network.get_layer(i)
if layer.name in config: if layer.name in config:
w_bit = config[layer.name]['weight_bit'] w_bits = config[layer.name]['weight_bits']
a_bit = config[layer.name]['activation_bit'] a_bits = config[layer.name]['output_bits']
layer.precision = Precision_Dict[w_bit] layer.precision = Precision_Dict[w_bits]
layer.set_output_type(0, Precision_Dict[a_bit]) layer.set_output_type(0, Precision_Dict[a_bits])
else: else:
# This implementation may be incorrect when output number > 1 # This implementation may be incorrect when output number > 1
for i in range(network.num_layers): for i in range(network.num_layers):
if config is None: if config is None:
# no low bit layer need to be set, keep original model # no low bits layer need to be set, keep original model
break break
layer = network.get_layer(i) layer = network.get_layer(i)
if layer.name not in config: if layer.name not in config:
...@@ -198,37 +198,37 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa ...@@ -198,37 +198,37 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa
handle_gemm(network, i, config) handle_gemm(network, i, config)
continue continue
# If weight_bit exists in config, set layer precision and layer's input tensor dynamic range. # If weight_bits exists in config, set layer precision and layer's input tensor dynamic range.
if 'weight_bit' in config[layer.name]: if 'weight_bits' in config[layer.name]:
assert 'tracked_min_input' in config[layer.name] assert 'tracked_min_input' in config[layer.name]
assert 'tracked_max_input' in config[layer.name] assert 'tracked_max_input' in config[layer.name]
w_bit = config[layer.name]['weight_bit'] w_bits = config[layer.name]['weight_bits']
tracked_min_input = config[layer.name]['tracked_min_input'] tracked_min_input = config[layer.name]['tracked_min_input']
tracked_max_input = config[layer.name]['tracked_max_input'] tracked_max_input = config[layer.name]['tracked_max_input']
layer.precision = Precision_Dict[w_bit] layer.precision = Precision_Dict[w_bits]
in_tensor = layer.get_input(0) in_tensor = layer.get_input(0)
in_tensor.dynamic_range = (tracked_min_input, tracked_max_input) in_tensor.dynamic_range = (tracked_min_input, tracked_max_input)
# If activation exists in config, set layer output type and layer's output tensor dynamic range. # If output exists in config, set layer output type and layer's output tensor dynamic range.
if 'activation_bit' in config[layer.name]: if 'output_bits' in config[layer.name]:
assert 'tracked_min_activation' in config[layer.name] assert 'tracked_min_output' in config[layer.name]
assert 'tracked_max_activation' in config[layer.name] assert 'tracked_max_output' in config[layer.name]
a_bit = config[layer.name]['activation_bit'] a_bits = config[layer.name]['output_bits']
tracked_min_activation = config[layer.name]['tracked_min_activation'] tracked_min_output = config[layer.name]['tracked_min_output']
tracked_max_activation = config[layer.name]['tracked_max_activation'] tracked_max_output = config[layer.name]['tracked_max_output']
layer.set_output_type(0, Precision_Dict[a_bit]) layer.set_output_type(0, Precision_Dict[a_bits])
out_tensor = layer.get_output(0) out_tensor = layer.get_output(0)
out_tensor.dynamic_range = (tracked_min_activation, tracked_max_activation) out_tensor.dynamic_range = (tracked_min_output, tracked_max_output)
# Build engine and do int8 calibration. # Build engine and do int8 calibration.
if trt_version == TRT8: if trt_version == TRT8:
engine = builder.build_engine(network, trt_config) engine = builder.build_engine(network, trt_config)
else: else:
engine.builder.build_cuda_engine(network) engine = builder.build_cuda_engine(network)
return engine return engine
class ModelSpeedupTensorRT(BaseModelSpeedup): class ModelSpeedupTensorRT(BaseModelSpeedup):
def __init__(self, model, input_shape, config=None, onnx_path="default_model.onnx", extra_layer_bit=32, strict_datatype=True, def __init__(self, model, input_shape, config=None, onnx_path="default_model.onnx", extra_layer_bits=32, strict_datatype=True,
calibrate_type=CalibrateType.ENTROPY2, calib_data_loader=None, calibration_cache = "calibration.cache", batchsize=1, calibrate_type=CalibrateType.ENTROPY2, calib_data_loader=None, calibration_cache = "calibration.cache", batchsize=1,
input_names=["actual_input_1"], output_names=["output1"]): input_names=["actual_input_1"], output_names=["output1"]):
""" """
...@@ -239,14 +239,14 @@ class ModelSpeedupTensorRT(BaseModelSpeedup): ...@@ -239,14 +239,14 @@ class ModelSpeedupTensorRT(BaseModelSpeedup):
input_shape : tuple input_shape : tuple
The input shape of model, shall pass it to torch.onnx.export. The input shape of model, shall pass it to torch.onnx.export.
config : dict config : dict
Config recording bit number and name of layers. Config recording bits number and name of layers.
onnx_path : str onnx_path : str
The path user want to store onnx model which is converted from pytorch model. The path user want to store onnx model which is converted from pytorch model.
extra_layer_bit : int extra_layer_bits : int
Other layers which are not in config will be quantized to corresponding bit number. Other layers which are not in config will be quantized to corresponding bits number.
strict_datatype : bool strict_datatype : bool
Whether constrain layer bit to the number given in config or not. If true, all the layer Whether constrain layer bits to the number given in config or not. If true, all the layer
will be set to given bit strictly. Otherwise, these layers will be set automatically by will be set to given bits strictly. Otherwise, these layers will be set automatically by
tensorrt. tensorrt.
calibrate_type : tensorrt.tensorrt.CalibrationAlgoType calibrate_type : tensorrt.tensorrt.CalibrationAlgoType
The algorithm of calibrating. Please refer to https://docs.nvidia.com/deeplearning/ The algorithm of calibrating. Please refer to https://docs.nvidia.com/deeplearning/
...@@ -267,7 +267,7 @@ class ModelSpeedupTensorRT(BaseModelSpeedup): ...@@ -267,7 +267,7 @@ class ModelSpeedupTensorRT(BaseModelSpeedup):
self.onnx_path = onnx_path self.onnx_path = onnx_path
self.input_shape = input_shape self.input_shape = input_shape
self.config = config self.config = config
self.extra_layer_bit = extra_layer_bit self.extra_layer_bits = extra_layer_bits
self.strict_datatype = strict_datatype self.strict_datatype = strict_datatype
self.calibrate_type = calibrate_type self.calibrate_type = calibrate_type
self.calib_data_loader = calib_data_loader self.calib_data_loader = calib_data_loader
...@@ -327,7 +327,7 @@ class ModelSpeedupTensorRT(BaseModelSpeedup): ...@@ -327,7 +327,7 @@ class ModelSpeedupTensorRT(BaseModelSpeedup):
calib = calibrator.Calibrator(calib_data, self.calibration_cache, self.batchsize, self.calibrate_type) calib = calibrator.Calibrator(calib_data, self.calibration_cache, self.batchsize, self.calibrate_type)
# build inference engine with calibration # build inference engine with calibration
engine = build_engine(onnx_path, self.onnx_config, self.extra_layer_bit, self.strict_datatype, calib) engine = build_engine(onnx_path, self.onnx_config, self.extra_layer_bits, self.strict_datatype, calib)
return engine.create_execution_context() return engine.create_execution_context()
def _tensorrt_build_withoutcalib(self, onnx_path): def _tensorrt_build_withoutcalib(self, onnx_path):
...@@ -344,7 +344,7 @@ class ModelSpeedupTensorRT(BaseModelSpeedup): ...@@ -344,7 +344,7 @@ class ModelSpeedupTensorRT(BaseModelSpeedup):
tensorrt.IExecutionContext tensorrt.IExecutionContext
Context for executing inference using an ICudaEngine Context for executing inference using an ICudaEngine
""" """
engine = build_engine(onnx_path, self.onnx_config, self.extra_layer_bit, self.strict_datatype) engine = build_engine(onnx_path, self.onnx_config, self.extra_layer_bits, self.strict_datatype)
return engine.create_execution_context() return engine.create_execution_context()
def inference(self, test_data): def inference(self, test_data):
......
...@@ -49,7 +49,8 @@ class CompressorTestCase(TestCase): ...@@ -49,7 +49,8 @@ class CompressorTestCase(TestCase):
}] }]
model.relu = torch.nn.ReLU() model.relu = torch.nn.ReLU()
quantizer = torch_quantizer.QAT_Quantizer(model, config_list) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer)
quantizer.compress() quantizer.compress()
modules_to_compress = quantizer.get_modules_to_compress() modules_to_compress = quantizer.get_modules_to_compress()
modules_to_compress_name = [t[0].name for t in modules_to_compress] modules_to_compress_name = [t[0].name for t in modules_to_compress]
...@@ -317,7 +318,9 @@ class CompressorTestCase(TestCase): ...@@ -317,7 +318,9 @@ class CompressorTestCase(TestCase):
'op_types': ['ReLU'] 'op_types': ['ReLU']
}] }]
model.relu = torch.nn.ReLU() model.relu = torch.nn.ReLU()
quantizer = torch_quantizer.QAT_Quantizer(model, config_list)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer)
quantizer.compress() quantizer.compress()
# test quantize # test quantize
...@@ -350,14 +353,14 @@ class CompressorTestCase(TestCase): ...@@ -350,14 +353,14 @@ class CompressorTestCase(TestCase):
eps = 1e-7 eps = 1e-7
x = torch.tensor([[-0.2, 0], [0.1, 0.2]]) x = torch.tensor([[-0.2, 0], [0.1, 0.2]])
out = model.relu(x) out = model.relu(x)
assert math.isclose(model.relu.module.tracked_min_activation, 0, abs_tol=eps) assert math.isclose(model.relu.module.tracked_min_output, 0, abs_tol=eps)
assert math.isclose(model.relu.module.tracked_max_activation, 0.002, abs_tol=eps) assert math.isclose(model.relu.module.tracked_max_output, 0.002, abs_tol=eps)
quantizer.step_with_optimizer() quantizer.step_with_optimizer()
x = torch.tensor([[0.2, 0.4], [0.6, 0.8]]) x = torch.tensor([[0.2, 0.4], [0.6, 0.8]])
out = model.relu(x) out = model.relu(x)
assert math.isclose(model.relu.module.tracked_min_activation, 0.002, abs_tol=eps) assert math.isclose(model.relu.module.tracked_min_output, 0.002, abs_tol=eps)
assert math.isclose(model.relu.module.tracked_max_activation, 0.00998, abs_tol=eps) assert math.isclose(model.relu.module.tracked_max_output, 0.00998, abs_tol=eps)
def test_torch_quantizer_export(self): def test_torch_quantizer_export(self):
config_list_qat = [{ config_list_qat = [{
...@@ -392,7 +395,8 @@ class CompressorTestCase(TestCase): ...@@ -392,7 +395,8 @@ class CompressorTestCase(TestCase):
for config, quantize_algorithm in zip(config_set, quantize_algorithm_set): for config, quantize_algorithm in zip(config_set, quantize_algorithm_set):
model = TorchModel() model = TorchModel()
model.relu = torch.nn.ReLU() model.relu = torch.nn.ReLU()
quantizer = quantize_algorithm(model, config) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
quantizer = quantize_algorithm(model, config, optimizer)
quantizer.compress() quantizer.compress()
x = torch.rand((1, 1, 28, 28), requires_grad=True) x = torch.rand((1, 1, 28, 28), requires_grad=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment