Unverified Commit af929fdb authored by chenbohua3's avatar chenbohua3 Committed by GitHub
Browse files

Add LSQ quantizer (#3503)

parent 761732ab
...@@ -87,6 +87,8 @@ Quantization algorithms compress the original network by reducing the number of ...@@ -87,6 +87,8 @@ Quantization algorithms compress the original network by reducing the number of
- DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients. `Reference Paper <https://arxiv.org/abs/1606.06160>`__ - DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients. `Reference Paper <https://arxiv.org/abs/1606.06160>`__
* - `BNN Quantizer <../Compression/Quantizer.rst#bnn-quantizer>`__ * - `BNN Quantizer <../Compression/Quantizer.rst#bnn-quantizer>`__
- Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1. `Reference Paper <https://arxiv.org/abs/1602.02830>`__ - Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1. `Reference Paper <https://arxiv.org/abs/1602.02830>`__
* - `LSQ Quantizer <../Compression/Quantizer.rst#lsq-quantizer>`__
- Learned step size quantization. `Reference Paper <https://arxiv.org/pdf/1902.08153.pdf>`__
Model Speedup Model Speedup
......
...@@ -8,6 +8,7 @@ Index of supported quantization algorithms ...@@ -8,6 +8,7 @@ Index of supported quantization algorithms
* `QAT Quantizer <#qat-quantizer>`__ * `QAT Quantizer <#qat-quantizer>`__
* `DoReFa Quantizer <#dorefa-quantizer>`__ * `DoReFa Quantizer <#dorefa-quantizer>`__
* `BNN Quantizer <#bnn-quantizer>`__ * `BNN Quantizer <#bnn-quantizer>`__
* `LSQ Quantizer <#lsq-quantizer>`__
Naive Quantizer Naive Quantizer
--------------- ---------------
...@@ -86,6 +87,61 @@ note ...@@ -86,6 +87,61 @@ note
batch normalization folding is currently not supported. batch normalization folding is currently not supported.
----
LSQ Quantizer
-------------
In `LEARNED STEP SIZE QUANTIZATION <https://arxiv.org/pdf/1902.08153.pdf>`__\ , authors Steven K. Esser and Jeffrey L. McKinstry provide an algorithm to train the scales with gradients.
..
The authors introduce a novel means to estimate and scale the task loss gradient at each weight and activation layers quantizer step size, such that it can be learned in conjunction with other network parameters.
Usage
^^^^^
You can add codes below before your training codes. Three things must be done:
1. configure which layer to be quantized and which tensor (input/output/weight) of that layer to be quantized.
2. construct the lsq quantizer
3. call the `compress` API
PyTorch code
.. code-block:: python
from nni.algorithms.compression.pytorch.quantization import LsqQuantizer
model = Mnist()
configure_list = [{
'quant_types': ['weight', 'input'],
'quant_bits': {
'weight': 8,
'input': 8,
},
'op_names': ['conv1']
}, {
'quant_types': ['output'],
'quant_bits': {'output': 8,},
'op_names': ['relu1']
}]
quantizer = LsqQuantizer(model, configure_list, optimizer)
quantizer.compress()
You can view example for more information. :githublink:`examples/model_compress/quantization/LSQ_torch_quantizer.py <examples/model_compress/quantization/LSQ_torch_quantizer.py>`
User configuration for LSQ Quantizer
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
common configuration needed by compression algorithms can be found at `Specification of `config_list <./QuickStart.rst>`__.
configuration needed by this algorithm :
---- ----
DoReFa Quantizer DoReFa Quantizer
......
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.algorithms.compression.pytorch.quantization import LsqQuantizer
from nni.compression.pytorch.quantization_speedup import ModelSpeedupTensorRT
class Mnist(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
self.fc2 = torch.nn.Linear(500, 10)
self.relu1 = torch.nn.ReLU6()
self.relu2 = torch.nn.ReLU6()
self.relu3 = torch.nn.ReLU6()
self.max_pool1 = torch.nn.MaxPool2d(2, 2)
self.max_pool2 = torch.nn.MaxPool2d(2, 2)
def forward(self, x):
x = self.relu1(self.conv1(x))
x = self.max_pool1(x)
x = self.relu2(self.conv2(x))
x = self.max_pool2(x)
x = x.view(-1, 4 * 4 * 50)
x = self.relu3(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def train(model, quantizer, device, train_loader, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('Loss: {} Accuracy: {}%)\n'.format(
test_loss, 100 * correct / len(test_loader.dataset)))
def test_trt(engine, test_loader):
test_loss = 0
correct = 0
time_elasped = 0
for data, target in test_loader:
output, time = engine.inference(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
time_elasped += time
test_loss /= len(test_loader.dataset)
print('Loss: {} Accuracy: {}%'.format(
test_loss, 100 * correct / len(test_loader.dataset)))
print("Inference elapsed_time (whole dataset): {}s".format(time_elasped))
def main():
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True, transform=trans),
batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=False, transform=trans),
batch_size=1000, shuffle=True)
model = Mnist()
configure_list = [{
'quant_types': ['weight', 'input'],
'quant_bits': {'weight': 8, 'input': 8},
'op_names': ['conv1']
}, {
'quant_types': ['output'],
'quant_bits': {'output': 8, },
'op_names': ['relu1']
}, {
'quant_types': ['weight', 'input'],
'quant_bits': {'weight': 8, 'input': 8},
'op_names': ['conv2']
}, {
'quant_types': ['output'],
'quant_bits': {'output': 8},
'op_names': ['relu2']
}, {
'quant_types': ['output'],
'quant_bits': {'output': 8},
'op_names': ['max_pool2']
}
]
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
quantizer = LsqQuantizer(model, configure_list, optimizer)
quantizer.compress()
model.to(device)
for epoch in range(40):
print('# Epoch {} #'.format(epoch))
train(model, quantizer, device, train_loader, optimizer)
test(model, device, test_loader)
model_path = "mnist_model.pth"
calibration_path = "mnist_calibration.pth"
calibration_config = quantizer.export_model(model_path, calibration_path)
test(model, device, test_loader)
print("calibration_config: ", calibration_config)
batch_size = 32
input_shape = (batch_size, 1, 28, 28)
engine = ModelSpeedupTensorRT(model, input_shape, config=calibration_config, batchsize=batch_size)
engine.compress()
test_trt(engine, test_loader)
if __name__ == '__main__':
main()
...@@ -6,9 +6,9 @@ import copy ...@@ -6,9 +6,9 @@ import copy
import torch import torch
from schema import Schema, And, Or, Optional from schema import Schema, And, Or, Optional
from nni.compression.pytorch.utils.config_validation import CompressorSchema from nni.compression.pytorch.utils.config_validation import CompressorSchema
from nni.compression.pytorch.compressor import Quantizer, QuantGrad, QuantType from nni.compression.pytorch.compressor import Quantizer, QuantForward, QuantGrad, QuantType
__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer'] __all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer', 'LsqQuantizer']
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -146,7 +146,7 @@ class QAT_Quantizer(Quantizer): ...@@ -146,7 +146,7 @@ class QAT_Quantizer(Quantizer):
types of nn.module you want to apply quantization, eg. 'Conv2d' types of nn.module you want to apply quantization, eg. 'Conv2d'
""" """
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
self.quant_grad = QATGrad self.quant_grad = QATGrad.apply
modules_to_compress = self.get_modules_to_compress() modules_to_compress = self.get_modules_to_compress()
self.bound_model.register_buffer("steps", torch.Tensor([1])) self.bound_model.register_buffer("steps", torch.Tensor([1]))
for layer, config in modules_to_compress: for layer, config in modules_to_compress:
...@@ -474,7 +474,7 @@ class BNNQuantizer(Quantizer): ...@@ -474,7 +474,7 @@ class BNNQuantizer(Quantizer):
def __init__(self, model, config_list, optimizer=None): def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
self.quant_grad = ClipGrad self.quant_grad = ClipGrad.apply
modules_to_compress = self.get_modules_to_compress() modules_to_compress = self.get_modules_to_compress()
for layer, config in modules_to_compress: for layer, config in modules_to_compress:
if "weight" in config.get("quant_types", []): if "weight" in config.get("quant_types", []):
...@@ -560,3 +560,205 @@ class BNNQuantizer(Quantizer): ...@@ -560,3 +560,205 @@ class BNNQuantizer(Quantizer):
self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, input_shape, device) self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, input_shape, device)
return calibration_config return calibration_config
class LsqQuantizer(Quantizer):
"""Quantizer defined in:
Learned Step Size Quantization (ICLR 2020)
https://arxiv.org/pdf/1902.08153.pdf
"""
def __init__(self, model, config_list, optimizer=None):
"""
Parameters
----------
model : torch.nn.Module
the model to be quantized
config_list : list of dict
list of configurations for quantization
supported keys for dict:
- quant_types : list of string
type of quantization you want to apply, currently support 'weight', 'input', 'output'
- quant_bits : int or dict of {str : int}
bits length of quantization, key is the quantization type, value is the length, eg. {'weight': 8},
when the type is int, all quantization types share same bits length
- quant_start_step : int
disable quantization until model are run by certain number of steps, this allows the network to enter a more stable
state where activation quantization ranges do not exclude a significant fraction of values, default value is 0
- op_types : list of string
types of nn.module you want to apply quantization, eg. 'Conv2d'
"""
super().__init__(model, config_list, optimizer)
self.quant_grad = QuantForward()
modules_to_compress = self.get_modules_to_compress()
self.bound_model.register_buffer("steps", torch.Tensor([1]))
for layer, config in modules_to_compress:
if "weight" in config.get("quant_types", []):
layer.module.register_parameter("weight_scale", torch.nn.Parameter(torch.Tensor([1.0])))
# todo: support per-channel quantization for weight since TensorRT use it for conv weight
q_bit = get_bits_length(config, "weight")
layer.module.register_buffer('weight_bit', torch.Tensor([q_bit]))
qmax = 2 ** (q_bit - 1) - 1
qmin = -2 ** (q_bit - 1)
init_weight_scale = layer.module.weight.data.detach().abs().mean() * 2 / (qmax ** 0.5)
layer.module.weight_scale = torch.nn.Parameter(init_weight_scale)
layer.module.weight_qmax = qmax
layer.module.weight_qmin = qmin
self.optimizer.add_param_group({"params": layer.module.weight_scale})
if "output" in config.get("quant_types", []):
# scale of activation will be initialized using the first batch data
layer.module.register_parameter("output_scale", torch.nn.Parameter(torch.Tensor([1.0])))
q_bit = get_bits_length(config, "output")
layer.module.register_buffer('output_bit', torch.Tensor([q_bit]))
qmax = 2 ** (q_bit - 1) - 1
qmin = -2 ** (q_bit - 1)
layer.module.output_qmax = qmax
layer.module.output_qmin = qmin
self.optimizer.add_param_group({"params": layer.module.output_scale})
if "input" in config.get("quant_types", []):
# scale of input will be initialized using the first batch data
layer.module.register_parameter("input_scale", torch.nn.Parameter(torch.Tensor([1.0])))
q_bit = get_bits_length(config, "input")
layer.module.register_buffer('input_bit', torch.Tensor([q_bit]))
qmax = 2 ** (q_bit - 1) - 1
qmin = -2 ** (q_bit - 1)
layer.module.input_qmax = qmax
layer.module.input_qmin = qmin
self.optimizer.add_param_group({"params": layer.module.input_scale})
@staticmethod
def grad_scale(x, scale):
"""
Used to scale the gradient. Give tensor `x`, we have `y=grad_scale(x, scale)=x` in the forward pass,
which means that this function will not change the value of `x`. In the backward pass, we have:
:math:`\frac{\alpha_L}{\alpha_x}=\frac{\alpha_L}{\alpha_y}*\frac{\alpha_y}{\alpha_x}=sclae*\frac{\alpha_L}{\alpha_x}`
This means that the origin gradient of x is scaled by a factor of `scale`. Applying this function
to a nn.Parameter will scale the gradient of it without changing its value.
"""
y = x
y_grad = x * scale
return (y - y_grad).detach() + y_grad
@staticmethod
def round_pass(x):
"""
A simple way to achieve STE operation.
"""
y = x.round()
y_grad = x
return (y - y_grad).detach() + y_grad
def quantize(self, x, scale, qmin, qmax):
grad_scale_factor = 1.0 / ((qmax * x.numel()) ** 0.5)
scale = self.grad_scale(scale, grad_scale_factor)
x = x / scale
x = torch.clamp(x, qmin, qmax)
x = self.round_pass(x)
x = x * scale
return x
def quantize_weight(self, wrapper, **kwargs):
module = wrapper.module
# todo: add support for quantize bias. If we use TensorRT as backend, there is no need to quantize
# bias
old_weight = module.old_weight
weight = self.quantize(old_weight, module.weight_scale, module.weight_qmin, module.weight_qmax)
module.weight = weight
return weight
def quantize_output(self, output, wrapper, **kwargs):
module = wrapper.module
# initialize the scale
if self.bound_model.steps == 1:
qmax = module.output_qmax
init_oup_scale = output.data.detach().abs().mean() * 2 / (qmax ** 0.5)
module.output_scale.data = init_oup_scale
output = self.quantize(output, module.output_scale, module.output_qmin, module.output_qmax)
return output
def quantize_input(self, *inputs, wrapper, **kwargs):
# This is hacky since it is not recommended to modify a tuple
# NB: support layers with multi inputs
module = wrapper.module
# initialize the scale
if self.bound_model.steps == 1:
qmax = module.input_qmax
init_oup_scale = inputs[0].data.detach().abs().mean() * 2 / (qmax ** 0.5)
module.input_scale.data = init_oup_scale
new_input = self.quantize(inputs[0], module.input_scale, module.input_qmin, module.input_qmax)
list_inp = list(inputs)
list_inp[0] = new_input
return tuple(list_inp)
def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None):
"""
Export quantized model weights and calibration parameters(optional)
Parameters
----------
model_path : str
path to save quantized model weight
calibration_path : str
(optional) path to save quantize parameters after calibration
onnx_path : str
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
Returns
-------
Dict
"""
assert model_path is not None, 'model_path must be specified'
self._unwrap_model()
calibration_config = {}
for name, module in self.bound_model.named_modules():
if hasattr(module, 'input_bit') or hasattr(module, 'output_bit'):
calibration_config[name] = {}
if hasattr(module, 'weight_bit'):
calibration_config[name]['weight_bit'] = int(module.weight_bit)
abs_max_input = float(module.input_scale * module.input_qmax)
calibration_config[name]['tracked_min_input'] = -abs_max_input
calibration_config[name]['tracked_max_input'] = abs_max_input
if hasattr(module, 'output_bit'):
calibration_config[name]['activation_bit'] = int(module.output_bit)
abs_max_output = float(module.output_scale * module.output_qmax)
calibration_config[name]['tracked_min_activation'] = -abs_max_output
calibration_config[name]['tracked_max_activation'] = abs_max_output
self._del_simulated_attr(module)
self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path,
input_shape, device)
return calibration_config
def _del_simulated_attr(self, module):
"""
delete redundant parameters in quantize module
"""
del_attr_list = ['old_weight', 'tracked_min_input', 'tracked_max_input', 'tracked_min_activation', \
'tracked_max_activation', 'output_scale', 'input_scale', 'weight_scale','weight_bit', 'output_bit', 'input_bit']
for attr in del_attr_list:
if hasattr(module, attr):
delattr(module, attr)
def step_with_optimizer(self):
"""
override `compressor` `step` method, quantization only happens after certain number of steps
"""
self.bound_model.steps += 1
...@@ -474,13 +474,13 @@ class QuantizerModuleWrapper(torch.nn.Module): ...@@ -474,13 +474,13 @@ class QuantizerModuleWrapper(torch.nn.Module):
def forward(self, *inputs): def forward(self, *inputs):
if 'input' in self.config['quant_types']: if 'input' in self.config['quant_types']:
inputs = self.quantizer.quant_grad.apply( inputs = self.quantizer.quant_grad(
inputs, inputs,
QuantType.QUANT_INPUT, QuantType.QUANT_INPUT,
self) self)
if 'weight' in self.config['quant_types'] and _check_weight(self.module): if 'weight' in self.config['quant_types'] and _check_weight(self.module):
self.quantizer.quant_grad.apply( self.quantizer.quant_grad(
self.module.old_weight, self.module.old_weight,
QuantType.QUANT_WEIGHT, QuantType.QUANT_WEIGHT,
self, inputs[0]) self, inputs[0])
...@@ -489,12 +489,13 @@ class QuantizerModuleWrapper(torch.nn.Module): ...@@ -489,12 +489,13 @@ class QuantizerModuleWrapper(torch.nn.Module):
result = self.module(*inputs) result = self.module(*inputs)
if 'output' in self.config['quant_types']: if 'output' in self.config['quant_types']:
result = self.quantizer.quant_grad.apply( result = self.quantizer.quant_grad(
result, result,
QuantType.QUANT_OUTPUT, QuantType.QUANT_OUTPUT,
self) self)
return result return result
class Quantizer(Compressor): class Quantizer(Compressor):
""" """
Base quantizer for pytorch quantizer Base quantizer for pytorch quantizer
...@@ -502,7 +503,7 @@ class Quantizer(Compressor): ...@@ -502,7 +503,7 @@ class Quantizer(Compressor):
def __init__(self, model, config_list, optimizer=None): def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
self.quant_grad = QuantGrad self.quant_grad = QuantGrad.apply
if self.optimizer is not None: if self.optimizer is not None:
self.patch_optimizer(self.step_with_optimizer) self.patch_optimizer(self.step_with_optimizer)
for wrapper in self.get_modules_wrapper(): for wrapper in self.get_modules_wrapper():
...@@ -719,15 +720,7 @@ class QuantGrad(torch.autograd.Function): ...@@ -719,15 +720,7 @@ class QuantGrad(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, tensor, quant_type, wrapper, input_tensor=None, **kwargs): def forward(ctx, tensor, quant_type, wrapper, input_tensor=None, **kwargs):
if quant_type == QuantType.QUANT_INPUT: output = quantize_helper(tensor, quant_type, wrapper, input_tensor, **kwargs)
output = wrapper.quantizer.quantize_input(tensor, wrapper, **kwargs)
elif quant_type == QuantType.QUANT_WEIGHT:
output = wrapper.quantizer.quantize_weight(wrapper, input_tensor=input_tensor, **kwargs)
elif quant_type == QuantType.QUANT_OUTPUT:
output = wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs)
else:
raise ValueError("unrecognized QuantType.")
bits = QuantGrad.get_bits_length(wrapper.config, QType_Dict[quant_type]) bits = QuantGrad.get_bits_length(wrapper.config, QType_Dict[quant_type])
qmin, qmax = torch.Tensor([0]).to(tensor.device), torch.Tensor([(1 << bits) - 1]).to(tensor.device) qmin, qmax = torch.Tensor([0]).to(tensor.device), torch.Tensor([(1 << bits) - 1]).to(tensor.device)
...@@ -750,3 +743,24 @@ def _check_weight(module): ...@@ -750,3 +743,24 @@ def _check_weight(module):
return isinstance(module.weight.data, torch.Tensor) return isinstance(module.weight.data, torch.Tensor)
except AttributeError: except AttributeError:
return False return False
def quantize_helper(tensor, quant_type, wrapper, input_tensor=None, **kwargs):
if quant_type == QuantType.QUANT_INPUT:
output = wrapper.quantizer.quantize_input(*tensor, wrapper=wrapper, **kwargs)
elif quant_type == QuantType.QUANT_WEIGHT:
output = wrapper.quantizer.quantize_weight(wrapper, input_tensor=input_tensor, **kwargs)
elif quant_type == QuantType.QUANT_OUTPUT:
output = wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs)
else:
raise ValueError("unrecognized QuantType.")
return output
class QuantForward(torch.nn.Module):
"""
Base class for executing quantization operations. This is for quantization algorithms
that do not need to customize gradient.
"""
def forward(self, tensor, quant_type, wrapper, input_tensor=None, **kwargs):
return quantize_helper(tensor, quant_type, wrapper, input_tensor, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment