Unverified Commit ede59380 authored by Cjkkkk's avatar Cjkkkk Committed by GitHub
Browse files

Patch optimizer for pruner (#2058)

parent f86c7005
## Overview
The model compression framework has two main components: `pruner` and `module wrapper`.
### pruner
A `pruner` is responsible for :
1. provide a `cal_mask` method that calculates masks for weight and bias.
2. replace the module with `module wrapper` based on config.
3. modify the optimizer so that the `cal_mask` method is called every time the `step` method is called.
### module wrapper
A `module wrapper` is a module containing :
1. the origin module
2. some buffers used by `cal_mask`
3. a new forward method that applies masks before running the original forward method.
the reasons to use `module wrapper` :
1. some buffers are needed by `cal_mask` to calculate masks and these buffers should be registered in `module wrapper` so that the original modules are not contaminated.
2. a new `forward` method is needed to apply masks to weight before calling the real `forward` method.
## How it works
A basic pruner usage:
```python
configure_list = [{
'sparsity': 0.7,
'op_types': ['BatchNorm2d'],
}]
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
pruner = SlimPruner(model, configure_list, optimizer)
model = pruner.compress()
```
A pruner receive model, config and optimizer as arguments. In the `__init__` method, the `step` method of the optimizer is replaced with a new `step` method that calls `cal_mask`. Also, all modules are checked if they need to be pruned based on config. If a module needs to be pruned, then this module is replaced by a `module wrapper`. Afterward, the new model and new optimizer are returned, which can be trained as before. `compress` method will calculate the default masks.
## Implement a new pruning algorithm
Implementing a new pruning algorithm requires implementing a new `pruner` class, which should subclass `Pruner` and override the `cal_mask` method. The `cal_mask` is called by`optimizer.step` method.
The `Pruner` base class provided basic functionality listed above, for example, replacing modules and patching optimizer.
A basic pruner look likes this:
```python
class NewPruner(Pruner):
def __init__(self, model, config_list, optimizer)
super().__init__(model, config_list, optimizer)
# do some initialization
def calc_mask(self, wrapper, **kwargs):
# do something to calculate weight_mask
wrapper.weight_mask = weight_mask
```
### Set wrapper attribute
Sometimes `cal_mask` must save some state data, therefore users can use `set_wrappers_attribute` API to register attribute just like how buffers are registered in PyTorch modules. These buffers will be registered to `module wrapper`. Users can access these buffers through `module wrapper`.
```python
class NewPruner(Pruner):
def __init__(self, model, config_list, optimizer):
super().__init__(model, config_list, optimizer)
self.set_wrappers_attribute("if_calculated", False)
def calc_mask(self, wrapper):
# do something to calculate weight_mask
if wrapper.if_calculated:
pass
else:
wrapper.if_calculated = True
# update masks
```
### Collect data during forward
Sometimes users want to collect some data during the modules' forward method, for example, the mean value of the activation. Therefore user can add a customized collector to module.
```python
class ActivationRankFilterPruner(Pruner):
def __init__(self, model, config_list, optimizer, activation='relu', statistics_batch_num=1):
super().__init__(model, config_list, optimizer)
self.set_wrappers_attribute("if_calculated", False)
self.set_wrappers_attribute("collected_activation", [])
self.statistics_batch_num = statistics_batch_num
def collector(module_, input_, output):
if len(module_.collected_activation) < self.statistics_batch_num:
module_.collected_activation.append(self.activation(output.detach().cpu()))
self.add_activation_collector(collector)
assert activation in ['relu', 'relu6']
if activation == 'relu':
self.activation = torch.nn.functional.relu
elif activation == 'relu6':
self.activation = torch.nn.functional.relu6
else:
self.activation = None
```
The collector function will be called each time the forward method runs.
Users can also remove this collector like this:
```python
collector_id = self.add_activation_collector(collector)
# ...
self.remove_activation_collector(collector_id)
```
### Multi-GPU support
On multi-GPU training, buffers and parameters are copied to multiple GPU every time the `forward` method runs on multiple GPU. If buffers and parameters are updated in the `forward` method, an `in-place` update is needed to ensure the update is effective.
Since `cal_mask` is called in the `optimizer.step` method, which happens after the `forward` method and happens only on one GPU, it supports multi-GPU naturally.
\ No newline at end of file
......@@ -3,7 +3,7 @@ As larger neural networks with more layers and nodes are considered, reducing th
We are glad to introduce model compression toolkit on top of NNI, it's still in the experiment phase which might evolve based on usage feedback. We'd like to invite you to use, feedback and even contribute.
NNI provides an easy-to-use toolkit to help user design and use compression algorithms. It currently supports PyTorch with unified interface. For users to compress their models, they only need to add several lines in their code. There are some popular model compression algorithms built-in in NNI. Users could further use NNI's auto tuning power to find the best compressed model, which is detailed in [Auto Model Compression](./AutoCompression.md). On the other hand, users could easily customize their new compression algorithms using NNI's interface, refer to the tutorial [here](#customize-new-compression-algorithms).
NNI provides an easy-to-use toolkit to help user design and use compression algorithms. It currently supports PyTorch with unified interface. For users to compress their models, they only need to add several lines in their code. There are some popular model compression algorithms built-in in NNI. Users could further use NNI's auto tuning power to find the best compressed model, which is detailed in [Auto Model Compression](./AutoCompression.md). On the other hand, users could easily customize their new compression algorithms using NNI's interface, refer to the tutorial [here](#customize-new-compression-algorithms). Details about how model compression framework works can be found in [here](./Framework.md).
For a survey of model compression, you can refer to this paper: [Recent Advances in Efficient Computation of Deep Convolutional Neural Networks](https://arxiv.org/pdf/1802.00939.pdf).
......
......@@ -21,3 +21,4 @@ For details, please refer to the following tutorials:
Quantizers <quantizers>
Model Speedup <Compressor/ModelSpeedup>
Automatic Model Compression <Compressor/AutoCompression>
Implementation <Compressor/Framework>
......@@ -88,14 +88,14 @@ def main():
# Prune model and test accuracy without fine tuning.
print('=' * 10 + 'Test on the pruned model before fine tune' + '=' * 10)
pruner = L1FilterPruner(model, configure_list)
optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
pruner = L1FilterPruner(model, configure_list, optimizer_finetune)
model = pruner.compress()
test(model, device, test_loader)
# top1 = 88.19%
# Fine tune the pruned model for 40 epochs and test accuracy
print('=' * 10 + 'Fine tuning' + '=' * 10)
optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
best_top1 = 0
for epoch in range(40):
pruner.update_epoch(epoch)
......
......@@ -79,15 +79,15 @@ def main():
}, {
'quant_types': ['output'],
'quant_bits': 8,
'quant_start_step': 7000,
'quant_start_step': 1000,
'op_types':['ReLU6']
}]
quantizer = QAT_Quantizer(model, configure_list)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
quantizer = QAT_Quantizer(model, configure_list, optimizer)
quantizer.compress()
model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
for epoch in range(10):
for epoch in range(40):
print('# Epoch {} #'.format(epoch))
train(model, quantizer, device, train_loader, optimizer)
test(model, device, test_loader)
......
......@@ -71,7 +71,6 @@ if __name__ == '__main__':
pruner = LotteryTicketPruner(model, configure_list, optimizer)
pruner.compress()
#model = nn.DataParallel(model)
for i in pruner.get_prune_iterations():
pruner.prune_iteration_start()
......
......@@ -215,7 +215,7 @@ def main(args):
optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
best_top1 = 0
pruner = create_pruner(model, args.pruner_name)
pruner = create_pruner(model, args.pruner_name, optimizer_finetune)
model = pruner.compress()
if args.multi_gpu and torch.cuda.device_count() > 1:
......
......@@ -107,7 +107,8 @@ def main():
# Prune model and test accuracy without fine tuning.
print('=' * 10 + 'Test the pruned model before fine tune' + '=' * 10)
pruner = SlimPruner(model, configure_list)
optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
pruner = SlimPruner(model, configure_list, optimizer_finetune)
model = pruner.compress()
if args.parallel:
if torch.cuda.device_count() > 1:
......@@ -119,13 +120,12 @@ def main():
model.to(device)
# Fine tune the pruned model for 40 epochs and test accuracy
print('=' * 10 + 'Fine tuning' + '=' * 10)
optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
best_top1 = 0
for epoch in range(40):
pruner.update_epoch(epoch)
print('# Epoch {} #'.format(epoch))
train(model, device, train_loader, optimizer_finetune)
top1 = test(model, device, test_loader)
if top1 > best_top1:
best_top1 = top1
# Export the best model, 'model_path' stores state_dict of the pruned model,
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .compressor import LayerInfo, Compressor, Pruner, Quantizer
from .compressor import Compressor, Pruner, Quantizer
from .pruners import *
from .weight_rank_filter_pruners import *
from .activation_rank_filter_pruners import *
......
......@@ -16,7 +16,7 @@ class ActivationRankFilterPruner(Pruner):
to achieve a preset level of network sparsity.
"""
def __init__(self, model, config_list, activation='relu', statistics_batch_num=1):
def __init__(self, model, config_list, optimizer, activation='relu', statistics_batch_num=1):
"""
Parameters
----------
......@@ -31,11 +31,15 @@ class ActivationRankFilterPruner(Pruner):
Num of batches for activation statistics
"""
super().__init__(model, config_list)
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
super().__init__(model, config_list, optimizer)
self.set_wrappers_attribute("if_calculated", False)
self.set_wrappers_attribute("collected_activation", [])
self.statistics_batch_num = statistics_batch_num
self.collected_activation = {}
self.hooks = {}
def collector(module_, input_, output):
if len(module_.collected_activation) < self.statistics_batch_num:
module_.collected_activation.append(self.activation(output.detach().cpu()))
self.add_activation_collector(collector)
assert activation in ['relu', 'relu6']
if activation == 'relu':
self.activation = torch.nn.functional.relu
......@@ -44,33 +48,10 @@ class ActivationRankFilterPruner(Pruner):
else:
self.activation = None
def compress(self):
"""
Compress the model, register a hook for collecting activations.
"""
if self.modules_wrapper is not None:
# already compressed
return self.bound_model
else:
self.modules_wrapper = []
modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress:
wrapper = self._wrap_modules(layer, config)
self.modules_wrapper.append(wrapper)
self.collected_activation[layer.name] = []
def _hook(module_, input_, output, name=layer.name):
if len(self.collected_activation[name]) < self.statistics_batch_num:
self.collected_activation[name].append(self.activation(output.detach().cpu()))
wrapper.module.register_forward_hook(_hook)
self._wrap_model()
return self.bound_model
def get_mask(self, base_mask, activations, num_prune):
raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__))
def calc_mask(self, layer, config, **kwargs):
def calc_mask(self, wrapper, **kwargs):
"""
Calculate the mask of given layer.
Filters with the smallest importance criterion which is calculated from the activation are masked.
......@@ -88,29 +69,30 @@ class ActivationRankFilterPruner(Pruner):
dictionary for storing masks
"""
weight = layer.module.weight.data
op_type = layer.type
weight = wrapper.module.weight.data
op_type = wrapper.type
config = wrapper.config
assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)"
assert op_type in ['Conv2d'], "only support Conv2d"
assert op_type in config.get('op_types')
if_calculated = kwargs["if_calculated"]
if if_calculated:
if wrapper.if_calculated:
return None
mask_weight = torch.ones(weight.size()).type_as(weight).detach()
if hasattr(layer.module, 'bias') and layer.module.bias is not None:
mask_bias = torch.ones(layer.module.bias.size()).type_as(layer.module.bias).detach()
if hasattr(wrapper.module, 'bias') and wrapper.module.bias is not None:
mask_bias = torch.ones(wrapper.module.bias.size()).type_as(wrapper.module.bias).detach()
else:
mask_bias = None
mask = {'weight': mask_weight, 'bias': mask_bias}
mask = {'weight_mask': mask_weight, 'bias_mask': mask_bias}
try:
filters = weight.size(0)
num_prune = int(filters * config.get('sparsity'))
if filters < 2 or num_prune < 1 or len(self.collected_activation[layer.name]) < self.statistics_batch_num:
if filters < 2 or num_prune < 1 or len(wrapper.collected_activation) < self.statistics_batch_num:
return mask
mask = self.get_mask(mask, self.collected_activation[layer.name], num_prune)
mask = self.get_mask(mask, wrapper.collected_activation, num_prune)
finally:
if len(self.collected_activation[layer.name]) == self.statistics_batch_num:
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
if len(wrapper.collected_activation) == self.statistics_batch_num:
wrapper.if_calculated = True
return mask
......@@ -123,7 +105,7 @@ class ActivationAPoZRankFilterPruner(ActivationRankFilterPruner):
https://arxiv.org/abs/1607.03250
"""
def __init__(self, model, config_list, activation='relu', statistics_batch_num=1):
def __init__(self, model, config_list, optimizer, activation='relu', statistics_batch_num=1):
"""
Parameters
----------
......@@ -137,7 +119,7 @@ class ActivationAPoZRankFilterPruner(ActivationRankFilterPruner):
statistics_batch_num : int
Num of batches for activation statistics
"""
super().__init__(model, config_list, activation, statistics_batch_num)
super().__init__(model, config_list, optimizer, activation, statistics_batch_num)
def get_mask(self, base_mask, activations, num_prune):
"""
......@@ -161,9 +143,9 @@ class ActivationAPoZRankFilterPruner(ActivationRankFilterPruner):
apoz = self._calc_apoz(activations)
prune_indices = torch.argsort(apoz, descending=True)[:num_prune]
for idx in prune_indices:
base_mask['weight'][idx] = 0.
if base_mask['bias'] is not None:
base_mask['bias'][idx] = 0.
base_mask['weight_mask'][idx] = 0.
if base_mask['bias_mask'] is not None:
base_mask['bias_mask'][idx] = 0.
return base_mask
def _calc_apoz(self, activations):
......@@ -195,7 +177,7 @@ class ActivationMeanRankFilterPruner(ActivationRankFilterPruner):
https://arxiv.org/abs/1611.06440
"""
def __init__(self, model, config_list, activation='relu', statistics_batch_num=1):
def __init__(self, model, config_list, optimizer, activation='relu', statistics_batch_num=1):
"""
Parameters
----------
......@@ -209,7 +191,7 @@ class ActivationMeanRankFilterPruner(ActivationRankFilterPruner):
statistics_batch_num : int
Num of batches for activation statistics
"""
super().__init__(model, config_list, activation, statistics_batch_num)
super().__init__(model, config_list, optimizer, activation, statistics_batch_num)
def get_mask(self, base_mask, activations, num_prune):
"""
......@@ -233,9 +215,9 @@ class ActivationMeanRankFilterPruner(ActivationRankFilterPruner):
mean_activation = self._cal_mean_activation(activations)
prune_indices = torch.argsort(mean_activation)[:num_prune]
for idx in prune_indices:
base_mask['weight'][idx] = 0.
if base_mask['bias'] is not None:
base_mask['bias'][idx] = 0.
base_mask['weight_mask'][idx] = 0.
if base_mask['bias_mask'] is not None:
base_mask['bias_mask'][idx] = 0.
return base_mask
def _cal_mean_activation(self, activations):
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import types
import logging
import torch
from . import default_layers
......@@ -20,12 +21,13 @@ def _setattr(model, name, module):
model = getattr(model, name)
setattr(model, name_list[-1], module)
class Compressor:
"""
Abstract base PyTorch compressor
"""
def __init__(self, model, config_list):
def __init__(self, model, config_list, optimizer):
"""
Record necessary info in class members
......@@ -35,15 +37,27 @@ class Compressor:
the model user wants to compress
config_list : list
the configurations that users specify for compression
optimizer: pytorch optimizer
optimizer used to train the model
"""
self.bound_model = model
self.config_list = config_list
self.optimizer = optimizer
self.modules_to_compress = None
self.modules_wrapper = None
self.buffers = {}
self.modules_wrapper = []
self.is_wrapped = False
def detect_modules_to_compress(self):
self._fwd_hook_handles = {}
self._fwd_hook_id = 0
for layer, config in self._detect_modules_to_compress():
wrapper = self._wrap_modules(layer, config)
self.modules_wrapper.append(wrapper)
self._wrap_model()
def _detect_modules_to_compress(self):
"""
detect all modules should be compressed, and save the result in `self.modules_to_compress`.
The model will be instrumented and user should never edit it after calling this method.
......@@ -87,26 +101,26 @@ class Compressor:
torch.nn.Module
model with specified modules compressed.
"""
if self.modules_wrapper is not None:
# already compressed
return self.bound_model
else:
self.modules_wrapper = []
modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress:
wrapper = self._wrap_modules(layer, config)
self.modules_wrapper.append(wrapper)
self._wrap_model()
return self.bound_model
def register_buffer(self, name, value):
def set_wrappers_attribute(self, name, value):
"""
To register buffers used in wrapped module's forward method.
To register attributes used in wrapped module's forward method.
If the type of the value is Torch.tensor, then this value is registered as a buffer in wrapper,
which will be saved by model.state_dict. Otherwise, this value is just a regular variable in wrapper.
Parameters
----------
name : str
name of the variable
value: any
value of the variable
"""
self.buffers[name] = value
for wrapper in self.get_modules_wrapper():
if isinstance(value, torch.Tensor):
wrapper.register_buffer(name, value.clone())
else:
setattr(wrapper, name, value)
def get_modules_to_compress(self):
"""
......@@ -180,11 +194,7 @@ class Compressor:
epoch : num
the current epoch number
"""
def step(self):
"""
If user want to update model every step, user can override this method
"""
pass
def _wrap_modules(self, layer, config):
"""
......@@ -200,6 +210,33 @@ class Compressor:
raise NotImplementedError()
def add_activation_collector(self, collector):
self._fwd_hook_id += 1
self._fwd_hook_handles[self._fwd_hook_id] = []
for wrapper in self.get_modules_wrapper():
handle = wrapper.register_forward_hook(collector)
self._fwd_hook_handles[self._fwd_hook_id].append(handle)
return self._fwd_hook_id
def remove_activation_collector(self, fwd_hook_id):
if fwd_hook_id not in self._fwd_hook_handles:
raise ValueError("%s is not a valid collector id" % str(fwd_hook_id))
for handle in self._fwd_hook_handles[fwd_hook_id]:
handle.remove()
del self._fwd_hook_handles[fwd_hook_id]
def patch_optimizer(self, *tasks):
def patch_step(old_step):
def new_step(_, *args, **kwargs):
# call origin optimizer step method
output = old_step(*args, **kwargs)
# calculate mask
for task in tasks:
task()
return output
return new_step
self.optimizer.step = types.MethodType(patch_step(self.optimizer.step), self.optimizer)
class PrunerModuleWrapper(torch.nn.Module):
def __init__(self, module, module_name, module_type, config, pruner):
"""
......@@ -226,7 +263,6 @@ class PrunerModuleWrapper(torch.nn.Module):
# config and pruner
self.config = config
self.pruner = pruner
self.registered_buffers = []
# register buffer for mask
self.register_buffer("weight_mask", torch.ones(self.module.weight.shape))
......@@ -234,29 +270,11 @@ class PrunerModuleWrapper(torch.nn.Module):
self.register_buffer("bias_mask", torch.ones(self.module.bias.shape))
else:
self.register_buffer("bias_mask", None)
self.registered_buffers.append('weight_mask')
self.registered_buffers.append('bias_mask')
# register user specified buffer
for name in self.pruner.buffers:
self.register_buffer(name, self.pruner.buffers[name].clone())
self.registered_buffers.append(name)
def get_registered_buffers(self):
buffers = {}
for name in self.registered_buffers:
buffers[name] = getattr(self, name)
return buffers
def forward(self, *inputs):
mask = self.pruner.calc_mask(LayerInfo(self.name, self.module), self.config, **self.get_registered_buffers())
if mask is not None:
self.weight_mask.copy_(mask['weight'])
# apply mask to weight
# apply mask to weight, bias
self.module.weight.data = self.module.weight.data.mul_(self.weight_mask)
# apply mask to bias
if hasattr(self.module, 'bias') and self.module.bias is not None:
if mask is not None and 'bias' in mask:
self.bias_mask.copy_(mask['bias'])
self.module.bias.data = self.module.bias.data.mul_(self.bias_mask)
return self.module(*inputs)
......@@ -272,10 +290,23 @@ class Pruner(Compressor):
"""
def __init__(self, model, config_list):
super().__init__(model, config_list)
def __init__(self, model, config_list, optimizer):
super().__init__(model, config_list, optimizer)
self.patch_optimizer(self.update_mask)
def calc_mask(self, layer, config, **kwargs):
def compress(self):
self.update_mask()
return self.bound_model
def update_mask(self):
for wrapper in self.get_modules_wrapper():
masks = self.calc_mask(wrapper)
if masks is not None:
for k in masks:
assert hasattr(wrapper, k), "there is no attribute '%s' in wrapper" % k
setattr(wrapper, k, masks[k])
def calc_mask(self, wrapper, **kwargs):
"""
Pruners should overload this method to provide mask for weight tensors.
The mask must have the same shape and type comparing to the weight.
......@@ -284,10 +315,8 @@ class Pruner(Compressor):
Parameters
----------
layer : LayerInfo
calculate mask for `layer`'s weight
config : dict
the configuration for generating the mask
wrapper : Module
calculate mask for `wrapper.module`'s weight
"""
raise NotImplementedError("Pruners must overload calc_mask()")
......@@ -327,8 +356,6 @@ class Pruner(Compressor):
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
"""
# if self.detect_modules_to_compress() and not self.mask_dict:
# _logger.warning('You may not use self.mask_dict in base Pruner class to record masks')
assert model_path is not None, 'model_path must be specified'
mask_dict = {}
self._unwrap_model() # used for generating correct state_dict name without wrapper state
......@@ -404,7 +431,6 @@ class QuantizerModuleWrapper(torch.nn.Module):
# config and pruner
self.config = config
self.quantizer = quantizer
self.registered_buffers = []
# register buffer and parameter
# old_weight is used to store origin weight and weight is used to store quantized weight
......@@ -418,35 +444,18 @@ class QuantizerModuleWrapper(torch.nn.Module):
delattr(self.module, 'weight')
self.module.register_buffer('weight', self.module.old_weight)
# register user specified buffer
for name in self.quantizer.buffers:
self.register_buffer(name, self.quantizer.buffers[name].clone())
self.registered_buffers.append(name)
def get_registered_buffers(self):
buffers = {}
for name in self.registered_buffers:
buffers[name] = getattr(self, name)
return buffers
def forward(self, *inputs):
if 'input' in self.config['quant_types']:
inputs = self.quantizer.quant_grad.apply(
inputs,
QuantType.QUANT_INPUT,
self.quantizer.quantize_input,
self.config,
LayerInfo(self.name, self.module),
**self.get_registered_buffers())
self)
if 'weight' in self.config['quant_types'] and _check_weight(self.module):
new_weight = self.quantizer.quant_grad.apply(
self.module.old_weight,
QuantType.QUANT_WEIGHT,
self.quantizer.quantize_weight,
self.config,
LayerInfo(self.name, self.module),
**self.get_registered_buffers())
self)
self.module.weight = new_weight
result = self.module(*inputs)
else:
......@@ -456,10 +465,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
result = self.quantizer.quant_grad.apply(
result,
QuantType.QUANT_OUTPUT,
self.quantizer.quantize_output,
self.config,
LayerInfo(self.name, self.module),
**self.get_registered_buffers())
self)
return result
class Quantizer(Compressor):
......@@ -467,11 +473,18 @@ class Quantizer(Compressor):
Base quantizer for pytorch quantizer
"""
def __init__(self, model, config_list):
super().__init__(model, config_list)
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
self.quant_grad = QuantGrad
if self.optimizer is not None:
self.patch_optimizer(self.step_with_optimizer)
for wrapper in self.get_modules_wrapper():
if 'weight' in wrapper.config['quant_types']:
# old_weight is registered to keep track of weight before quantization
# and it is trainable, therefore, it should be added to optimizer.
self.optimizer.add_param_group({"params": wrapper.module.old_weight})
def quantize_weight(self, weight, config, op, op_type, op_name):
def quantize_weight(self, weight, wrapper, **kwargs):
"""
quantize should overload this method to quantize weight.
This method is effectively hooked to :meth:`forward` of the model.
......@@ -479,12 +492,12 @@ class Quantizer(Compressor):
----------
weight : Tensor
weight that needs to be quantized
config : dict
the configuration for weight quantization
wrapper : QuantizerModuleWrapper
the wrapper for origin module
"""
raise NotImplementedError('Quantizer must overload quantize_weight()')
def quantize_output(self, output, config, op, op_type, op_name):
def quantize_output(self, output, wrapper, **kwargs):
"""
quantize should overload this method to quantize output.
This method is effectively hooked to :meth:`forward` of the model.
......@@ -492,12 +505,12 @@ class Quantizer(Compressor):
----------
output : Tensor
output that needs to be quantized
config : dict
the configuration for output quantization
wrapper : QuantizerModuleWrapper
the wrapper for origin module
"""
raise NotImplementedError('Quantizer must overload quantize_output()')
def quantize_input(self, *inputs, config, op, op_type, op_name):
def quantize_input(self, *inputs, wrapper, **kwargs):
"""
quantize should overload this method to quantize input.
This method is effectively hooked to :meth:`forward` of the model.
......@@ -505,8 +518,8 @@ class Quantizer(Compressor):
----------
inputs : Tensor
inputs that needs to be quantized
config : dict
the configuration for inputs quantization
wrapper : QuantizerModuleWrapper
the wrapper for origin module
"""
raise NotImplementedError('Quantizer must overload quantize_input()')
......@@ -532,6 +545,9 @@ class Quantizer(Compressor):
return QuantizerModuleWrapper(layer.module, layer.name, layer.type, config, self)
def step_with_optimizer(self):
pass
class QuantType:
"""
Enum class for quantization type.
......@@ -540,6 +556,7 @@ class QuantType:
QUANT_WEIGHT = 1
QUANT_OUTPUT = 2
class QuantGrad(torch.autograd.Function):
"""
Base class for overriding backward function of quantization operation.
......@@ -566,15 +583,22 @@ class QuantGrad(torch.autograd.Function):
return grad_output
@staticmethod
def forward(ctx, tensor, quant_type, quant_func, config, layer, **kwargs):
def forward(ctx, tensor, quant_type, wrapper, **kwargs):
ctx.save_for_backward(tensor, torch.Tensor([quant_type]))
return quant_func(tensor, config, op=layer.module, op_type=layer.type, op_name=layer.name, **kwargs)
if quant_type == QuantType.QUANT_INPUT:
return wrapper.quantizer.quantize_input(tensor, wrapper, **kwargs)
elif quant_type == QuantType.QUANT_WEIGHT:
return wrapper.quantizer.quantize_weight(tensor, wrapper, **kwargs)
elif quant_type == QuantType.QUANT_OUTPUT:
return wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs)
else:
raise ValueError("unrecognized QuantType.")
@classmethod
def backward(cls, ctx, grad_output):
tensor, quant_type = ctx.saved_variables
output = cls.quant_backward(tensor, grad_output, quant_type)
return output, None, None, None, None, None
return output, None, None, None
def _check_weight(module):
try:
......
......@@ -16,7 +16,7 @@ class LevelPruner(Pruner):
Prune to an exact pruning level specification
"""
def __init__(self, model, config_list):
def __init__(self, model, config_list, optimizer):
"""
Parameters
----------
......@@ -26,36 +26,35 @@ class LevelPruner(Pruner):
List on pruning configs
"""
super().__init__(model, config_list)
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
super().__init__(model, config_list, optimizer)
self.set_wrappers_attribute("if_calculated", False)
def calc_mask(self, layer, config, **kwargs):
def calc_mask(self, wrapper, **kwargs):
"""
Calculate the mask of given layer
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
wrapper : Module
the module to instrument the compression operation
Returns
-------
dict
dictionary for storing masks
"""
weight = layer.module.weight.data
if_calculated = kwargs["if_calculated"]
config = wrapper.config
weight = wrapper.module.weight.data
if not if_calculated:
if not wrapper.if_calculated:
w_abs = weight.abs()
k = int(weight.numel() * config['sparsity'])
if k == 0:
return torch.ones(weight.shape).type_as(weight)
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
mask_weight = torch.gt(w_abs, threshold).type_as(weight)
mask = {'weight': mask_weight}
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
mask = {'weight_mask': mask_weight}
wrapper.if_calculated = True
return mask
else:
return None
......@@ -71,7 +70,7 @@ class AGP_Pruner(Pruner):
https://arxiv.org/pdf/1710.01878.pdf
"""
def __init__(self, model, config_list):
def __init__(self, model, config_list, optimizer):
"""
Parameters
----------
......@@ -81,48 +80,45 @@ class AGP_Pruner(Pruner):
List on pruning configs
"""
super().__init__(model, config_list)
super().__init__(model, config_list, optimizer)
self.now_epoch = 0
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
self.set_wrappers_attribute("if_calculated", False)
def calc_mask(self, layer, config, **kwargs):
def calc_mask(self, wrapper, **kwargs):
"""
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters
----------
layer : LayerInfo
wrapper : Module
the layer to instrument the compression operation
config : dict
layer's pruning config
kwargs: dict
buffers registered in __init__ function
Returns
-------
dict
dictionary for storing masks
"""
weight = layer.module.weight.data
config = wrapper.config
weight = wrapper.module.weight.data
start_epoch = config.get('start_epoch', 0)
freq = config.get('frequency', 1)
if_calculated = kwargs["if_calculated"]
if if_calculated:
if wrapper.if_calculated:
return None
if not (self.now_epoch >= start_epoch and (self.now_epoch - start_epoch) % freq == 0):
return None
mask = {'weight': kwargs['weight_mask'] if 'weight_mask' in kwargs else torch.ones(weight.shape).type_as(weight)}
mask = {'weight_mask': wrapper.weight_mask}
target_sparsity = self.compute_target_sparsity(config)
k = int(weight.numel() * target_sparsity)
if k == 0 or target_sparsity >= 1 or target_sparsity <= 0:
return mask
# if we want to generate new mask, we should update weigth first
w_abs = weight.abs() * mask['weight']
w_abs = weight.abs() * mask['weight_mask']
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
new_mask = {'weight': torch.gt(w_abs, threshold).type_as(weight)}
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
new_mask = {'weight_mask': torch.gt(w_abs, threshold).type_as(weight)}
wrapper.if_calculated = True
return new_mask
......@@ -180,7 +176,7 @@ class SlimPruner(Pruner):
https://arxiv.org/pdf/1708.06519.pdf
"""
def __init__(self, model, config_list):
def __init__(self, model, config_list, optimizer):
"""
Parameters
----------
......@@ -189,53 +185,51 @@ class SlimPruner(Pruner):
- sparsity: percentage of convolutional filters to be pruned.
"""
super().__init__(model, config_list)
super().__init__(model, config_list, optimizer)
weight_list = []
if len(config_list) > 1:
logger.warning('Slim pruner only supports 1 configuration')
config = config_list[0]
for (layer, config) in self.detect_modules_to_compress():
for (layer, config) in self.get_modules_to_compress():
assert layer.type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
weight_list.append(layer.module.weight.data.abs().clone())
all_bn_weights = torch.cat(weight_list)
k = int(all_bn_weights.shape[0] * config['sparsity'])
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max()
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
self.set_wrappers_attribute("if_calculated", False)
def calc_mask(self, layer, config, **kwargs):
def calc_mask(self, wrapper, **kwargs):
"""
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters
----------
layer : LayerInfo
wrapper : Module
the layer to instrument the compression operation
config : dict
layer's pruning config
kwargs: dict
buffers registered in __init__ function
Returns
-------
dict
dictionary for storing masks
"""
weight = layer.module.weight.data
op_type = layer.type
if_calculated = kwargs["if_calculated"]
config = wrapper.config
weight = wrapper.module.weight.data
op_type = wrapper.type
assert op_type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
if if_calculated:
if wrapper.if_calculated:
return None
base_mask = torch.ones(weight.size()).type_as(weight).detach()
mask = {'weight': base_mask.detach(), 'bias': base_mask.clone().detach()}
mask = {'weight_mask': base_mask.detach(), 'bias_mask': base_mask.clone().detach()}
filters = weight.size(0)
num_prune = int(filters * config.get('sparsity'))
if filters >= 2 and num_prune >= 1:
w_abs = weight.abs()
mask_weight = torch.gt(w_abs, self.global_threshold).type_as(weight)
mask_bias = mask_weight.clone()
mask = {'weight': mask_weight.detach(), 'bias': mask_bias.detach()}
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
mask = {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias.detach()}
wrapper.if_calculated = True
return mask
class LotteryTicketPruner(Pruner):
......@@ -267,7 +261,7 @@ class LotteryTicketPruner(Pruner):
reset_weights : bool
Whether reset weights and optimizer at the beginning of each round.
"""
super().__init__(model, config_list)
super().__init__(model, config_list, optimizer)
self.curr_prune_iteration = None
self.prune_iterations = self._validate_config(config_list)
......@@ -307,20 +301,16 @@ class LotteryTicketPruner(Pruner):
k = int(w_abs.numel() * curr_sparsity)
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
mask = torch.gt(w_abs, threshold).type_as(weight)
return {'weight': mask}
return {'weight_mask': mask}
def calc_mask(self, layer, config, **kwargs):
def calc_mask(self, wrapper, **kwargs):
"""
Generate mask for the given ``weight``.
Parameters
----------
layer : LayerInfo
wrapper : Module
The layer to be pruned
config : dict
Pruning configurations for this weight
kwargs : dict
Auxiliary information
Returns
-------
......@@ -355,7 +345,7 @@ class LotteryTicketPruner(Pruner):
assert self.curr_prune_iteration < self.prune_iterations + 1, 'Exceed the configured prune_iterations'
modules_wrapper = self.get_modules_wrapper()
modules_to_compress = self.detect_modules_to_compress()
modules_to_compress = self.get_modules_to_compress()
for layer, config in modules_to_compress:
module_wrapper = None
for wrapper in modules_wrapper:
......@@ -367,7 +357,7 @@ class LotteryTicketPruner(Pruner):
sparsity = config.get('sparsity')
mask = self._calc_mask(layer.module.weight.data, sparsity, module_wrapper.weight_mask)
# TODO: directly use weight_mask is not good
module_wrapper.weight_mask.copy_(mask['weight'])
module_wrapper.weight_mask = mask['weight_mask']
# there is no mask for bias
# reinit weights back to original after new masks are generated
......
......@@ -13,14 +13,14 @@ logger = logging.getLogger(__name__)
class NaiveQuantizer(Quantizer):
"""quantize weight to 8 bits
"""
def __init__(self, model, config_list):
super().__init__(model, config_list)
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
self.layer_scale = {}
def quantize_weight(self, weight, config, op_name, **kwargs):
def quantize_weight(self, weight, wrapper, **kwargs):
new_scale = weight.abs().max() / 127
scale = max(self.layer_scale.get(op_name, 0), new_scale)
self.layer_scale[op_name] = scale
scale = max(self.layer_scale.get(wrapper.name, 0), new_scale)
self.layer_scale[wrapper.name] = scale
orig_type = weight.type() # TODO: user layer
return weight.div(scale).type(torch.int8).type(orig_type).mul(scale)
......@@ -104,7 +104,7 @@ class QAT_Quantizer(Quantizer):
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
"""
def __init__(self, model, config_list):
def __init__(self, model, config_list, optimizer=None):
"""
Parameters
----------
......@@ -124,9 +124,9 @@ class QAT_Quantizer(Quantizer):
- op_types : list of string
types of nn.module you want to apply quantization, eg. 'Conv2d'
"""
super().__init__(model, config_list)
super().__init__(model, config_list, optimizer)
self.steps = 1
modules_to_compress = self.detect_modules_to_compress()
modules_to_compress = self.get_modules_to_compress()
for layer, config in modules_to_compress:
layer.module.register_buffer("zero_point", None)
layer.module.register_buffer("scale", None)
......@@ -181,7 +181,9 @@ class QAT_Quantizer(Quantizer):
real_val = op.scale * (quantized_val - op.zero_point)
return real_val
def quantize_weight(self, weight, config, op, **kwargs):
def quantize_weight(self, weight, wrapper, **kwargs):
config = wrapper.config
module = wrapper.module
weight_bits = get_bits_length(config, 'weight')
quant_start_step = config.get('quant_start_step', 0)
assert weight_bits >= 1, "quant bits length should be at least 1"
......@@ -189,12 +191,14 @@ class QAT_Quantizer(Quantizer):
if quant_start_step > self.steps:
return weight
rmin, rmax = torch.min(weight), torch.max(weight)
op.scale, op.zero_point = update_quantization_param(weight_bits, rmin, rmax)
out = self._quantize(weight_bits, op, weight)
out = self._dequantize(op, out)
module.scale, module.zero_point = update_quantization_param(weight_bits, rmin, rmax)
out = self._quantize(weight_bits, module, weight)
out = self._dequantize(module, out)
return out
def quantize_output(self, output, config, op, **kwargs):
def quantize_output(self, output, wrapper, **kwargs):
config = wrapper.config
module = wrapper.module
output_bits = get_bits_length(config, 'output')
quant_start_step = config.get('quant_start_step', 0)
assert output_bits >= 1, "quant bits length should be at least 1"
......@@ -203,18 +207,18 @@ class QAT_Quantizer(Quantizer):
return output
current_min, current_max = torch.min(output), torch.max(output)
op.tracked_min_biased, op.tracked_min = update_ema(op.tracked_min_biased, current_min, op.ema_decay, self.steps)
op.tracked_max_biased, op.tracked_max = update_ema(op.tracked_max_biased, current_max, op.ema_decay, self.steps)
op.scale, op.zero_point = update_quantization_param(output_bits, op.tracked_min, op.tracked_max)
out = self._quantize(output_bits, op, output)
out = self._dequantize(op, out)
module.tracked_min_biased, module.tracked_min = update_ema(module.tracked_min_biased, current_min, module.ema_decay, self.steps)
module.tracked_max_biased, module.tracked_max = update_ema(module.tracked_max_biased, current_max, module.ema_decay, self.steps)
module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min, module.tracked_max)
out = self._quantize(output_bits, module, output)
out = self._dequantize(module, out)
return out
def fold_bn(self, config, **kwargs):
# TODO simulate folded weight
pass
def step(self):
def step_with_optimizer(self):
"""
override `compressor` `step` method, quantization only happens after certain number of steps
"""
......@@ -226,11 +230,11 @@ class DoReFaQuantizer(Quantizer):
Zhou et al., DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients
(https://arxiv.org/abs/1606.06160)
"""
def __init__(self, model, config_list):
super().__init__(model, config_list)
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
def quantize_weight(self, weight, config, **kwargs):
weight_bits = get_bits_length(config, 'weight')
def quantize_weight(self, weight, wrapper, **kwargs):
weight_bits = get_bits_length(wrapper.config, 'weight')
out = weight.tanh()
out = out / (2 * out.abs().max()) + 0.5
out = self.quantize(out, weight_bits)
......@@ -256,17 +260,17 @@ class BNNQuantizer(Quantizer):
Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1
(https://arxiv.org/abs/1602.02830)
"""
def __init__(self, model, config_list):
super().__init__(model, config_list)
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
self.quant_grad = ClipGrad
def quantize_weight(self, weight, config, **kwargs):
def quantize_weight(self, weight, wrapper, **kwargs):
out = torch.sign(weight)
# remove zeros
out[out == 0] = 1
return out
def quantize_output(self, output, config, **kwargs):
def quantize_output(self, output, wrapper, **kwargs):
out = torch.sign(output)
# remove zeros
out[out == 0] = 1
......
......@@ -15,7 +15,7 @@ class WeightRankFilterPruner(Pruner):
importance criterion in convolution layers to achieve a preset level of network sparsity.
"""
def __init__(self, model, config_list):
def __init__(self, model, config_list, optimizer):
"""
Parameters
----------
......@@ -26,13 +26,13 @@ class WeightRankFilterPruner(Pruner):
- sparsity: percentage of convolutional filters to be pruned.
"""
super().__init__(model, config_list)
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
super().__init__(model, config_list, optimizer)
self.set_wrappers_attribute("if_calculated", False)
def get_mask(self, base_mask, weight, num_prune):
raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__))
def calc_mask(self, layer, config, **kwargs):
def calc_mask(self, wrapper, **kwargs):
"""
Calculate the mask of given layer.
Filters with the smallest importance criterion of the kernel weights are masked.
......@@ -48,20 +48,21 @@ class WeightRankFilterPruner(Pruner):
dictionary for storing masks
"""
weight = layer.module.weight.data
op_type = layer.type
weight = wrapper.module.weight.data
op_type = wrapper.type
config = wrapper.config
assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)"
assert op_type in ['Conv1d', 'Conv2d'], "only support Conv1d and Conv2d"
assert op_type in config.get('op_types')
if_calculated = kwargs["if_calculated"]
if if_calculated:
if wrapper.if_calculated:
return None
mask_weight = torch.ones(weight.size()).type_as(weight).detach()
if hasattr(layer.module, 'bias') and layer.module.bias is not None:
mask_bias = torch.ones(layer.module.bias.size()).type_as(layer.module.bias).detach()
if hasattr(wrapper.module, 'bias') and wrapper.module.bias is not None:
mask_bias = torch.ones(wrapper.module.bias.size()).type_as(wrapper.module.bias).detach()
else:
mask_bias = None
mask = {'weight': mask_weight, 'bias': mask_bias}
mask = {'weight_mask': mask_weight, 'bias_mask': mask_bias}
try:
filters = weight.size(0)
num_prune = int(filters * config.get('sparsity'))
......@@ -69,7 +70,7 @@ class WeightRankFilterPruner(Pruner):
return mask
mask = self.get_mask(mask, weight, num_prune)
finally:
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
wrapper.if_calculated = True
return mask
......@@ -82,7 +83,7 @@ class L1FilterPruner(WeightRankFilterPruner):
https://arxiv.org/abs/1608.08710
"""
def __init__(self, model, config_list):
def __init__(self, model, config_list, optimizer):
"""
Parameters
----------
......@@ -93,7 +94,7 @@ class L1FilterPruner(WeightRankFilterPruner):
- sparsity: percentage of convolutional filters to be pruned.
"""
super().__init__(model, config_list)
super().__init__(model, config_list, optimizer)
def get_mask(self, base_mask, weight, num_prune):
"""
......@@ -121,7 +122,7 @@ class L1FilterPruner(WeightRankFilterPruner):
mask_weight = torch.gt(w_abs_structured, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
mask_bias = torch.gt(w_abs_structured, threshold).type_as(weight)
return {'weight': mask_weight.detach(), 'bias': mask_bias.detach()}
return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias.detach()}
class L2FilterPruner(WeightRankFilterPruner):
......@@ -130,7 +131,7 @@ class L2FilterPruner(WeightRankFilterPruner):
smallest L2 norm of the weights.
"""
def __init__(self, model, config_list):
def __init__(self, model, config_list, optimizer):
"""
Parameters
----------
......@@ -141,7 +142,7 @@ class L2FilterPruner(WeightRankFilterPruner):
- sparsity: percentage of convolutional filters to be pruned.
"""
super().__init__(model, config_list)
super().__init__(model, config_list, optimizer)
def get_mask(self, base_mask, weight, num_prune):
"""
......@@ -167,7 +168,7 @@ class L2FilterPruner(WeightRankFilterPruner):
mask_weight = torch.gt(w_l2_norm, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
mask_bias = torch.gt(w_l2_norm, threshold).type_as(weight)
return {'weight': mask_weight.detach(), 'bias': mask_bias.detach()}
return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias.detach()}
class FPGMPruner(WeightRankFilterPruner):
......@@ -177,7 +178,7 @@ class FPGMPruner(WeightRankFilterPruner):
https://arxiv.org/pdf/1811.00250.pdf
"""
def __init__(self, model, config_list):
def __init__(self, model, config_list, optimizer):
"""
Parameters
----------
......@@ -187,7 +188,7 @@ class FPGMPruner(WeightRankFilterPruner):
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
super().__init__(model, config_list)
super().__init__(model, config_list, optimizer)
def get_mask(self, base_mask, weight, num_prune):
"""
......@@ -208,9 +209,9 @@ class FPGMPruner(WeightRankFilterPruner):
"""
min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune)
for idx in min_gm_idx:
base_mask['weight'][idx] = 0.
if base_mask['bias'] is not None:
base_mask['bias'][idx] = 0.
base_mask['weight_mask'][idx] = 0.
if base_mask['bias_mask'] is not None:
base_mask['bias_mask'][idx] = 0.
return base_mask
def _get_min_gm_kernel_idx(self, weight, n):
......@@ -258,4 +259,4 @@ class FPGMPruner(WeightRankFilterPruner):
def update_epoch(self, epoch):
for wrapper in self.get_modules_wrapper():
wrapper.registered_buffers['if_calculated'].copy_(torch.tensor(0)) # pylint: disable=not-callable
wrapper.if_calculated = False
......@@ -92,8 +92,9 @@ class CompressorTestCase(TestCase):
def test_torch_level_pruner(self):
model = TorchModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
torch_compressor.LevelPruner(model, configure_list).compress()
torch_compressor.LevelPruner(model, configure_list, optimizer).compress()
@tf2
def test_tf_level_pruner(self):
......@@ -130,22 +131,24 @@ class CompressorTestCase(TestCase):
"""
model = TorchModel()
config_list = [{'sparsity': 0.2, 'op_types': ['Conv2d']}, {'sparsity': 0.6, 'op_types': ['Conv2d']}]
pruner = torch_compressor.FPGMPruner(model, config_list)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
config_list = [{'sparsity': 0.6, 'op_types': ['Conv2d']}, {'sparsity': 0.2, 'op_types': ['Conv2d']}]
pruner = torch_compressor.FPGMPruner(model, config_list, optimizer)
model.conv2.weight.data = torch.tensor(w).float()
layer = torch_compressor.compressor.LayerInfo('conv2', model.conv2)
masks = pruner.calc_mask(layer, config_list[0], if_calculated=torch.tensor(0))
assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))
model.conv2.module.weight.data = torch.tensor(w).float()
masks = pruner.calc_mask(model.conv2)
assert all(torch.sum(masks['weight_mask'], (1, 2, 3)).numpy() == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))
model.conv2.weight.data = torch.tensor(w).float()
masks = pruner.calc_mask(layer, config_list[1], if_calculated=torch.tensor(0))
assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.]))
model.conv2.module.weight.data = torch.tensor(w).float()
model.conv2.if_calculated = False
model.conv2.config = config_list[0]
masks = pruner.calc_mask(model.conv2)
assert all(torch.sum(masks['weight_mask'], (1, 2, 3)).numpy() == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.]))
@tf2
def test_tf_fpgm_pruner(self):
model = get_tf_model()
config_list = [{'sparsity': 0.2, 'op_types': ['Conv2D']}, {'sparsity': 0.6, 'op_types': ['Conv2D']}]
config_list = [{'sparsity': 0.2, 'op_types': ['Conv2D']}]
pruner = tf_compressor.FPGMPruner(model, config_list)
weights = model.layers[2].weights
......@@ -158,11 +161,6 @@ class CompressorTestCase(TestCase):
assert all(masks.sum((1)) == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))
model.layers[2].set_weights([weights[0], weights[1].numpy()])
masks = pruner.calc_mask(layer, config_list[1]).numpy()
masks = masks.reshape((-1, masks.shape[-1])).transpose([1, 0])
assert all(masks.sum((1)) == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.]))
def test_torch_l1filter_pruner(self):
"""
Filters with the minimum sum of the weights' L1 norm are pruned in this paper:
......@@ -178,18 +176,17 @@ class CompressorTestCase(TestCase):
w = np.array([np.zeros((3, 3, 3)), np.ones((3, 3, 3)), np.ones((3, 3, 3)) * 2,
np.ones((3, 3, 3)) * 3, np.ones((3, 3, 3)) * 4])
model = TorchModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
config_list = [{'sparsity': 0.2, 'op_types': ['Conv2d'], 'op_names': ['conv1']},
{'sparsity': 0.6, 'op_types': ['Conv2d'], 'op_names': ['conv2']}]
pruner = torch_compressor.L1FilterPruner(model, config_list)
pruner = torch_compressor.L1FilterPruner(model, config_list, optimizer)
model.conv1.weight.data = torch.tensor(w).float()
model.conv2.weight.data = torch.tensor(w).float()
layer1 = torch_compressor.compressor.LayerInfo('conv1', model.conv1)
mask1 = pruner.calc_mask(layer1, config_list[0], if_calculated=torch.tensor(0))
layer2 = torch_compressor.compressor.LayerInfo('conv2', model.conv2)
mask2 = pruner.calc_mask(layer2, config_list[1], if_calculated=torch.tensor(0))
assert all(torch.sum(mask1['weight'], (1, 2, 3)).numpy() == np.array([0., 27., 27., 27., 27.]))
assert all(torch.sum(mask2['weight'], (1, 2, 3)).numpy() == np.array([0., 0., 0., 27., 27.]))
model.conv1.module.weight.data = torch.tensor(w).float()
model.conv2.module.weight.data = torch.tensor(w).float()
mask1 = pruner.calc_mask(model.conv1)
mask2 = pruner.calc_mask(model.conv2)
assert all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 27., 27., 27., 27.]))
assert all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 0., 0., 27., 27.]))
def test_torch_slim_pruner(self):
"""
......@@ -207,33 +204,32 @@ class CompressorTestCase(TestCase):
"""
w = np.array([0, 1, 2, 3, 4])
model = TorchModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
config_list = [{'sparsity': 0.2, 'op_types': ['BatchNorm2d']}]
model.bn1.weight.data = torch.tensor(w).float()
model.bn2.weight.data = torch.tensor(-w).float()
pruner = torch_compressor.SlimPruner(model, config_list)
pruner = torch_compressor.SlimPruner(model, config_list, optimizer)
layer1 = torch_compressor.compressor.LayerInfo('bn1', model.bn1)
mask1 = pruner.calc_mask(layer1, config_list[0], if_calculated=torch.tensor(0))
layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2)
mask2 = pruner.calc_mask(layer2, config_list[0], if_calculated=torch.tensor(0))
assert all(mask1['weight'].numpy() == np.array([0., 1., 1., 1., 1.]))
assert all(mask2['weight'].numpy() == np.array([0., 1., 1., 1., 1.]))
assert all(mask1['bias'].numpy() == np.array([0., 1., 1., 1., 1.]))
assert all(mask2['bias'].numpy() == np.array([0., 1., 1., 1., 1.]))
mask1 = pruner.calc_mask(model.bn1)
mask2 = pruner.calc_mask(model.bn2)
assert all(mask1['weight_mask'].numpy() == np.array([0., 1., 1., 1., 1.]))
assert all(mask2['weight_mask'].numpy() == np.array([0., 1., 1., 1., 1.]))
assert all(mask1['bias_mask'].numpy() == np.array([0., 1., 1., 1., 1.]))
assert all(mask2['bias_mask'].numpy() == np.array([0., 1., 1., 1., 1.]))
model = TorchModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
config_list = [{'sparsity': 0.6, 'op_types': ['BatchNorm2d']}]
model.bn1.weight.data = torch.tensor(w).float()
model.bn2.weight.data = torch.tensor(w).float()
pruner = torch_compressor.SlimPruner(model, config_list)
pruner = torch_compressor.SlimPruner(model, config_list, optimizer)
layer1 = torch_compressor.compressor.LayerInfo('bn1', model.bn1)
mask1 = pruner.calc_mask(layer1, config_list[0], if_calculated=torch.tensor(0))
layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2)
mask2 = pruner.calc_mask(layer2, config_list[0], if_calculated=torch.tensor(0))
assert all(mask1['weight'].numpy() == np.array([0., 0., 0., 1., 1.]))
assert all(mask2['weight'].numpy() == np.array([0., 0., 0., 1., 1.]))
assert all(mask1['bias'].numpy() == np.array([0., 0., 0., 1., 1.]))
assert all(mask2['bias'].numpy() == np.array([0., 0., 0., 1., 1.]))
mask1 = pruner.calc_mask(model.bn1)
mask2 = pruner.calc_mask(model.bn2)
assert all(mask1['weight_mask'].numpy() == np.array([0., 0., 0., 1., 1.]))
assert all(mask2['weight_mask'].numpy() == np.array([0., 0., 0., 1., 1.]))
assert all(mask1['bias_mask'].numpy() == np.array([0., 0., 0., 1., 1.]))
assert all(mask2['bias_mask'].numpy() == np.array([0., 0., 0., 1., 1.]))
def test_torch_QAT_quantizer(self):
model = TorchModel()
......@@ -254,14 +250,14 @@ class CompressorTestCase(TestCase):
# range not including 0
eps = 1e-7
weight = torch.tensor([[1, 2], [3, 5]]).float()
quantize_weight = quantizer.quantize_weight(weight, config_list[0], model.conv2)
assert math.isclose(model.conv2.scale, 5 / 255, abs_tol=eps)
assert model.conv2.zero_point == 0
quantize_weight = quantizer.quantize_weight(weight, model.conv2)
assert math.isclose(model.conv2.module.scale, 5 / 255, abs_tol=eps)
assert model.conv2.module.zero_point == 0
# range including 0
weight = torch.tensor([[-1, 2], [3, 5]]).float()
quantize_weight = quantizer.quantize_weight(weight, config_list[0], model.conv2)
assert math.isclose(model.conv2.scale, 6 / 255, abs_tol=eps)
assert model.conv2.zero_point in (42, 43)
quantize_weight = quantizer.quantize_weight(weight, model.conv2)
assert math.isclose(model.conv2.module.scale, 6 / 255, abs_tol=eps)
assert model.conv2.module.zero_point in (42, 43)
# test ema
x = torch.tensor([[-0.2, 0], [0.1, 0.2]])
......@@ -269,7 +265,7 @@ class CompressorTestCase(TestCase):
assert math.isclose(model.relu.module.tracked_min_biased, 0, abs_tol=eps)
assert math.isclose(model.relu.module.tracked_max_biased, 0.002, abs_tol=eps)
quantizer.step()
quantizer.step_with_optimizer()
x = torch.tensor([[0.2, 0.4], [0.6, 0.8]])
out = model.relu(x)
assert math.isclose(model.relu.module.tracked_min_biased, 0.002, abs_tol=eps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment