Commit 06a98372 authored by Cjkkkk's avatar Cjkkkk Committed by chicm-ms
Browse files

add new QAT_quantization (#1732)

parent a03570a0
......@@ -28,17 +28,23 @@ In [Quantization and Training of Neural Networks for Efficient Integer-Arithmeti
### Usage
You can quantize your model to 8 bits with the code below before your training code.
Tensorflow code
```python
from nni.compressors.tensorflow import QAT_Quantizer
config_list = [{ 'q_bits': 8, 'op_types': ['default'] }]
quantizer = QAT_Quantizer(tf.get_default_graph(), config_list)
quantizer.compress()
```
PyTorch code
```python
from nni.compressors.torch import QAT_Quantizer
config_list = [{ 'q_bits': 8, 'op_types': ['default'] }]
model = Mnist()
config_list = [{
'quant_types': ['weight'],
'quant_bits': {
'weight': 8,
}, # you can just use `int` here because all `quan_types` share same bits length, see config for `ReLu6` below.
'op_types':['Conv2d', 'Linear']
}, {
'quant_types': ['output'],
'quant_bits': 8,
'quant_start_step': 7000,
'op_types':['ReLU6']
}]
quantizer = QAT_Quantizer(model, config_list)
quantizer.compress()
```
......@@ -46,9 +52,17 @@ quantizer.compress()
You can view example for more information
#### User configuration for QAT Quantizer
* **q_bits:** This is to specify the q_bits operations to be quantized to
* **quant_types:** : list of string
type of quantization you want to apply, currently support 'weight', 'input', 'output'
* **quant_bits:** int or dict of {str : int}
bits length of quantization, key is the quantization type, value is the length, eg. {'weight', 8},
when the type is int, all quantization types share same bits length
* **quant_start_step:** int
disable quantization until model are run by certain number of steps, this allows the network to enter a more stable
state where activation quantization ranges do not exclude a significant fraction of values, default value is 0
### note
batch normalization folding is currently not supported.
***
## DoReFa Quantizer
......
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.compression.torch import QAT_Quantizer
class Mnist(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
self.fc2 = torch.nn.Linear(500, 10)
self.relu1 = torch.nn.ReLU6()
self.relu2 = torch.nn.ReLU6()
self.relu3 = torch.nn.ReLU6()
def forward(self, x):
x = self.relu1(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = self.relu2(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50)
x = self.relu3(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def train(model, quantizer, device, train_loader, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
quantizer.step()
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('Loss: {} Accuracy: {}%)\n'.format(
test_loss, 100 * correct / len(test_loader.dataset)))
def main():
torch.manual_seed(0)
device = torch.device('cpu')
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True, transform=trans),
batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=False, transform=trans),
batch_size=1000, shuffle=True)
model = Mnist()
'''you can change this to DoReFaQuantizer to implement it
DoReFaQuantizer(configure_list).compress(model)
'''
configure_list = [{
'quant_types': ['weight'],
'quant_bits': {
'weight': 8,
}, # you can just use `int` here because all `quan_types` share same bits length, see config for `ReLu6` below.
'op_types':['Conv2d', 'Linear']
}, {
'quant_types': ['output'],
'quant_bits': 8,
'quant_start_step': 7000,
'op_types':['ReLU6']
}]
quantizer = QAT_Quantizer(model, configure_list)
quantizer.compress()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
for epoch in range(10):
print('# Epoch {} #'.format(epoch))
train(model, quantizer, device, train_loader, optimizer)
test(model, device, test_loader)
if __name__ == '__main__':
main()
......@@ -22,6 +22,80 @@ class NaiveQuantizer(Quantizer):
return weight.div(scale).type(torch.int8).type(orig_type).mul(scale)
def update_ema(biased_ema, value, decay, step):
"""
calculate biased stat and unbiased stat in each step using exponential moving average method
Parameters
----------
biased_ema : float
previous stat value
value : float
current stat value
decay : float
the weight of previous stat value, larger means smoother curve
step : int
current step
Returns
-------
float, float
"""
biased_ema = biased_ema * decay + (1 - decay) * value
unbiased_ema = biased_ema / (1 - decay ** step) # Bias correction
return biased_ema, unbiased_ema
def update_quantization_param(bits, rmin, rmax):
"""
calculate the `zero_point` and `scale`.
Parameters
----------
bits : int
quantization bits length
rmin : float
min value of real value
rmax : float
max value of real value
Returns
-------
float, float
"""
# extend the [min, max] interval to ensure that it contains 0.
# Otherwise, we would not meet the requirement that 0 be an exactly
# representable value.
rmin = min(rmin, 0)
rmax = max(rmax, 0)
# the min and max quantized values, as floating-point values
qmin = 0
qmax = (1 << bits) - 1
# First determine the scale.
scale = (rmax - rmin) / (qmax - qmin)
# Zero-point computation.
initial_zero_point = qmin - rmin / scale
# Now we need to nudge the zero point to be an integer
nudged_zero_point = 0
if initial_zero_point < qmin:
nudged_zero_point = qmin
elif initial_zero_point > qmax:
nudged_zero_point = qmax
else:
nudged_zero_point = torch.round(initial_zero_point)
return scale, nudged_zero_point
def get_bits_length(config, quant_type):
if isinstance(config["quant_bits"], int):
return config["quant_bits"]
else:
return config["quant_bits"].get(quant_type)
class QAT_Quantizer(Quantizer):
"""Quantizer using the DoReFa scheme, as defined in:
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
......@@ -29,23 +103,119 @@ class QAT_Quantizer(Quantizer):
"""
def __init__(self, model, config_list):
"""
config_list: supported keys:
- q_bits
Parameters
----------
layer : LayerInfo
the layer to quantize
config_list : list of dict
list of configurations for quantization
supported keys for dict:
- quant_types : list of string
type of quantization you want to apply, currently support 'weight', 'input', 'output'
- quant_bits : int or dict of {str : int}
bits length of quantization, key is the quantization type, value is the length, eg. {'weight', 8},
when the type is int, all quantization types share same bits length
- quant_start_step : int
disable quantization until model are run by certain number of steps, this allows the network to enter a more stable
state where activation quantization ranges do not exclude a significant fraction of values, default value is 0
- op_types : list of string
types of nn.module you want to apply quantization, eg. 'Conv2d'
"""
super().__init__(model, config_list)
self.steps = 1
modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress:
layer.module.register_buffer("zero_point", None)
layer.module.register_buffer("scale", None)
if "output" in config.get("quant_types", []):
layer.module.register_buffer('ema_decay', torch.Tensor([0.99]))
layer.module.register_buffer('tracked_min_biased', torch.zeros(1))
layer.module.register_buffer('tracked_min', torch.zeros(1))
layer.module.register_buffer('tracked_max_biased', torch.zeros(1))
layer.module.register_buffer('tracked_max', torch.zeros(1))
def quantize_weight(self, weight, config, **kwargs):
if config['q_bits'] <= 1:
def _quantize(self, bits, op, real_val):
"""
quantize real value.
Parameters
----------
bits : int
quantization bits length
op : torch.nn.module
target module
real_val : float
real value to be quantized
Returns
-------
float
"""
transformed_val = op.zero_point + real_val / op.scale
qmin = 0
qmax = (1 << bits) - 1
clamped_val = torch.clamp(transformed_val, qmin, qmax)
quantized_val = torch.round(clamped_val)
return quantized_val
def _dequantize(self, op, quantized_val):
"""
dequantize quantized value.
Because we simulate quantization in training process, all the computations still happen as float point computations, which means we
first quantize tensors then dequantize them. For more details, please refer to the paper.
Parameters
----------
op : torch.nn.Module
target module
quantized_val : float
quantized_val value to be dequantized
Returns
-------
float
"""
real_val = op.scale * (quantized_val - op.zero_point)
return real_val
def quantize_weight(self, weight, config, op, **kwargs):
weight_bits = get_bits_length(config, 'weight')
quant_start_step = config.get('quant_start_step', 0)
assert weight_bits >= 1, "quant bits length should be at least 1"
if quant_start_step > self.steps:
return weight
a = torch.min(weight)
b = torch.max(weight)
n = pow(2, config['q_bits'])
scale = (b-a)/(n-1)
zero_point = a
out = torch.round((weight - zero_point)/scale)
out = out*scale + zero_point
orig_type = weight.dtype
return out.type(orig_type)
rmin, rmax = torch.min(weight), torch.max(weight)
op.scale, op.zero_point = update_quantization_param(weight_bits, rmin, rmax)
out = self._quantize(weight_bits, op, weight)
out = self._dequantize(op, out)
return out
def quantize_output(self, output, config, op, **kwargs):
output_bits = get_bits_length(config, 'output')
quant_start_step = config.get('quant_start_step', 0)
assert output_bits >= 1, "quant bits length should be at least 1"
if quant_start_step > self.steps:
return output
current_min, current_max = torch.min(output), torch.max(output)
op.tracked_min_biased, op.tracked_min = update_ema(op.tracked_min_biased, current_min, op.ema_decay, self.steps)
op.tracked_max_biased, op.tracked_max = update_ema(op.tracked_max_biased, current_max, op.ema_decay, self.steps)
op.scale, op.zero_point = update_quantization_param(output_bits, op.tracked_min, op.tracked_max)
out = self._quantize(output_bits, op, output)
out = self._dequantize(op, out)
return out
def fold_bn(self, config, **kwargs):
# TODO simulate folded weight
pass
def step(self):
"""
override `compressor` `step` method, quantization only happens after certain number of steps
"""
self.steps += 1
class DoReFaQuantizer(Quantizer):
......
......@@ -304,6 +304,12 @@ class Quantizer(Compressor):
assert layer._forward is None, 'Each model can only be compressed once'
assert "quant_types" in config, 'must provide quant_types in config'
assert isinstance(config["quant_types"], list), 'quant_types must be list type'
assert "quant_bits" in config, 'must provide quant_bits in config'
assert isinstance(config["quant_bits"], int) or isinstance(config["quant_bits"], dict), 'quant_bits must be dict type or int type'
if isinstance(config["quant_bits"], dict):
for quant_type in config["quant_types"]:
assert quant_type in config["quant_bits"], 'bits length for %s must be specified in quant_bits dict' % quant_type
if 'weight' in config["quant_types"]:
if not _check_weight(layer.module):
......@@ -312,7 +318,7 @@ class Quantizer(Compressor):
def new_forward(*inputs):
if 'input' in config["quant_types"]:
inputs = self.quantize_input(inputs, config=config, op=layer.module, op_type=layer.type, op_name=layer.name)
inputs = straight_through_quantize_input.apply(inputs, self, config, layer)
if 'weight' in config["quant_types"] and _check_weight(layer.module):
weight = layer.module.weight.data
......@@ -324,12 +330,32 @@ class Quantizer(Compressor):
result = layer._forward(*inputs)
if 'output' in config["quant_types"]:
result = self.quantize_output(result, config, op=layer.module, op_type=layer.type, op_name=layer.name)
result = straight_through_quantize_output.apply(result, self, config, layer)
return result
layer.module.forward = new_forward
class straight_through_quantize_output(torch.autograd.Function):
@staticmethod
def forward(ctx, output, quantizer, config, layer):
return quantizer.quantize_output(output, config, op=layer.module, op_type=layer.type, op_name=layer.name)
@staticmethod
def backward(ctx, grad_output):
# Straight-through estimator
return grad_output, None, None, None
class straight_through_quantize_input(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, quantizer, config, layer):
return quantizer.quantize_input(inputs, config, op=layer.module, op_type=layer.type, op_name=layer.name)
@staticmethod
def backward(ctx, grad_output):
# Straight-through estimator
return grad_output, None, None, None
def _check_weight(module):
try:
return isinstance(module.weight, torch.nn.Parameter) and isinstance(module.weight.data, torch.Tensor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment