Unverified Commit e76f196d authored by chenbohua3's avatar chenbohua3 Committed by GitHub
Browse files

add quantize_input to QAT quantizer (#4084)

parent 5fc73ba6
......@@ -7,7 +7,7 @@ import sys
sys.path.append('../models')
from mnist.naive import NaiveModel
def train(model, quantizer, device, train_loader, optimizer):
def train(model, device, train_loader, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
......@@ -19,6 +19,7 @@ def train(model, quantizer, device, train_loader, optimizer):
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))
def test(model, device, test_loader):
model.eval()
test_loss = 0
......@@ -47,30 +48,45 @@ def main():
datasets.MNIST('data', train=False, transform=trans),
batch_size=1000, shuffle=True)
model = NaiveModel()
'''you can change this to DoReFaQuantizer to implement it
DoReFaQuantizer(configure_list).compress(model)
'''
# Two things should be kept in mind when set this configure_list:
# 1. When deploying model on backend, some layers will be fused into one layer. For example, the consecutive
# conv + bn + relu layers will be fused into one big layer. If we want to execute the big layer in quantization
# mode, we should tell the backend the quantization information of the input, output, and the weight tensor of
# the big layer, which correspond to conv's input, conv's weight and relu's output.
# 2. Same tensor should be quantized only once. For example, if a tensor is the output of layer A and the input
# of the layer B, you should configure either {'quant_types': ['output'], 'op_names': ['a']} or
# {'quant_types': ['input'], 'op_names': ['b']} in the configure_list.
configure_list = [{
'quant_types': ['weight'],
'quant_bits': {
'weight': 8,
}, # you can just use `int` here because all `quan_types` share same bits length, see config for `ReLu6` below.
'op_types':['Conv2d', 'Linear']
}, {
'quant_types': ['output'],
'quant_bits': 8,
'quant_start_step': 1000,
'op_types':['ReLU6']
}]
'quant_types': ['weight', 'input'],
'quant_bits': {'weight': 8, 'input': 8},
'op_names': ['conv1', 'conv2']
}, {
'quant_types': ['output'],
'quant_bits': {'output': 8, },
'op_names': ['relu1', 'relu2']
}, {
'quant_types': ['output', 'weight', 'input'],
'quant_bits': {'output': 8, 'weight': 8, 'input': 8},
'op_names': ['fc1'],
}, {
'quant_types': ['output', 'weight', 'input'],
'quant_bits': {'output': 8, 'weight': 8, 'input': 8},
'op_names': ['fc2'],
}]
model = NaiveModel().to(device)
dummy_input = torch.randn(1, 1, 28, 28).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
quantizer = QAT_Quantizer(model, configure_list, optimizer)
# To enable batch normalization folding in the training process, you should
# pass dummy_input to the QAT_Quantizer.
quantizer = QAT_Quantizer(model, configure_list, optimizer, dummy_input=dummy_input)
quantizer.compress()
model.to(device)
for epoch in range(40):
print('# Epoch {} #'.format(epoch))
train(model, quantizer, device, train_loader, optimizer)
train(model, device, train_loader, optimizer)
test(model, device, test_loader)
model_path = "mnist_model.pth"
......
......@@ -380,8 +380,10 @@ class QAT_Quantizer(Quantizer):
layer.module.register_buffer('ema_decay', torch.Tensor([0.99]))
if "weight" in config.get("quant_types", []):
layer.module.register_buffer('weight_bits', torch.zeros(1))
if "input" in config.get("quant_types", []):
layer.module.register_buffer('tracked_min_input', torch.zeros(1))
layer.module.register_buffer('tracked_max_input', torch.zeros(1))
layer.module.register_buffer('input_bits', torch.zeros(1))
if "output" in config.get("quant_types", []):
layer.module.register_buffer('output_bits', torch.zeros(1))
layer.module.register_buffer('tracked_min_output', torch.zeros(1))
......@@ -394,7 +396,7 @@ class QAT_Quantizer(Quantizer):
"""
del_attr_list = ['old_weight', 'old_bias', 'ema_decay', 'tracked_min_output', 'tracked_max_output',
'tracked_min_input', 'tracked_max_input', 'scale', 'zero_point', 'weight_bits',
'output_bits', 'BN_FOLD_TAG']
'output_bits', 'BN_FOLD_TAG', 'input_bits']
for attr in del_attr_list:
if hasattr(module, attr):
delattr(module, attr)
......@@ -409,8 +411,9 @@ class QAT_Quantizer(Quantizer):
List of configurations
"""
schema = QuantizerSchema([{
Optional('quant_types'): Schema([lambda x: x in ['weight', 'output']]),
Optional('quant_types'): Schema([lambda x: x in ['weight', 'output', 'input']]),
Optional('quant_bits'): Or(And(int, lambda n: 0 < n < 32), Schema({
Optional('input'): And(int, lambda n: 0 < n < 32),
Optional('weight'): And(int, lambda n: 0 < n < 32),
Optional('output'): And(int, lambda n: 0 < n < 32),
})),
......@@ -472,25 +475,17 @@ class QAT_Quantizer(Quantizer):
config = wrapper.config
module = wrapper.module
weight = module.weight
input = kwargs['input_tensor'] # pylint: disable=redefined-builtin
weight_bits = get_bits_length(config, 'weight')
quant_start_step = config.get('quant_start_step', 0)
assert weight_bits >= 1, "quant bits length should be at least 1"
# we dont update weight in evaluation stage
if quant_start_step > self.bound_model.steps:
module.tracked_min_input, module.tracked_max_input = torch.min(input), torch.max(input)
return weight
if not wrapper.training:
return weight
current_min, current_max = torch.min(input), torch.max(input)
module.tracked_min_input = update_ema(module.tracked_min_input, current_min,
module.ema_decay)
module.tracked_max_input = update_ema(module.tracked_max_input, current_max,
module.ema_decay)
# quantize weight
rmin, rmax = torch.min(weight), torch.max(weight)
module.scale, module.zero_point = update_quantization_param(weight_bits, rmin, rmax)
......@@ -500,6 +495,31 @@ class QAT_Quantizer(Quantizer):
wrapper.module.weight = weight
return weight
def quantize_input(self, inputs, wrapper, **kwargs):
config = wrapper.config
module = wrapper.module
input_bits = get_bits_length(config, 'input')
module.input_bits = torch.Tensor([input_bits])
quant_start_step = config.get('quant_start_step', 0)
assert input_bits >= 1, "quant bits length should be at least 1"
if quant_start_step > self.bound_model.steps:
module.tracked_min_input, module.tracked_max_input = torch.min(inputs), torch.max(inputs)
return inputs
# we dont update output quantization parameters in evaluation stage
if wrapper.training:
current_min, current_max = torch.min(inputs), torch.max(inputs)
module.tracked_min_input = update_ema(module.tracked_min_input, current_min,
module.ema_decay)
module.tracked_max_input = update_ema(module.tracked_max_input, current_max,
module.ema_decay)
module.scale, module.zero_point = update_quantization_param(
input_bits, module.tracked_min_input, module.tracked_max_input)
inp = self._quantize(input_bits, module, inputs)
inp = self._dequantize(module, inp)
return inp
def quantize_output(self, output, wrapper, **kwargs):
config = wrapper.config
module = wrapper.module
......@@ -519,8 +539,9 @@ class QAT_Quantizer(Quantizer):
module.ema_decay)
module.tracked_max_output = update_ema(module.tracked_max_output, current_max,
module.ema_decay)
module.scale, module.zero_point = update_quantization_param(
output_bits, module.tracked_min_output, module.tracked_max_output)
module.scale, module.zero_point = update_quantization_param(
output_bits, module.tracked_min_output, module.tracked_max_output)
out = self._quantize(output_bits, module, output)
out = self._dequantize(module, out)
return out
......@@ -556,8 +577,6 @@ class QAT_Quantizer(Quantizer):
calibration_config[name] = {}
if hasattr(module, 'weight_bits'):
calibration_config[name]['weight_bits'] = int(module.weight_bits)
calibration_config[name]['tracked_min_input'] = float(module.tracked_min_input)
calibration_config[name]['tracked_max_input'] = float(module.tracked_max_input)
# Recover weight/bias for batch normalization folding
actual_weight = getattr(module, 'old_weight', None)
......@@ -573,6 +592,10 @@ class QAT_Quantizer(Quantizer):
module.register_parameter('bias', actual_bias)
else:
setattr(module, 'bias', None)
if hasattr(module, 'input_bit'):
calibration_config[name]['input_bits'] = int(module.input_bit)
calibration_config[name]['tracked_min_input'] = float(module.tracked_min_input)
calibration_config[name]['tracked_max_input'] = float(module.tracked_max_input)
if hasattr(module, 'output_bits'):
calibration_config[name]['output_bits'] = int(module.output_bits)
......
......@@ -308,7 +308,7 @@ class CompressorTestCase(TestCase):
def test_torch_QAT_quantizer(self):
model = TorchModel()
config_list = [{
'quant_types': ['weight'],
'quant_types': ['weight', 'input'],
'quant_bits': 8,
'op_types': ['Conv2d', 'Linear']
}, {
......@@ -326,18 +326,24 @@ class CompressorTestCase(TestCase):
# test quantize
# range not including 0
eps = 1e-7
input = torch.tensor([[0, 4], [2, 1]]).float()
input = torch.tensor([[1, 4], [2, 1]])
weight = torch.tensor([[1, 2], [3, 5]]).float()
model.conv2.module.weight.data = weight
quantizer.quantize_weight(model.conv2, input_tensor=input)
assert math.isclose(model.conv2.module.scale, 5 / 255, abs_tol=eps)
assert model.conv2.module.zero_point == 0
quantizer.quantize_input(input, model.conv2)
self.assertTrue(torch.allclose(model.conv2.module.scale, torch.tensor([0.04 / 255])))
self.assertTrue(torch.equal(model.conv2.module.zero_point, torch.tensor([0.])))
# range including 0
weight = torch.tensor([[-1, 2], [3, 5]]).float()
model.conv2.module.weight = weight
quantizer.quantize_weight(model.conv2, input_tensor=input)
assert math.isclose(model.conv2.module.scale, 6 / 255, abs_tol=eps)
assert model.conv2.module.zero_point in (42, 43)
quantizer.quantize_input(input, model.conv2)
self.assertTrue(torch.allclose(model.conv2.module.scale, torch.tensor([0.0796 / 255])))
self.assertTrue(torch.equal(model.conv2.module.zero_point, torch.tensor([0.])))
# test value of weight and bias after quantization
weight = torch.tensor([[1.1287, 2.3456], [3.7814, 5.9723]])
weight_valid = torch.tensor([[1.1242, 2.3421], [3.7707, 5.9723]])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment