Unverified Commit 3fc79c74 authored by lin bin's avatar lin bin Committed by GitHub
Browse files

[Quantization] support load_calibration_config (#4163)

parent 9a9cb3d9
......@@ -388,15 +388,18 @@ class QAT_Quantizer(Quantizer):
module.register_buffer("scale", torch.tensor([1.0]))
module.register_buffer('ema_decay', torch.tensor([0.99]))
if "weight" in config.get("quant_types", []):
module.register_buffer('weight_bits', torch.zeros(1))
weight_bits = get_bits_length(config, 'weight')
layer.module.register_buffer('weight_bits', torch.Tensor([int(weight_bits)]))
if "input" in config.get("quant_types", []):
module.register_buffer('input_bits', torch.zeros(1))
module.register_buffer('tracked_min_input', torch.zeros(1))
module.register_buffer('tracked_max_input', torch.zeros(1))
input_bits = get_bits_length(config, 'input')
layer.module.register_buffer('tracked_min_input', torch.zeros(1))
layer.module.register_buffer('tracked_max_input', torch.zeros(1))
layer.module.register_buffer('input_bits', torch.Tensor([int(input_bits)]))
if "output" in config.get("quant_types", []):
module.register_buffer('output_bits', torch.zeros(1))
module.register_buffer('tracked_min_output', torch.zeros(1))
module.register_buffer('tracked_max_output', torch.zeros(1))
output_bits = get_bits_length(config, 'output')
layer.module.register_buffer('output_bits', torch.Tensor([int(output_bits)]))
layer.module.register_buffer('tracked_min_output', torch.zeros(1))
layer.module.register_buffer('tracked_max_output', torch.zeros(1))
self.bound_model.to(device)
def _del_simulated_attr(self, module):
......@@ -484,7 +487,7 @@ class QAT_Quantizer(Quantizer):
config = wrapper.config
module = wrapper.module
weight = module.weight
weight_bits = get_bits_length(config, 'weight')
weight_bits = int(module.weight_bits)
quant_start_step = config.get('quant_start_step', 0)
assert weight_bits >= 1, "quant bits length should be at least 1"
......@@ -501,20 +504,13 @@ class QAT_Quantizer(Quantizer):
module.zero_point.copy_(zero_point)
weight = self._quantize(weight_bits, module, weight)
weight = self._dequantize(module, weight)
module.weight_bits = torch.Tensor([weight_bits])
# Weight can not be in-place modified, so when use torch.nn.DataParallel, this update
# will be lost after each forward process. However, this update takes effect on each
# replicated module during each forward process, which will make the quantized weight
# be used correctly.
wrapper.module.weight = weight
return weight
def quantize_input(self, inputs, wrapper, **kwargs):
config = wrapper.config
module = wrapper.module
input_bits = get_bits_length(config, 'input')
module.input_bit = torch.tensor([input_bits])
input_bits = int(module.input_bits)
quant_start_step = config.get('quant_start_step', 0)
assert input_bits >= 1, "quant bits length should be at least 1"
......@@ -544,8 +540,7 @@ class QAT_Quantizer(Quantizer):
def quantize_output(self, output, wrapper, **kwargs):
config = wrapper.config
module = wrapper.module
output_bits = get_bits_length(config, 'output')
module.output_bits = torch.Tensor([output_bits])
output_bits = int(module.output_bits)
quant_start_step = config.get('quant_start_step', 0)
assert output_bits >= 1, "quant bits length should be at least 1"
......@@ -574,6 +569,25 @@ class QAT_Quantizer(Quantizer):
out = self._dequantize(module, out)
return out
def load_calibration_config(self, calibration_config):
modules_to_compress = self.get_modules_to_compress()
for layer, _ in modules_to_compress:
name, module = layer.name, layer.module
if name not in calibration_config:
if hasattr(module, 'weight_bits') or hasattr(module, 'output_bits') or hasattr(module, 'input_bits'):
logger.warning(f"Can not find module {name}'s parameter in input config.")
continue
if hasattr(module, 'weight_bits'):
assert calibration_config[name]['weight_bits'] == module.weight_bits, f"weight bits of module {name} fail to match"
if hasattr(module, 'input_bits'):
assert calibration_config[name]['input_bits'] == module.input_bits, f"input bits of module {name} fail to match"
module.tracked_min_input.data = torch.Tensor([calibration_config[name]['tracked_min_input']])
module.tracked_max_input.data = torch.Tensor([calibration_config[name]['tracked_max_input']])
if hasattr(module, 'output_bits'):
assert calibration_config[name]['output_bits'] == module.output_bits, f"output bits of module {name} fail to match"
module.tracked_min_output.data = torch.Tensor([calibration_config[name]['tracked_min_output']])
module.tracked_max_output.data = torch.Tensor([calibration_config[name]['tracked_max_output']])
def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None):
"""
Export quantized model weights and calibration parameters(optional)
......@@ -620,8 +634,8 @@ class QAT_Quantizer(Quantizer):
module.register_parameter('bias', actual_bias)
else:
setattr(module, 'bias', None)
if hasattr(module, 'input_bit'):
calibration_config[name]['input_bits'] = int(module.input_bit)
if hasattr(module, 'input_bits'):
calibration_config[name]['input_bits'] = int(module.input_bits)
calibration_config[name]['tracked_min_input'] = float(module.tracked_min_input)
calibration_config[name]['tracked_max_input'] = float(module.tracked_max_input)
......@@ -655,7 +669,8 @@ class DoReFaQuantizer(Quantizer):
modules_to_compress = self.get_modules_to_compress()
for layer, config in modules_to_compress:
if "weight" in config.get("quant_types", []):
layer.module.register_buffer('weight_bits', torch.zeros(1))
weight_bits = get_bits_length(config, 'weight')
layer.module.register_buffer('weight_bits', torch.Tensor([int(weight_bits)]))
self.bound_model.to(device)
def _del_simulated_attr(self, module):
......@@ -690,13 +705,12 @@ class DoReFaQuantizer(Quantizer):
def quantize_weight(self, wrapper, **kwargs):
weight = wrapper.module.weight
weight_bits = get_bits_length(wrapper.config, 'weight')
weight_bits = int(wrapper.module.weight_bits)
weight = weight.tanh()
weight = weight / (2 * weight.abs().max()) + 0.5
weight = self.quantize(weight, weight_bits)
weight = 2 * weight - 1
wrapper.module.weight = weight
wrapper.module.weight_bits = torch.Tensor([weight_bits])
# wrapper.module.weight.data = weight
return weight
......@@ -764,7 +778,8 @@ class BNNQuantizer(Quantizer):
modules_to_compress = self.get_modules_to_compress()
for layer, config in modules_to_compress:
if "weight" in config.get("quant_types", []):
layer.module.register_buffer('weight_bits', torch.zeros(1))
weight_bits = get_bits_length(config, 'weight')
layer.module.register_buffer('weight_bits', torch.Tensor([int(weight_bits)]))
self.bound_model.to(device)
def _del_simulated_attr(self, module):
......@@ -890,10 +905,10 @@ class LsqQuantizer(Quantizer):
if "weight" in config.get("quant_types", []):
layer.module.register_parameter("weight_scale", torch.nn.Parameter(torch.Tensor([1.0])))
# todo: support per-channel quantization for weight since TensorRT use it for conv weight
q_bits = get_bits_length(config, "weight")
layer.module.register_buffer('weight_bits', torch.Tensor([q_bits]))
qmax = 2 ** (q_bits - 1) - 1
qmin = -2 ** (q_bits - 1)
weight_bits = get_bits_length(config, "weight")
layer.module.register_buffer('weight_bits', torch.Tensor([weight_bits]))
qmax = 2 ** (weight_bits - 1) - 1
qmin = -2 ** (weight_bits - 1)
init_weight_scale = layer.module.weight.data.detach().abs().mean() * 2 / (qmax ** 0.5)
layer.module.weight_scale = torch.nn.Parameter(init_weight_scale)
layer.module.weight_qmax = qmax
......@@ -904,10 +919,10 @@ class LsqQuantizer(Quantizer):
if "output" in config.get("quant_types", []):
# scale of output will be initialized using the first batch data
layer.module.register_parameter("output_scale", torch.nn.Parameter(torch.Tensor([1.0])))
q_bits = get_bits_length(config, "output")
layer.module.register_buffer('output_bits', torch.Tensor([q_bits]))
qmax = 2 ** (q_bits - 1) - 1
qmin = -2 ** (q_bits - 1)
output_bits = get_bits_length(config, "output")
layer.module.register_buffer('output_bits', torch.Tensor([output_bits]))
qmax = 2 ** (output_bits - 1) - 1
qmin = -2 ** (output_bits - 1)
layer.module.output_qmax = qmax
layer.module.output_qmin = qmin
......@@ -916,10 +931,10 @@ class LsqQuantizer(Quantizer):
if "input" in config.get("quant_types", []):
# scale of input will be initialized using the first batch data
layer.module.register_parameter("input_scale", torch.nn.Parameter(torch.Tensor([1.0])))
q_bits = get_bits_length(config, "input")
layer.module.register_buffer('input_bits', torch.Tensor([q_bits]))
qmax = 2 ** (q_bits - 1) - 1
qmin = -2 ** (q_bits - 1)
input_bits = get_bits_length(config, "input")
layer.module.register_buffer('input_bits', torch.Tensor([input_bits]))
qmax = 2 ** (input_bits - 1) - 1
qmin = -2 ** (input_bits - 1)
layer.module.input_qmax = qmax
layer.module.input_qmin = qmin
......@@ -993,6 +1008,24 @@ class LsqQuantizer(Quantizer):
inputs = self.quantize(inputs, module.input_scale, module.input_qmin, module.input_qmax)
return inputs
def load_calibration_config(self, calibration_config):
modules_to_compress = self.get_modules_to_compress()
for layer, _ in modules_to_compress:
name, module = layer.name, layer.module
if name not in calibration_config:
if hasattr(module, 'weight_bits') or hasattr(module, 'output_bits') or hasattr(module, 'input_bits'):
logger.warning(f"Can not find module {name}'s parameter in input config.")
continue
if hasattr(module, 'weight_bits'):
assert calibration_config[name]['weight_bits'] == int(module.weight_bits), f"weight bits of module {name} fail to match"
if hasattr(module, 'input_bits'):
assert calibration_config[name]['input_bits'] == int(module.input_bits), f"input bits of module {name} fail to match"
module.input_scale.data = torch.Tensor([float(calibration_config[name]['tracked_max_input'] / module.input_qmax)])
if hasattr(module, 'output_bits'):
assert calibration_config[name]['output_bits'] == int(module.output_bits), f"output bits of module {name} fail to match"
module.output_scale.data = torch.Tensor([float(calibration_config[name]['tracked_max_output'] / module.output_qmax)])
def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None):
"""
Export quantized model weights and calibration parameters(optional)
......
......@@ -803,6 +803,24 @@ class Quantizer(Compressor):
"""
raise NotImplementedError('Quantizer must overload export_model()')
def load_calibration_config(self, calibration_config):
"""
This function aims to help quantizer set quantization parameters by
loading from a calibration_config which is exported by other quantizer
or itself. The main usage of this function is helping quantize aware training
quantizer set appropriate initial parameters so that the training process will
be much more flexible and converges quickly. What's more, it can also enable
quantizer resume quantization model by loading parameters from config.
Parameters
----------
calibration_config : dict
dict which saves quantization parameters, quantizer can export itself
calibration config.
eg, calibration_config = quantizer.export_model(model_path, calibration_path)
"""
raise NotImplementedError('Quantizer must overload export_model()')
def find_conv_bn_patterns(self, model, dummy_input):
"""
Find all Conv-BN patterns, used for batch normalization folding
......
......@@ -445,6 +445,35 @@ class CompressorTestCase(TestCase):
calibration_config = quantizer.export_model(model_path, calibration_path, onnx_path, input_shape, device)
assert calibration_config is not None
def test_quantizer_load_calibration_config(self):
configure_list = [{
'quant_types': ['weight', 'input'],
'quant_bits': {'weight': 8, 'input': 8},
'op_names': ['conv1', 'conv2']
}, {
'quant_types': ['output', 'weight', 'input'],
'quant_bits': {'output': 8, 'weight': 8, 'input': 8},
'op_names': ['fc1', 'fc2'],
}]
quantize_algorithm_set = [torch_quantizer.ObserverQuantizer, torch_quantizer.QAT_Quantizer, torch_quantizer.LsqQuantizer]
calibration_config = None
for quantize_algorithm in quantize_algorithm_set:
model = TorchModel().eval()
model.relu = torch.nn.ReLU()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
quantizer = quantize_algorithm(model, configure_list, optimizer)
quantizer.compress()
if calibration_config is not None:
quantizer.load_calibration_config(calibration_config)
model_path = "test_model.pth"
calibration_path = "test_calibration.pth"
onnx_path = "test_model.onnx"
input_shape = (1, 1, 28, 28)
device = torch.device("cpu")
calibration_config = quantizer.export_model(model_path, calibration_path, onnx_path, input_shape, device)
def test_torch_pruner_validation(self):
# test bad configuraiton
pruner_classes = [torch_pruner.__dict__[x] for x in \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment