"src/include/ConstantMatrixDescriptor.hpp" did not exist on "f35c64eb78af4754e78f8746c8e28d2ac8b68e80"
Unverified Commit 89fa23cb authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Refactor model pruning framework (#2504)

parent 96207cb5
# Design Doc # Design Doc
## Overview ## Overview
The model compression framework has two main components: `pruner` and `module wrapper`.
### pruner Following example shows how to use a pruner:
A `pruner` is responsible for :
1. provide a `cal_mask` method that calculates masks for weight and bias.
2. replace the module with `module wrapper` based on config.
3. modify the optimizer so that the `cal_mask` method is called every time the `step` method is called.
### module wrapper ```python
A `module wrapper` is a module containing : from nni.compression.torch import LevelPruner
1. the origin module
2. some buffers used by `cal_mask`
3. a new forward method that applies masks before running the original forward method.
the reasons to use `module wrapper` : # load a pretrained model or train a model before using a pruner
1. some buffers are needed by `cal_mask` to calculate masks and these buffers should be registered in `module wrapper` so that the original modules are not contaminated.
2. a new `forward` method is needed to apply masks to weight before calling the real `forward` method.
## How it works
A basic pruner usage:
```python
configure_list = [{ configure_list = [{
'sparsity': 0.7, 'sparsity': 0.7,
'op_types': ['BatchNorm2d'], 'op_types': ['Conv2d', 'Linear'],
}] }]
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4) optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
pruner = SlimPruner(model, configure_list, optimizer) pruner = LevelPruner(model, configure_list, optimizer)
model = pruner.compress() model = pruner.compress()
# model is ready for pruning, now start finetune the model,
# the model will be pruned during training automatically
``` ```
A pruner receive model, config and optimizer as arguments. In the `__init__` method, the `step` method of the optimizer is replaced with a new `step` method that calls `cal_mask`. Also, all modules are checked if they need to be pruned based on config. If a module needs to be pruned, then this module is replaced by a `module wrapper`. Afterward, the new model and new optimizer are returned, which can be trained as before. `compress` method will calculate the default masks. A pruner receives `model`, `config_list` and `optimizer` as arguments. It prunes the model per the `config_list` during training loop by adding a hook on `optimizer.step()`.
From implementation perspective, a pruner consists of a `weight masker` instance and multiple `module wrapper` instances.
### Weight masker
A `weight masker` is the implementation of pruning algorithms, it can prune a specified layer wrapped by `module wrapper` with specified sparsity.
### Module wrapper
A `module wrapper` is a module containing:
1. the origin module
2. some buffers used by `calc_mask`
3. a new forward method that applies masks before running the original forward method.
the reasons to use `module wrapper`:
1. some buffers are needed by `calc_mask` to calculate masks and these buffers should be registered in `module wrapper` so that the original modules are not contaminated.
2. a new `forward` method is needed to apply masks to weight before calling the real `forward` method.
### Pruner
A `pruner` is responsible for:
1. Manage / verify config_list.
2. Use `module wrapper` to wrap the model layers and add hook on `optimizer.step`
3. Use `weight masker` to calculate masks of layers while pruning.
4. Export pruned model weights and masks.
## Implement a new pruning algorithm ## Implement a new pruning algorithm
Implementing a new pruning algorithm requires implementing a new `pruner` class, which should subclass `Pruner` and override the `cal_mask` method. The `cal_mask` is called by`optimizer.step` method.
The `Pruner` base class provided basic functionality listed above, for example, replacing modules and patching optimizer.
A basic pruner look likes this: Implementing a new pruning algorithm requires implementing a `weight masker` class which shoud be a subclass of `WeightMasker`, and a `pruner` class, which should a subclass `Pruner`.
```python
class NewPruner(Pruner): An implementation of `weight masker` may look like this:
def __init__(self, model, config_list, optimizer)
super().__init__(model, config_list, optimizer)
# do some initialization
def calc_mask(self, wrapper, **kwargs): ```python
# do something to calculate weight_mask class MyMasker(WeightMasker):
wrapper.weight_mask = weight_mask def __init__(self, model, pruner):
super().__init__(model, pruner)
# You can do some initialization here, such as collecting some statistics data
# if it is necessary for your algorithms to calculate the masks.
def calc_mask(self, sparsity, wrapper, wrapper_idx=None):
# calculate the masks based on the wrapper.weight, and sparsity,
# and anything else
# mask = ...
return {'weight_mask': mask}
``` ```
### Set wrapper attribute
Sometimes `cal_mask` must save some state data, therefore users can use `set_wrappers_attribute` API to register attribute just like how buffers are registered in PyTorch modules. These buffers will be registered to `module wrapper`. Users can access these buffers through `module wrapper`. You can reference nni provided [weight masker](https://github.com/microsoft/nni/blob/master/src/sdk/pynni/nni/compression/torch/pruning/structured_pruning.py) implementations to implement your own weight masker.
A basic pruner looks likes this:
```python ```python
class NewPruner(Pruner): class MyPruner(Pruner):
def __init__(self, model, config_list, optimizer): def __init__(self, model, config_list, optimizer):
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
self.set_wrappers_attribute("if_calculated", False) self.set_wrappers_attribute("if_calculated", False)
# construct a weight masker instance
def calc_mask(self, wrapper): self.masker = MyMasker(model, self)
# do something to calculate weight_mask
def calc_mask(self, wrapper, wrapper_idx=None):
sparsity = wrapper.config['sparsity']
if wrapper.if_calculated: if wrapper.if_calculated:
pass # Already pruned, do not prune again as a one-shot pruner
return None
else: else:
# call your masker to actually calcuate the mask for this layer
masks = self.masker.calc_mask(sparsity=sparsity, wrapper=wrapper, wrapper_idx=wrapper_idx)
wrapper.if_calculated = True wrapper.if_calculated = True
# update masks return masks
``` ```
Reference nni provided [pruner](https://github.com/microsoft/nni/blob/master/src/sdk/pynni/nni/compression/torch/pruning/one_shot.py) implementations to implement your own pruner class.
### Set wrapper attribute
Sometimes `calc_mask` must save some state data, therefore users can use `set_wrappers_attribute` API to register attribute just like how buffers are registered in PyTorch modules. These buffers will be registered to `module wrapper`. Users can access these buffers through `module wrapper`.
In above example, we use `set_wrappers_attribute` to set a buffer `if_calculated` which is used as flag indicating if the mask of a layer is already calculated.
### Collect data during forward ### Collect data during forward
Sometimes users want to collect some data during the modules' forward method, for example, the mean value of the activation. Therefore user can add a customized collector to module.
Sometimes users want to collect some data during the modules' forward method, for example, the mean value of the activation. This can be done by adding a customized collector to module.
```python ```python
class ActivationRankFilterPruner(Pruner): class MyMasker(WeightMasker):
def __init__(self, model, config_list, optimizer, activation='relu', statistics_batch_num=1): def __init__(self, model, pruner):
super().__init__(model, config_list, optimizer) super().__init__(model, pruner)
self.set_wrappers_attribute("if_calculated", False) # Set attribute `collected_activation` for all wrappers to store
self.set_wrappers_attribute("collected_activation", []) # activations for each layer
self.statistics_batch_num = statistics_batch_num self.pruner.set_wrappers_attribute("collected_activation", [])
self.activation = torch.nn.functional.relu
def collector(module_, input_, output):
if len(module_.collected_activation) < self.statistics_batch_num: def collector(wrapper, input_, output):
module_.collected_activation.append(self.activation(output.detach().cpu())) # The collected activation can be accessed via each wrapper's collected_activation
self.add_activation_collector(collector) # attribute
assert activation in ['relu', 'relu6'] wrapper.collected_activation.append(self.activation(output.detach().cpu()))
if activation == 'relu':
self.activation = torch.nn.functional.relu self.pruner.hook_id = self.pruner.add_activation_collector(collector)
elif activation == 'relu6':
self.activation = torch.nn.functional.relu6
else:
self.activation = None
``` ```
The collector function will be called each time the forward method runs. The collector function will be called each time the forward method runs.
Users can also remove this collector like this: Users can also remove this collector like this:
```python ```python
collector_id = self.add_activation_collector(collector) # Save the collector identifier
# ... collector_id = self.pruner.add_activation_collector(collector)
self.remove_activation_collector(collector_id)
# When the collector is not used any more, it can be remove using
# the saved collector identifier
self.pruner.remove_activation_collector(collector_id)
``` ```
### Multi-GPU support ### Multi-GPU support
On multi-GPU training, buffers and parameters are copied to multiple GPU every time the `forward` method runs on multiple GPU. If buffers and parameters are updated in the `forward` method, an `in-place` update is needed to ensure the update is effective. On multi-GPU training, buffers and parameters are copied to multiple GPU every time the `forward` method runs on multiple GPU. If buffers and parameters are updated in the `forward` method, an `in-place` update is needed to ensure the update is effective.
Since `cal_mask` is called in the `optimizer.step` method, which happens after the `forward` method and happens only on one GPU, it supports multi-GPU naturally. Since `calc_mask` is called in the `optimizer.step` method, which happens after the `forward` method and happens only on one GPU, it supports multi-GPU naturally.
\ No newline at end of file
...@@ -80,11 +80,21 @@ config_list = [{ ...@@ -80,11 +80,21 @@ config_list = [{
'frequency': 1, 'frequency': 1,
'op_types': ['default'] 'op_types': ['default']
}] }]
pruner = AGP_Pruner(model, config_list) pruner = AGP_Pruner(model, config_list, pruning_algorithm='level')
pruner.compress() pruner.compress()
``` ```
you should add code below to update epoch number when you finish one epoch in your training code. AGP pruner uses `LevelPruner` algorithms to prune the weight by default, however you can set `pruning_algorithm` parameter to other values to use other pruning algorithms:
* `level`: LevelPruner
* `slim`: SlimPruner
* `l1`: L1FilterPruner
* `l2`: L2FilterPruner
* `fpgm`: FPGMPruner
* `taylorfo`: TaylorFOWeightFilterPruner
* `apoz`: ActivationAPoZRankFilterPruner
* `mean_activation`: ActivationMeanRankFilterPruner
You should add code below to update epoch number when you finish one epoch in your training code.
Tensorflow code Tensorflow code
```python ```python
...@@ -209,7 +219,7 @@ pruner.compress() ...@@ -209,7 +219,7 @@ pruner.compress()
``` ```
Note: FPGM Pruner is used to prune convolutional layers within deep neural networks, therefore the `op_types` field supports only convolutional layers. Note: FPGM Pruner is used to prune convolutional layers within deep neural networks, therefore the `op_types` field supports only convolutional layers.
you should add code below to update epoch number at beginning of each epoch. You should add code below to update epoch number at beginning of each epoch.
Tensorflow code Tensorflow code
```python ```python
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from .pruners import * from .finegrained_pruning import *
from .weight_rank_filter_pruners import * from .structured_pruning import *
from .activation_rank_filter_pruners import *
from .apply_compression import apply_compression_results from .apply_compression import apply_compression_results
from .gradient_rank_filter_pruners import * from .one_shot import *
from .agp import *
from .lottery_ticket import LotteryTicketPruner
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
from schema import And, Optional
from ..utils.config_validation import CompressorSchema
from ..compressor import Pruner
__all__ = ['ActivationAPoZRankFilterPruner', 'ActivationMeanRankFilterPruner']
logger = logging.getLogger('torch activation rank filter pruners')
class ActivationRankFilterPruner(Pruner):
"""
A structured pruning base class that prunes the filters with the smallest
importance criterion in convolution layers (using activation values)
to achieve a preset level of network sparsity.
"""
def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
activation : str
Activation function
statistics_batch_num : int
Num of batches for activation statistics
"""
super().__init__(model, config_list, optimizer)
self.set_wrappers_attribute("if_calculated", False)
self.statistics_batch_num = statistics_batch_num
self.hook_id = self._add_activation_collector()
assert activation in ['relu', 'relu6']
if activation == 'relu':
self.activation = torch.nn.functional.relu
elif activation == 'relu6':
self.activation = torch.nn.functional.relu6
else:
self.activation = None
def _add_activation_collector(self):
def collector(collected_activation):
def hook(module_, input_, output):
collected_activation.append(self.activation(output.detach().cpu()))
return hook
self.collected_activation = {}
self._fwd_hook_id += 1
self._fwd_hook_handles[self._fwd_hook_id] = []
for wrapper_idx, wrapper in enumerate(self.get_modules_wrapper()):
self.collected_activation[wrapper_idx] = []
handle = wrapper.register_forward_hook(collector(self.collected_activation[wrapper_idx]))
self._fwd_hook_handles[self._fwd_hook_id].append(handle)
return self._fwd_hook_id
def validate_config(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
Optional('op_types'): [str],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
def get_mask(self, base_mask, activations, num_prune):
raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__))
def calc_mask(self, wrapper, wrapper_idx, **kwargs):
"""
Calculate the mask of given layer.
Filters with the smallest importance criterion which is calculated from the activation are masked.
Parameters
----------
wrapper : Module
the layer to instrument the compression operation
Returns
-------
dict
dictionary for storing masks
"""
weight = wrapper.module.weight.data
op_type = wrapper.type
config = wrapper.config
assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)"
assert op_type in ['Conv2d'], "only support Conv2d"
assert op_type in config.get('op_types')
if wrapper.if_calculated:
return None
mask_weight = torch.ones(weight.size()).type_as(weight).detach()
if hasattr(wrapper.module, 'bias') and wrapper.module.bias is not None:
mask_bias = torch.ones(wrapper.module.bias.size()).type_as(wrapper.module.bias).detach()
else:
mask_bias = None
mask = {'weight_mask': mask_weight, 'bias_mask': mask_bias}
try:
filters = weight.size(0)
num_prune = int(filters * config.get('sparsity'))
acts = self.collected_activation[wrapper_idx]
if filters < 2 or num_prune < 1 or len(acts) < self.statistics_batch_num:
return mask
mask = self.get_mask(mask, acts, num_prune)
finally:
if len(acts) >= self.statistics_batch_num:
wrapper.if_calculated = True
if self.hook_id in self._fwd_hook_handles:
self.remove_activation_collector(self.hook_id)
return mask
class ActivationAPoZRankFilterPruner(ActivationRankFilterPruner):
"""
A structured pruning algorithm that prunes the filters with the
smallest APoZ(average percentage of zeros) of output activations.
Hengyuan Hu, Rui Peng, Yu-Wing Tai and Chi-Keung Tang,
"Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient Deep Architectures", ICLR 2016.
https://arxiv.org/abs/1607.03250
"""
def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
activation : str
Activation function
statistics_batch_num : int
Num of batches for activation statistics
"""
super().__init__(model, config_list, optimizer, activation, statistics_batch_num)
def get_mask(self, base_mask, activations, num_prune):
"""
Calculate the mask of given layer.
Filters with the largest APoZ(average percentage of zeros) of output activations are masked.
Parameters
----------
base_mask : dict
The basic mask with the same shape of weight, all item in the basic mask is 1.
activations : list
Layer's output activations
num_prune : int
Num of filters to prune
Returns
-------
dict
dictionary for storing masks
"""
apoz = self._calc_apoz(activations)
prune_indices = torch.argsort(apoz, descending=True)[:num_prune]
for idx in prune_indices:
base_mask['weight_mask'][idx] = 0.
if base_mask['bias_mask'] is not None:
base_mask['bias_mask'][idx] = 0.
return base_mask
def _calc_apoz(self, activations):
"""
Calculate APoZ(average percentage of zeros) of activations.
Parameters
----------
activations : list
Layer's output activations
Returns
-------
torch.Tensor
Filter's APoZ(average percentage of zeros) of the activations
"""
activations = torch.cat(activations, 0)
_eq_zero = torch.eq(activations, torch.zeros_like(activations))
_apoz = torch.sum(_eq_zero, dim=(0, 2, 3)) / torch.numel(_eq_zero[:, 0, :, :])
return _apoz
class ActivationMeanRankFilterPruner(ActivationRankFilterPruner):
"""
A structured pruning algorithm that prunes the filters with the
smallest mean value of output activations.
Pavlo Molchanov, Stephen Tyree, Tero Karras, Timo Aila and Jan Kautz,
"Pruning Convolutional Neural Networks for Resource Efficient Inference", ICLR 2017.
https://arxiv.org/abs/1611.06440
"""
def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
activation : str
Activation function
statistics_batch_num : int
Num of batches for activation statistics
"""
super().__init__(model, config_list, optimizer, activation, statistics_batch_num)
def get_mask(self, base_mask, activations, num_prune):
"""
Calculate the mask of given layer.
Filters with the smallest APoZ(average percentage of zeros) of output activations are masked.
Parameters
----------
base_mask : dict
The basic mask with the same shape of weight, all item in the basic mask is 1.
activations : list
Layer's output activations
num_prune : int
Num of filters to prune
Returns
-------
dict
dictionary for storing masks
"""
mean_activation = self._cal_mean_activation(activations)
prune_indices = torch.argsort(mean_activation)[:num_prune]
for idx in prune_indices:
base_mask['weight_mask'][idx] = 0.
if base_mask['bias_mask'] is not None:
base_mask['bias_mask'][idx] = 0.
return base_mask
def _cal_mean_activation(self, activations):
"""
Calculate mean value of activations.
Parameters
----------
activations : list
Layer's output activations
Returns
-------
torch.Tensor
Filter's mean value of the output activations
"""
activations = torch.cat(activations, 0)
mean_activation = torch.mean(activations, dim=(0, 2, 3))
return mean_activation
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
from schema import And, Optional
from .constants import MASKER_DICT
from ..utils.config_validation import CompressorSchema
from ..compressor import Pruner
__all__ = ['AGP_Pruner']
logger = logging.getLogger('torch pruner')
class AGP_Pruner(Pruner):
"""
An automated gradual pruning algorithm that prunes the smallest magnitude
weights to achieve a preset level of network sparsity.
Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the
efficacy of pruning for model compression", 2017 NIPS Workshop on Machine
Learning of Phones and other Consumer Devices,
https://arxiv.org/pdf/1710.01878.pdf
"""
def __init__(self, model, config_list, optimizer, pruning_algorithm='level'):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
List on pruning configs
optimizer: torch.optim.Optimizer
Optimizer used to train model
pruning_algorithm: str
algorithms being used to prune model
"""
super().__init__(model, config_list, optimizer)
assert isinstance(optimizer, torch.optim.Optimizer), "AGP pruner is an iterative pruner, please pass optimizer of the model to it"
self.masker = MASKER_DICT[pruning_algorithm](model, self)
self.now_epoch = 0
self.set_wrappers_attribute("if_calculated", False)
def validate_config(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
List on pruning configs
"""
schema = CompressorSchema([{
'initial_sparsity': And(float, lambda n: 0 <= n <= 1),
'final_sparsity': And(float, lambda n: 0 <= n <= 1),
'start_epoch': And(int, lambda n: n >= 0),
'end_epoch': And(int, lambda n: n >= 0),
'frequency': And(int, lambda n: n > 0),
Optional('op_types'): [str],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
def calc_mask(self, wrapper, wrapper_idx=None):
"""
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters
----------
wrapper : Module
the layer to instrument the compression operation
wrapper_idx: int
index of this wrapper in pruner's all wrappers
Returns
-------
dict | None
Dictionary for storing masks, keys of the dict:
'weight_mask': weight mask tensor
'bias_mask': bias mask tensor (optional)
"""
config = wrapper.config
start_epoch = config.get('start_epoch', 0)
freq = config.get('frequency', 1)
if wrapper.if_calculated:
return None
if not (self.now_epoch >= start_epoch and (self.now_epoch - start_epoch) % freq == 0):
return None
target_sparsity = self.compute_target_sparsity(config)
new_mask = self.masker.calc_mask(sparsity=target_sparsity, wrapper=wrapper, wrapper_idx=wrapper_idx)
if new_mask is not None:
wrapper.if_calculated = True
return new_mask
def compute_target_sparsity(self, config):
"""
Calculate the sparsity for pruning
Parameters
----------
config : dict
Layer's pruning config
Returns
-------
float
Target sparsity to be pruned
"""
end_epoch = config.get('end_epoch', 1)
start_epoch = config.get('start_epoch', 0)
freq = config.get('frequency', 1)
final_sparsity = config.get('final_sparsity', 0)
initial_sparsity = config.get('initial_sparsity', 0)
if end_epoch <= start_epoch or initial_sparsity >= final_sparsity:
logger.warning('your end epoch <= start epoch or initial_sparsity >= final_sparsity')
return final_sparsity
if end_epoch <= self.now_epoch:
return final_sparsity
span = ((end_epoch - start_epoch - 1) // freq) * freq
assert span > 0
target_sparsity = (final_sparsity +
(initial_sparsity - final_sparsity) *
(1.0 - ((self.now_epoch - start_epoch) / span)) ** 3)
return target_sparsity
def update_epoch(self, epoch):
"""
Update epoch
Parameters
----------
epoch : int
current training epoch
"""
if epoch > 0:
self.now_epoch = epoch
for wrapper in self.get_modules_wrapper():
wrapper.if_calculated = False
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from ..pruning import LevelPrunerMasker, SlimPrunerMasker, L1FilterPrunerMasker, \
L2FilterPrunerMasker, FPGMPrunerMasker, TaylorFOWeightFilterPrunerMasker, \
ActivationAPoZRankFilterPrunerMasker, ActivationMeanRankFilterPrunerMasker
MASKER_DICT = {
'level': LevelPrunerMasker,
'slim': SlimPrunerMasker,
'l1': L1FilterPrunerMasker,
'l2': L2FilterPrunerMasker,
'fpgm': FPGMPrunerMasker,
'taylorfo': TaylorFOWeightFilterPrunerMasker,
'apoz': ActivationAPoZRankFilterPrunerMasker,
'mean_activation': ActivationMeanRankFilterPrunerMasker
}
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
from .weight_masker import WeightMasker
__all__ = ['LevelPrunerMasker']
logger = logging.getLogger('torch pruner')
class LevelPrunerMasker(WeightMasker):
"""
Prune to an exact pruning level specification
"""
def calc_mask(self, sparsity, wrapper, wrapper_idx=None):
weight = wrapper.module.weight.data.clone()
if wrapper.weight_mask is not None:
# apply base mask for iterative pruning
weight = weight * wrapper.weight_mask
w_abs = weight.abs()
k = int(weight.numel() * sparsity)
if k == 0:
return {'weight_mask': torch.ones(weight.shape).type_as(weight)}
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
mask_weight = torch.gt(w_abs, threshold).type_as(weight)
mask = {'weight_mask': mask_weight}
return mask
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
from ..compressor import Pruner
__all__ = ['TaylorFOWeightFilterPruner']
logger = logging.getLogger('torch gradient rank filter pruners')
class GradientRankFilterPruner(Pruner):
"""
A structured pruning base class that prunes the filters with the smallest
importance criterion in convolution layers (using gradient values)
to achieve a preset level of network sparsity.
"""
def __init__(self, model, config_list, optimizer, statistics_batch_num=1):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
statistics_batch_num : int
Num of batches for calculating contribution
"""
super().__init__(model, config_list, optimizer)
self.set_wrappers_attribute("if_calculated", False)
self.set_wrappers_attribute("contribution", None)
self.statistics_batch_num = statistics_batch_num
self.iterations = 0
self.old_step = self.optimizer.step
self.patch_optimizer(self.calc_contributions)
def calc_contributions(self):
raise NotImplementedError('{} calc_contributions is not implemented'.format(self.__class__.__name__))
def get_mask(self, base_mask, contribution, num_prune):
raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__))
def calc_mask(self, wrapper, **kwargs):
"""
Calculate the mask of given layer.
Filters with the smallest importance criterion which is calculated from the activation are masked.
Parameters
----------
wrapper : Module
the layer to instrument the compression operation
Returns
-------
dict
dictionary for storing masks
"""
weight = wrapper.module.weight.data
op_type = wrapper.type
config = wrapper.config
assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)"
assert op_type in config.get('op_types')
if wrapper.if_calculated:
return None
mask_weight = torch.ones(weight.size()).type_as(weight).detach()
if hasattr(wrapper.module, 'bias') and wrapper.module.bias is not None:
mask_bias = torch.ones(wrapper.module.bias.size()).type_as(wrapper.module.bias).detach()
else:
mask_bias = None
mask = {'weight_mask': mask_weight, 'bias_mask': mask_bias}
try:
filters = weight.size(0)
num_prune = int(filters * config.get('sparsity'))
if filters < 2 or num_prune < 1 or self.iterations < self.statistics_batch_num:
return mask
mask = self.get_mask(mask, wrapper.contribution, num_prune)
finally:
if self.iterations >= self.statistics_batch_num:
wrapper.if_calculated = True
return mask
class TaylorFOWeightFilterPruner(GradientRankFilterPruner):
"""
A structured pruning algorithm that prunes the filters with the smallest
importance approximations based on the first order taylor expansion on the weight.
Molchanov, Pavlo and Mallya, Arun and Tyree, Stephen and Frosio, Iuri and Kautz, Jan,
"Importance Estimation for Neural Network Pruning", CVPR 2019.
http://jankautz.com/publications/Importance4NNPruning_CVPR19.pdf
"""
def __init__(self, model, config_list, optimizer, statistics_batch_num=1):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
statistics_batch_num : int
Num of batches for activation statistics
"""
super().__init__(model, config_list, optimizer, statistics_batch_num)
def get_mask(self, base_mask, contribution, num_prune):
"""
Calculate the mask of given layer.
Filters with the smallest importance approximations are masked.
Parameters
----------
base_mask : dict
The basic mask with the same shape of weight, all item in the basic mask is 1.
contribution : torch.Tensor
Layer's importance approximations
num_prune : int
Num of filters to prune
Returns
-------
dict
dictionary for storing masks
"""
prune_indices = torch.argsort(contribution)[:num_prune]
for idx in prune_indices:
base_mask['weight_mask'][idx] = 0.
if base_mask['bias_mask'] is not None:
base_mask['bias_mask'][idx] = 0.
return base_mask
def calc_contributions(self):
"""
Calculate the estimated importance of filters as a sum of individual contribution
based on the first order taylor expansion.
"""
if self.iterations >= self.statistics_batch_num:
return
for wrapper in self.get_modules_wrapper():
filters = wrapper.module.weight.size(0)
contribution = (wrapper.module.weight*wrapper.module.weight.grad).data.pow(2).view(filters, -1).sum(dim=1)
if wrapper.contribution is None:
wrapper.contribution = contribution
else:
wrapper.contribution += contribution
self.iterations += 1
...@@ -7,299 +7,10 @@ import torch ...@@ -7,299 +7,10 @@ import torch
from schema import And, Optional from schema import And, Optional
from ..utils.config_validation import CompressorSchema from ..utils.config_validation import CompressorSchema
from ..compressor import Pruner from ..compressor import Pruner
from .finegrained_pruning import LevelPrunerMasker
__all__ = ['LevelPruner', 'AGP_Pruner', 'SlimPruner', 'LotteryTicketPruner']
logger = logging.getLogger('torch pruner') logger = logging.getLogger('torch pruner')
class LevelPruner(Pruner):
"""
Prune to an exact pruning level specification
"""
def __init__(self, model, config_list, optimizer=None):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
List on pruning configs
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
super().__init__(model, config_list, optimizer)
self.set_wrappers_attribute("if_calculated", False)
def validate_config(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
List on pruning configs
"""
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
Optional('op_types'): [str],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
def calc_mask(self, wrapper, **kwargs):
"""
Calculate the mask of given layer
Parameters
----------
wrapper : Module
the module to instrument the compression operation
Returns
-------
dict
dictionary for storing masks
"""
config = wrapper.config
weight = wrapper.module.weight.data
if not wrapper.if_calculated:
w_abs = weight.abs()
k = int(weight.numel() * config['sparsity'])
if k == 0:
return torch.ones(weight.shape).type_as(weight)
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
mask_weight = torch.gt(w_abs, threshold).type_as(weight)
mask = {'weight_mask': mask_weight}
wrapper.if_calculated = True
return mask
else:
return None
class AGP_Pruner(Pruner):
"""
An automated gradual pruning algorithm that prunes the smallest magnitude
weights to achieve a preset level of network sparsity.
Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the
efficacy of pruning for model compression", 2017 NIPS Workshop on Machine
Learning of Phones and other Consumer Devices,
https://arxiv.org/pdf/1710.01878.pdf
"""
def __init__(self, model, config_list, optimizer):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
List on pruning configs
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
super().__init__(model, config_list, optimizer)
assert isinstance(optimizer, torch.optim.Optimizer), "AGP pruner is an iterative pruner, please pass optimizer of the model to it"
self.now_epoch = 0
self.set_wrappers_attribute("if_calculated", False)
def validate_config(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
List on pruning configs
"""
schema = CompressorSchema([{
'initial_sparsity': And(float, lambda n: 0 <= n <= 1),
'final_sparsity': And(float, lambda n: 0 <= n <= 1),
'start_epoch': And(int, lambda n: n >= 0),
'end_epoch': And(int, lambda n: n >= 0),
'frequency': And(int, lambda n: n > 0),
Optional('op_types'): [str],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
def calc_mask(self, wrapper, **kwargs):
"""
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters
----------
wrapper : Module
the layer to instrument the compression operation
Returns
-------
dict
dictionary for storing masks
"""
config = wrapper.config
weight = wrapper.module.weight.data
start_epoch = config.get('start_epoch', 0)
freq = config.get('frequency', 1)
if wrapper.if_calculated:
return None
if not (self.now_epoch >= start_epoch and (self.now_epoch - start_epoch) % freq == 0):
return None
mask = {'weight_mask': wrapper.weight_mask}
target_sparsity = self.compute_target_sparsity(config)
k = int(weight.numel() * target_sparsity)
if k == 0 or target_sparsity >= 1 or target_sparsity <= 0:
return mask
# if we want to generate new mask, we should update weigth first
w_abs = weight.abs() * mask['weight_mask']
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
new_mask = {'weight_mask': torch.gt(w_abs, threshold).type_as(weight)}
wrapper.if_calculated = True
return new_mask
def compute_target_sparsity(self, config):
"""
Calculate the sparsity for pruning
Parameters
----------
config : dict
Layer's pruning config
Returns
-------
float
Target sparsity to be pruned
"""
end_epoch = config.get('end_epoch', 1)
start_epoch = config.get('start_epoch', 0)
freq = config.get('frequency', 1)
final_sparsity = config.get('final_sparsity', 0)
initial_sparsity = config.get('initial_sparsity', 0)
if end_epoch <= start_epoch or initial_sparsity >= final_sparsity:
logger.warning('your end epoch <= start epoch or initial_sparsity >= final_sparsity')
return final_sparsity
if end_epoch <= self.now_epoch:
return final_sparsity
span = ((end_epoch - start_epoch - 1) // freq) * freq
assert span > 0
target_sparsity = (final_sparsity +
(initial_sparsity - final_sparsity) *
(1.0 - ((self.now_epoch - start_epoch) / span)) ** 3)
return target_sparsity
def update_epoch(self, epoch):
"""
Update epoch
Parameters
----------
epoch : int
current training epoch
"""
if epoch > 0:
self.now_epoch = epoch
for wrapper in self.get_modules_wrapper():
wrapper.if_calculated = False
class SlimPruner(Pruner):
"""
A structured pruning algorithm that prunes channels by pruning the weights of BN layers.
Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan and Changshui Zhang
"Learning Efficient Convolutional Networks through Network Slimming", 2017 ICCV
https://arxiv.org/pdf/1708.06519.pdf
"""
def __init__(self, model, config_list, optimizer=None):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
super().__init__(model, config_list, optimizer)
weight_list = []
if len(config_list) > 1:
logger.warning('Slim pruner only supports 1 configuration')
config = config_list[0]
for (layer, config) in self.get_modules_to_compress():
assert layer.type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
weight_list.append(layer.module.weight.data.abs().clone())
all_bn_weights = torch.cat(weight_list)
k = int(all_bn_weights.shape[0] * config['sparsity'])
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max()
self.set_wrappers_attribute("if_calculated", False)
def validate_config(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
'op_types': ['BatchNorm2d'],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
def calc_mask(self, wrapper, **kwargs):
"""
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters
----------
wrapper : Module
the layer to instrument the compression operation
Returns
-------
dict
dictionary for storing masks
"""
config = wrapper.config
weight = wrapper.module.weight.data
op_type = wrapper.type
assert op_type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
if wrapper.if_calculated:
return None
base_mask = torch.ones(weight.size()).type_as(weight).detach()
mask = {'weight_mask': base_mask.detach(), 'bias_mask': base_mask.clone().detach()}
filters = weight.size(0)
num_prune = int(filters * config.get('sparsity'))
if filters >= 2 and num_prune >= 1:
w_abs = weight.abs()
mask_weight = torch.gt(w_abs, self.global_threshold).type_as(weight)
mask_bias = mask_weight.clone()
mask = {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias.detach()}
wrapper.if_calculated = True
return mask
class LotteryTicketPruner(Pruner): class LotteryTicketPruner(Pruner):
""" """
This is a Pytorch implementation of the paper "The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks", This is a Pytorch implementation of the paper "The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks",
...@@ -343,6 +54,7 @@ class LotteryTicketPruner(Pruner): ...@@ -343,6 +54,7 @@ class LotteryTicketPruner(Pruner):
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
self.curr_prune_iteration = None self.curr_prune_iteration = None
self.prune_iterations = config_list[0]['prune_iterations'] self.prune_iterations = config_list[0]['prune_iterations']
self.masker = LevelPrunerMasker(model, self)
def validate_config(self, model, config_list): def validate_config(self, model, config_list):
""" """
...@@ -370,16 +82,14 @@ class LotteryTicketPruner(Pruner): ...@@ -370,16 +82,14 @@ class LotteryTicketPruner(Pruner):
curr_keep_ratio = keep_ratio_once ** self.curr_prune_iteration curr_keep_ratio = keep_ratio_once ** self.curr_prune_iteration
return max(1 - curr_keep_ratio, 0) return max(1 - curr_keep_ratio, 0)
def _calc_mask(self, weight, sparsity, curr_w_mask): def _calc_mask(self, wrapper, sparsity):
weight = wrapper.weight.data
if self.curr_prune_iteration == 0: if self.curr_prune_iteration == 0:
mask = torch.ones(weight.shape).type_as(weight) mask = {'weight_mask': torch.ones(weight.shape).type_as(weight)}
else: else:
curr_sparsity = self._calc_sparsity(sparsity) curr_sparsity = self._calc_sparsity(sparsity)
w_abs = weight.abs() * curr_w_mask mask = self.masker.calc_mask(wrapper, curr_sparsity)
k = int(w_abs.numel() * curr_sparsity) return mask
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
mask = torch.gt(w_abs, threshold).type_as(weight)
return {'weight_mask': mask}
def calc_mask(self, wrapper, **kwargs): def calc_mask(self, wrapper, **kwargs):
""" """
...@@ -433,7 +143,7 @@ class LotteryTicketPruner(Pruner): ...@@ -433,7 +143,7 @@ class LotteryTicketPruner(Pruner):
assert module_wrapper is not None assert module_wrapper is not None
sparsity = config.get('sparsity') sparsity = config.get('sparsity')
mask = self._calc_mask(layer.module.weight.data, sparsity, module_wrapper.weight_mask) mask = self._calc_mask(module_wrapper, sparsity)
# TODO: directly use weight_mask is not good # TODO: directly use weight_mask is not good
module_wrapper.weight_mask = mask['weight_mask'] module_wrapper.weight_mask = mask['weight_mask']
# there is no mask for bias # there is no mask for bias
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from schema import And, Optional
from .constants import MASKER_DICT
from ..utils.config_validation import CompressorSchema
from ..compressor import Pruner
__all__ = ['LevelPruner', 'SlimPruner', 'L1FilterPruner', 'L2FilterPruner', 'FPGMPruner', \
'TaylorFOWeightFilterPruner', 'ActivationAPoZRankFilterPruner', 'ActivationMeanRankFilterPruner']
logger = logging.getLogger('torch pruner')
class OneshotPruner(Pruner):
"""
Prune model to an exact pruning level for one time.
"""
def __init__(self, model, config_list, pruning_algorithm='level', optimizer=None, **algo_kwargs):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
List on pruning configs
pruning_algorithm: str
algorithms being used to prune model
optimizer: torch.optim.Optimizer
Optimizer used to train model
algo_kwargs: dict
Additional parameters passed to pruning algorithm masker class
"""
super().__init__(model, config_list, optimizer)
self.set_wrappers_attribute("if_calculated", False)
self.masker = MASKER_DICT[pruning_algorithm](model, self, **algo_kwargs)
def validate_config(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
List on pruning configs
"""
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
Optional('op_types'): [str],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
def calc_mask(self, wrapper, wrapper_idx=None):
"""
Calculate the mask of given layer
Parameters
----------
wrapper : Module
the module to instrument the compression operation
wrapper_idx: int
index of this wrapper in pruner's all wrappers
Returns
-------
dict
dictionary for storing masks, keys of the dict:
'weight_mask': weight mask tensor
'bias_mask': bias mask tensor (optional)
"""
if wrapper.if_calculated:
return None
sparsity = wrapper.config['sparsity']
if not wrapper.if_calculated:
masks = self.masker.calc_mask(sparsity=sparsity, wrapper=wrapper, wrapper_idx=wrapper_idx)
# masker.calc_mask returns None means calc_mask is not calculated sucessfully, can try later
if masks is not None:
wrapper.if_calculated = True
return masks
else:
return None
class LevelPruner(OneshotPruner):
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, pruning_algorithm='level', optimizer=optimizer)
class SlimPruner(OneshotPruner):
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, pruning_algorithm='slim', optimizer=optimizer)
def validate_config(self, model, config_list):
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
'op_types': ['BatchNorm2d'],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
if len(config_list) > 1:
logger.warning('Slim pruner only supports 1 configuration')
class _StructuredFilterPruner(OneshotPruner):
def __init__(self, model, config_list, pruning_algorithm, optimizer=None, **algo_kwargs):
super().__init__(model, config_list, pruning_algorithm=pruning_algorithm, optimizer=optimizer, **algo_kwargs)
def validate_config(self, model, config_list):
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
'op_types': ['Conv2d'],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
class L1FilterPruner(_StructuredFilterPruner):
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, pruning_algorithm='l1', optimizer=optimizer)
class L2FilterPruner(_StructuredFilterPruner):
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, pruning_algorithm='l2', optimizer=optimizer)
class FPGMPruner(_StructuredFilterPruner):
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, pruning_algorithm='fpgm', optimizer=optimizer)
class TaylorFOWeightFilterPruner(_StructuredFilterPruner):
def __init__(self, model, config_list, optimizer=None, statistics_batch_num=1):
super().__init__(model, config_list, pruning_algorithm='taylorfo', optimizer=optimizer, statistics_batch_num=statistics_batch_num)
class ActivationAPoZRankFilterPruner(_StructuredFilterPruner):
def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1):
super().__init__(model, config_list, pruning_algorithm='apoz', optimizer=optimizer, \
activation=activation, statistics_batch_num=statistics_batch_num)
class ActivationMeanRankFilterPruner(_StructuredFilterPruner):
def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1):
super().__init__(model, config_list, pruning_algorithm='mean_activation', optimizer=optimizer, \
activation=activation, statistics_batch_num=statistics_batch_num)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
from .weight_masker import WeightMasker
__all__ = ['L1FilterPrunerMasker', 'L2FilterPrunerMasker', 'FPGMPrunerMasker', \
'TaylorFOWeightFilterPrunerMasker', 'ActivationAPoZRankFilterPrunerMasker', \
'ActivationMeanRankFilterPrunerMasker', 'SlimPrunerMasker']
logger = logging.getLogger('torch filter pruners')
class StructuredWeightMasker(WeightMasker):
"""
A structured pruning masker base class that prunes convolutional layer filters.
"""
def calc_mask(self, sparsity, wrapper, wrapper_idx=None):
"""
Calculate the mask of given layer.
Parameters
----------
sparsity: float
pruning ratio, preserved weight ratio is `1 - sparsity`
wrapper: PrunerModuleWrapper
layer wrapper of this layer
wrapper_idx: int
index of this wrapper in pruner's all wrappers
Returns
-------
dict
dictionary for storing masks, keys of the dict:
'weight_mask': weight mask tensor
'bias_mask': bias mask tensor (optional)
"""
msg = 'module type {} is not supported!'.format(wrapper.type)
assert wrapper.type == 'Conv2d', msg
weight = wrapper.module.weight.data
bias = None
if hasattr(wrapper.module, 'bias') and wrapper.module.bias is not None:
bias = wrapper.module.bias.data
if wrapper.weight_mask is None:
mask_weight = torch.ones(weight.size()).type_as(weight).detach()
else:
mask_weight = wrapper.weight_mask.clone()
if bias is not None:
if wrapper.bias_mask is None:
mask_bias = torch.ones(bias.size()).type_as(bias).detach()
else:
mask_bias = wrapper.bias_mask.clone()
else:
mask_bias = None
mask = {'weight_mask': mask_weight, 'bias_mask': mask_bias}
filters = weight.size(0)
num_prune = int(filters * sparsity)
if filters < 2 or num_prune < 1:
return mask
# weight*mask_weight: apply base mask for iterative pruning
return self.get_mask(mask, weight*mask_weight, num_prune, wrapper, wrapper_idx)
def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx):
"""
Calculate the mask of given layer.
Parameters
----------
base_mask: dict
The basic mask with the same shape of weight, all item in the basic mask is 1.
weight: tensor
the module weight to be pruned
num_prune: int
Num of filters to prune
wrapper: PrunerModuleWrapper
layer wrapper of this layer
wrapper_idx: int
index of this wrapper in pruner's all wrappers
Returns
-------
dict
dictionary for storing masks
"""
raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__))
class L1FilterPrunerMasker(StructuredWeightMasker):
"""
A structured pruning algorithm that prunes the filters of smallest magnitude
weights sum in the convolution layers to achieve a preset level of network sparsity.
Hao Li, Asim Kadav, Igor Durdanovic, Hanan Samet and Hans Peter Graf,
"PRUNING FILTERS FOR EFFICIENT CONVNETS", 2017 ICLR
https://arxiv.org/abs/1608.08710
"""
def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx):
filters = weight.shape[0]
w_abs = weight.abs()
w_abs_structured = w_abs.view(filters, -1).sum(dim=1)
threshold = torch.topk(w_abs_structured.view(-1), num_prune, largest=False)[0].max()
mask_weight = torch.gt(w_abs_structured, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
mask_bias = torch.gt(w_abs_structured, threshold).type_as(weight).detach() if base_mask['bias_mask'] is not None else None
return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias}
class L2FilterPrunerMasker(StructuredWeightMasker):
"""
A structured pruning algorithm that prunes the filters with the
smallest L2 norm of the weights.
"""
def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx):
filters = weight.shape[0]
w = weight.view(filters, -1)
w_l2_norm = torch.sqrt((w ** 2).sum(dim=1))
threshold = torch.topk(w_l2_norm.view(-1), num_prune, largest=False)[0].max()
mask_weight = torch.gt(w_l2_norm, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
mask_bias = torch.gt(w_l2_norm, threshold).type_as(weight).detach() if base_mask['bias_mask'] is not None else None
return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias}
class FPGMPrunerMasker(StructuredWeightMasker):
"""
A filter pruner via geometric median.
"Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration",
https://arxiv.org/pdf/1811.00250.pdf
"""
def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx):
min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune)
for idx in min_gm_idx:
base_mask['weight_mask'][idx] = 0.
if base_mask['bias_mask'] is not None:
base_mask['bias_mask'][idx] = 0.
return base_mask
def _get_min_gm_kernel_idx(self, weight, n):
assert len(weight.size()) in [3, 4]
dist_list = []
for out_i in range(weight.size(0)):
dist_sum = self._get_distance_sum(weight, out_i)
dist_list.append((dist_sum, out_i))
min_gm_kernels = sorted(dist_list, key=lambda x: x[0])[:n]
return [x[1] for x in min_gm_kernels]
def _get_distance_sum(self, weight, out_idx):
"""
Calculate the total distance between a specified filter (by out_idex and in_idx) and
all other filters.
Parameters
----------
weight: Tensor
convolutional filter weight
out_idx: int
output channel index of specified filter, this method calculates the total distance
between this specified filter and all other filters.
Returns
-------
float32
The total distance
"""
logger.debug('weight size: %s', weight.size())
assert len(weight.size()) in [3, 4], 'unsupported weight shape'
w = weight.view(weight.size(0), -1)
anchor_w = w[out_idx].unsqueeze(0).expand(w.size(0), w.size(1))
x = w - anchor_w
x = (x * x).sum(-1)
x = torch.sqrt(x)
return x.sum()
class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker):
"""
A structured pruning algorithm that prunes the filters with the smallest
importance approximations based on the first order taylor expansion on the weight.
Molchanov, Pavlo and Mallya, Arun and Tyree, Stephen and Frosio, Iuri and Kautz, Jan,
"Importance Estimation for Neural Network Pruning", CVPR 2019.
http://jankautz.com/publications/Importance4NNPruning_CVPR19.pdf
"""
def __init__(self, model, pruner, statistics_batch_num=1):
super().__init__(model, pruner)
self.pruner.statistics_batch_num = statistics_batch_num
self.pruner.set_wrappers_attribute("contribution", None)
self.pruner.iterations = 0
self.pruner.patch_optimizer(self.calc_contributions)
def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx):
if self.pruner.iterations < self.pruner.statistics_batch_num:
return None
if wrapper.contribution is None:
return None
prune_indices = torch.argsort(wrapper.contribution)[:num_prune]
for idx in prune_indices:
base_mask['weight_mask'][idx] = 0.
if base_mask['bias_mask'] is not None:
base_mask['bias_mask'][idx] = 0.
return base_mask
def calc_contributions(self):
"""
Calculate the estimated importance of filters as a sum of individual contribution
based on the first order taylor expansion.
"""
if self.pruner.iterations >= self.pruner.statistics_batch_num:
return
for wrapper in self.pruner.get_modules_wrapper():
filters = wrapper.module.weight.size(0)
contribution = (wrapper.module.weight*wrapper.module.weight.grad).data.pow(2).view(filters, -1).sum(dim=1)
if wrapper.contribution is None:
wrapper.contribution = contribution
else:
wrapper.contribution += contribution
self.pruner.iterations += 1
class ActivationFilterPrunerMasker(StructuredWeightMasker):
def __init__(self, model, pruner, statistics_batch_num=1, activation='relu'):
super().__init__(model, pruner)
self.statistics_batch_num = statistics_batch_num
self.pruner.hook_id = self._add_activation_collector(self.pruner)
assert activation in ['relu', 'relu6']
if activation == 'relu':
self.pruner.activation = torch.nn.functional.relu
elif activation == 'relu6':
self.pruner.activation = torch.nn.functional.relu6
else:
self.pruner.activation = None
def _add_activation_collector(self, pruner):
def collector(collected_activation):
def hook(module_, input_, output):
collected_activation.append(pruner.activation(output.detach().cpu()))
return hook
pruner.collected_activation = {}
pruner._fwd_hook_id += 1
pruner._fwd_hook_handles[pruner._fwd_hook_id] = []
for wrapper_idx, wrapper in enumerate(pruner.get_modules_wrapper()):
pruner.collected_activation[wrapper_idx] = []
handle = wrapper.register_forward_hook(collector(pruner.collected_activation[wrapper_idx]))
pruner._fwd_hook_handles[pruner._fwd_hook_id].append(handle)
return pruner._fwd_hook_id
class ActivationAPoZRankFilterPrunerMasker(ActivationFilterPrunerMasker):
"""
A structured pruning algorithm that prunes the filters with the
smallest APoZ(average percentage of zeros) of output activations.
Hengyuan Hu, Rui Peng, Yu-Wing Tai and Chi-Keung Tang,
"Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient Deep Architectures", ICLR 2016.
https://arxiv.org/abs/1607.03250
"""
def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx):
assert wrapper_idx is not None
activations = self.pruner.collected_activation[wrapper_idx]
if len(activations) < self.statistics_batch_num:
return None
apoz = self._calc_apoz(activations)
prune_indices = torch.argsort(apoz, descending=True)[:num_prune]
for idx in prune_indices:
base_mask['weight_mask'][idx] = 0.
if base_mask['bias_mask'] is not None:
base_mask['bias_mask'][idx] = 0.
return base_mask
if len(activations) >= self.statistics_batch_num and self.pruner.hook_id in self.pruner._fwd_hook_handles:
self.pruner.remove_activation_collector(self.pruner.hook_id)
def _calc_apoz(self, activations):
"""
Calculate APoZ(average percentage of zeros) of activations.
Parameters
----------
activations : list
Layer's output activations
Returns
-------
torch.Tensor
Filter's APoZ(average percentage of zeros) of the activations
"""
activations = torch.cat(activations, 0)
_eq_zero = torch.eq(activations, torch.zeros_like(activations))
_apoz = torch.sum(_eq_zero, dim=(0, 2, 3)) / torch.numel(_eq_zero[:, 0, :, :])
return _apoz
class ActivationMeanRankFilterPrunerMasker(ActivationFilterPrunerMasker):
"""
A structured pruning algorithm that prunes the filters with the
smallest mean value of output activations.
Pavlo Molchanov, Stephen Tyree, Tero Karras, Timo Aila and Jan Kautz,
"Pruning Convolutional Neural Networks for Resource Efficient Inference", ICLR 2017.
https://arxiv.org/abs/1611.06440
"""
def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx):
assert wrapper_idx is not None
activations = self.pruner.collected_activation[wrapper_idx]
if len(activations) < self.statistics_batch_num:
return None
mean_activation = self._cal_mean_activation(activations)
prune_indices = torch.argsort(mean_activation)[:num_prune]
for idx in prune_indices:
base_mask['weight_mask'][idx] = 0.
if base_mask['bias_mask'] is not None:
base_mask['bias_mask'][idx] = 0.
if len(activations) >= self.statistics_batch_num and self.pruner.hook_id in self.pruner._fwd_hook_handles:
self.pruner.remove_activation_collector(self.pruner.hook_id)
return base_mask
def _cal_mean_activation(self, activations):
"""
Calculate mean value of activations.
Parameters
----------
activations : list
Layer's output activations
Returns
-------
torch.Tensor
Filter's mean value of the output activations
"""
activations = torch.cat(activations, 0)
mean_activation = torch.mean(activations, dim=(0, 2, 3))
return mean_activation
class SlimPrunerMasker(WeightMasker):
"""
A structured pruning algorithm that prunes channels by pruning the weights of BN layers.
Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan and Changshui Zhang
"Learning Efficient Convolutional Networks through Network Slimming", 2017 ICCV
https://arxiv.org/pdf/1708.06519.pdf
"""
def __init__(self, model, pruner, **kwargs):
super().__init__(model, pruner)
weight_list = []
for (layer, _) in pruner.get_modules_to_compress():
weight_list.append(layer.module.weight.data.abs().clone())
all_bn_weights = torch.cat(weight_list)
k = int(all_bn_weights.shape[0] * pruner.config_list[0]['sparsity'])
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max()
def calc_mask(self, sparsity, wrapper, wrapper_idx=None):
assert wrapper.type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
weight = wrapper.module.weight.data.clone()
if wrapper.weight_mask is not None:
# apply base mask for iterative pruning
weight = weight * wrapper.weight_mask
base_mask = torch.ones(weight.size()).type_as(weight).detach()
mask = {'weight_mask': base_mask.detach(), 'bias_mask': base_mask.clone().detach()}
filters = weight.size(0)
num_prune = int(filters * sparsity)
if filters >= 2 and num_prune >= 1:
w_abs = weight.abs()
mask_weight = torch.gt(w_abs, self.global_threshold).type_as(weight)
mask_bias = mask_weight.clone()
mask = {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias.detach()}
return mask
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
class WeightMasker(object):
def __init__(self, model, pruner, **kwargs):
self.model = model
self.pruner = pruner
def calc_mask(self, sparsity, wrapper, wrapper_idx=None):
"""
Calculate the mask of given layer.
Parameters
----------
sparsity: float
pruning ratio, preserved weight ratio is `1 - sparsity`
wrapper: PrunerModuleWrapper
layer wrapper of this layer
wrapper_idx: int
index of this wrapper in pruner's all wrappers
Returns
-------
dict
dictionary for storing masks, keys of the dict:
'weight_mask': weight mask tensor
'bias_mask': bias mask tensor (optional)
"""
raise NotImplementedError('{} calc_mask is not implemented'.format(self.__class__.__name__))
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
from schema import And, Optional
from ..utils.config_validation import CompressorSchema
from ..compressor import Pruner
__all__ = ['L1FilterPruner', 'L2FilterPruner', 'FPGMPruner']
logger = logging.getLogger('torch weight rank filter pruners')
class WeightRankFilterPruner(Pruner):
"""
A structured pruning base class that prunes the filters with the smallest
importance criterion in convolution layers to achieve a preset level of network sparsity.
"""
def __init__(self, model, config_list, optimizer=None):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
super().__init__(model, config_list, optimizer)
self.set_wrappers_attribute("if_calculated", False)
def validate_config(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
Optional('op_types'): ['Conv2d'],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
def get_mask(self, base_mask, weight, num_prune):
raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__))
def calc_mask(self, wrapper, **kwargs):
"""
Calculate the mask of given layer.
Filters with the smallest importance criterion of the kernel weights are masked.
Parameters
----------
wrapper : Module
the module to instrument the compression operation
Returns
-------
dict
dictionary for storing masks
"""
weight = wrapper.module.weight.data
op_type = wrapper.type
config = wrapper.config
assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)"
assert op_type in ['Conv1d', 'Conv2d'], "only support Conv1d and Conv2d"
assert op_type in config.get('op_types')
if wrapper.if_calculated:
return None
mask_weight = torch.ones(weight.size()).type_as(weight).detach()
if hasattr(wrapper.module, 'bias') and wrapper.module.bias is not None:
mask_bias = torch.ones(wrapper.module.bias.size()).type_as(wrapper.module.bias).detach()
else:
mask_bias = None
mask = {'weight_mask': mask_weight, 'bias_mask': mask_bias}
try:
filters = weight.size(0)
num_prune = int(filters * config.get('sparsity'))
if filters < 2 or num_prune < 1:
return mask
mask = self.get_mask(mask, weight, num_prune)
finally:
wrapper.if_calculated = True
return mask
class L1FilterPruner(WeightRankFilterPruner):
"""
A structured pruning algorithm that prunes the filters of smallest magnitude
weights sum in the convolution layers to achieve a preset level of network sparsity.
Hao Li, Asim Kadav, Igor Durdanovic, Hanan Samet and Hans Peter Graf,
"PRUNING FILTERS FOR EFFICIENT CONVNETS", 2017 ICLR
https://arxiv.org/abs/1608.08710
"""
def __init__(self, model, config_list, optimizer=None):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
super().__init__(model, config_list, optimizer)
def get_mask(self, base_mask, weight, num_prune):
"""
Calculate the mask of given layer.
Filters with the smallest sum of its absolute kernel weights are masked.
Parameters
----------
base_mask : dict
The basic mask with the same shape of weight or bias, all item in the basic mask is 1.
weight : torch.Tensor
Layer's weight
num_prune : int
Num of filters to prune
Returns
-------
dict
dictionary for storing masks
"""
filters = weight.shape[0]
w_abs = weight.abs()
w_abs_structured = w_abs.view(filters, -1).sum(dim=1)
threshold = torch.topk(w_abs_structured.view(-1), num_prune, largest=False)[0].max()
mask_weight = torch.gt(w_abs_structured, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
mask_bias = torch.gt(w_abs_structured, threshold).type_as(weight).detach() if base_mask['bias_mask'] is not None else None
return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias}
class L2FilterPruner(WeightRankFilterPruner):
"""
A structured pruning algorithm that prunes the filters with the
smallest L2 norm of the weights.
"""
def __init__(self, model, config_list, optimizer=None):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
super().__init__(model, config_list, optimizer)
def get_mask(self, base_mask, weight, num_prune):
"""
Calculate the mask of given layer.
Filters with the smallest L2 norm of the absolute kernel weights are masked.
Parameters
----------
base_mask : dict
The basic mask with the same shape of weight or bias, all item in the basic mask is 1.
weight : torch.Tensor
Layer's weight
num_prune : int
Num of filters to prune
Returns
-------
dict
dictionary for storing masks
"""
filters = weight.shape[0]
w = weight.view(filters, -1)
w_l2_norm = torch.sqrt((w ** 2).sum(dim=1))
threshold = torch.topk(w_l2_norm.view(-1), num_prune, largest=False)[0].max()
mask_weight = torch.gt(w_l2_norm, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
mask_bias = torch.gt(w_l2_norm, threshold).type_as(weight).detach() if base_mask['bias_mask'] is not None else None
return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias}
class FPGMPruner(WeightRankFilterPruner):
"""
A filter pruner via geometric median.
"Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration",
https://arxiv.org/pdf/1811.00250.pdf
"""
def __init__(self, model, config_list, optimizer):
"""
Parameters
----------
model : pytorch model
the model user wants to compress
config_list: list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
super().__init__(model, config_list, optimizer)
assert isinstance(optimizer, torch.optim.Optimizer), "FPGM pruner is an iterative pruner, please pass optimizer of the model to it"
def get_mask(self, base_mask, weight, num_prune):
"""
Calculate the mask of given layer.
Filters with the smallest sum of its absolute kernel weights are masked.
Parameters
----------
base_mask : dict
The basic mask with the same shape of weight and bias, all item in the basic mask is 1.
weight : torch.Tensor
Layer's weight
num_prune : int
Num of filters to prune
Returns
-------
dict
dictionary for storing masks
"""
min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune)
for idx in min_gm_idx:
base_mask['weight_mask'][idx] = 0.
if base_mask['bias_mask'] is not None:
base_mask['bias_mask'][idx] = 0.
return base_mask
def _get_min_gm_kernel_idx(self, weight, n):
assert len(weight.size()) in [3, 4]
dist_list = []
for out_i in range(weight.size(0)):
dist_sum = self._get_distance_sum(weight, out_i)
dist_list.append((dist_sum, out_i))
min_gm_kernels = sorted(dist_list, key=lambda x: x[0])[:n]
return [x[1] for x in min_gm_kernels]
def _get_distance_sum(self, weight, out_idx):
"""
Calculate the total distance between a specified filter (by out_idex and in_idx) and
all other filters.
Optimized verision of following naive implementation:
def _get_distance_sum(self, weight, in_idx, out_idx):
w = weight.view(-1, weight.size(-2), weight.size(-1))
dist_sum = 0.
for k in w:
dist_sum += torch.dist(k, weight[in_idx, out_idx], p=2)
return dist_sum
Parameters
----------
weight: Tensor
convolutional filter weight
out_idx: int
output channel index of specified filter, this method calculates the total distance
between this specified filter and all other filters.
Returns
-------
float32
The total distance
"""
logger.debug('weight size: %s', weight.size())
assert len(weight.size()) in [3, 4], 'unsupported weight shape'
w = weight.view(weight.size(0), -1)
anchor_w = w[out_idx].unsqueeze(0).expand(w.size(0), w.size(1))
x = w - anchor_w
x = (x * x).sum(-1)
x = torch.sqrt(x)
return x.sum()
def update_epoch(self, epoch):
for wrapper in self.get_modules_wrapper():
wrapper.if_calculated = False
...@@ -59,11 +59,6 @@ def tf2(func): ...@@ -59,11 +59,6 @@ def tf2(func):
return test_tf2_func return test_tf2_func
# for fpgm filter pruner test
w = np.array([[[[i + 1] * 3] * 3] * 5 for i in range(10)])
class CompressorTestCase(TestCase): class CompressorTestCase(TestCase):
def test_torch_quantizer_modules_detection(self): def test_torch_quantizer_modules_detection(self):
# test if modules can be detected # test if modules can be detected
...@@ -125,11 +120,12 @@ class CompressorTestCase(TestCase): ...@@ -125,11 +120,12 @@ class CompressorTestCase(TestCase):
https://arxiv.org/pdf/1811.00250.pdf https://arxiv.org/pdf/1811.00250.pdf
So if sparsity is 0.2, the expected masks should mask out w[4] and w[5], this can be verified through: So if sparsity is 0.2, the expected masks should mask out w[4] and w[5], this can be verified through:
`all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))` `all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 0., 0., 125., 125., 125., 125.]))`
If sparsity is 0.6, the expected masks should mask out w[2] - w[7], this can be verified through: If sparsity is 0.6, the expected masks should mask out w[2] - w[7], this can be verified through:
`all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.]))` `all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([125., 125., 0., 0., 0., 0., 0., 0., 125., 125.]))`
""" """
w = np.array([np.ones((5, 5, 5)) * (i+1) for i in range(10)]).astype(np.float32)
model = TorchModel() model = TorchModel()
config_list = [{'sparsity': 0.6, 'op_types': ['Conv2d']}, {'sparsity': 0.2, 'op_types': ['Conv2d']}] config_list = [{'sparsity': 0.6, 'op_types': ['Conv2d']}, {'sparsity': 0.2, 'op_types': ['Conv2d']}]
...@@ -137,16 +133,17 @@ class CompressorTestCase(TestCase): ...@@ -137,16 +133,17 @@ class CompressorTestCase(TestCase):
model.conv2.module.weight.data = torch.tensor(w).float() model.conv2.module.weight.data = torch.tensor(w).float()
masks = pruner.calc_mask(model.conv2) masks = pruner.calc_mask(model.conv2)
assert all(torch.sum(masks['weight_mask'], (1, 2, 3)).numpy() == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.])) assert all(torch.sum(masks['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 0., 0., 125., 125., 125., 125.]))
model.conv2.module.weight.data = torch.tensor(w).float() model.conv2.module.weight.data = torch.tensor(w).float()
model.conv2.if_calculated = False model.conv2.if_calculated = False
model.conv2.config = config_list[0] model.conv2.config = config_list[0]
masks = pruner.calc_mask(model.conv2) masks = pruner.calc_mask(model.conv2)
assert all(torch.sum(masks['weight_mask'], (1, 2, 3)).numpy() == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.])) assert all(torch.sum(masks['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 0., 0., 0., 0., 0., 0., 125., 125.]))
@tf2 @tf2
def test_tf_fpgm_pruner(self): def test_tf_fpgm_pruner(self):
w = np.array([np.ones((5, 5, 5)) * (i+1) for i in range(10)]).astype(np.float32)
model = get_tf_model() model = get_tf_model()
config_list = [{'sparsity': 0.2, 'op_types': ['Conv2D']}] config_list = [{'sparsity': 0.2, 'op_types': ['Conv2D']}]
...@@ -167,25 +164,26 @@ class CompressorTestCase(TestCase): ...@@ -167,25 +164,26 @@ class CompressorTestCase(TestCase):
PRUNING FILTERS FOR EFFICIENT CONVNETS, PRUNING FILTERS FOR EFFICIENT CONVNETS,
https://arxiv.org/abs/1608.08710 https://arxiv.org/abs/1608.08710
So if sparsity is 0.2, the expected masks should mask out filter 0, this can be verified through: So if sparsity is 0.2 for conv1, the expected masks should mask out filter 0, this can be verified through:
`all(torch.sum(mask1, (1, 2, 3)).numpy() == np.array([0., 27., 27., 27., 27.]))` `all(torch.sum(mask1, (1, 2, 3)).numpy() == np.array([0., 25., 25., 25., 25.]))`
If sparsity is 0.6, the expected masks should mask out filter 0,1,2, this can be verified through: If sparsity is 0.6 for conv2, the expected masks should mask out filter 0,1,2, this can be verified through:
`all(torch.sum(mask2, (1, 2, 3)).numpy() == np.array([0., 0., 0., 27., 27.]))` `all(torch.sum(mask2, (1, 2, 3)).numpy() == np.array([0., 0., 0., 0., 0., 0., 125., 125., 125., 125.]))`
""" """
w = np.array([np.zeros((3, 3, 3)), np.ones((3, 3, 3)), np.ones((3, 3, 3)) * 2, w1 = np.array([np.ones((1, 5, 5))*i for i in range(5)]).astype(np.float32)
np.ones((3, 3, 3)) * 3, np.ones((3, 3, 3)) * 4]) w2 = np.array([np.ones((5, 5, 5))*i for i in range(10)]).astype(np.float32)
model = TorchModel() model = TorchModel()
config_list = [{'sparsity': 0.2, 'op_types': ['Conv2d'], 'op_names': ['conv1']}, config_list = [{'sparsity': 0.2, 'op_types': ['Conv2d'], 'op_names': ['conv1']},
{'sparsity': 0.6, 'op_types': ['Conv2d'], 'op_names': ['conv2']}] {'sparsity': 0.6, 'op_types': ['Conv2d'], 'op_names': ['conv2']}]
pruner = torch_compressor.L1FilterPruner(model, config_list) pruner = torch_compressor.L1FilterPruner(model, config_list)
model.conv1.module.weight.data = torch.tensor(w).float() model.conv1.module.weight.data = torch.tensor(w1).float()
model.conv2.module.weight.data = torch.tensor(w).float() model.conv2.module.weight.data = torch.tensor(w2).float()
mask1 = pruner.calc_mask(model.conv1) mask1 = pruner.calc_mask(model.conv1)
mask2 = pruner.calc_mask(model.conv2) mask2 = pruner.calc_mask(model.conv2)
assert all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 27., 27., 27., 27.])) assert all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 25., 25., 25., 25.]))
assert all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 0., 0., 27., 27.])) assert all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 0., 0., 0., 0., 0., 125., 125., 125., 125.]))
def test_torch_slim_pruner(self): def test_torch_slim_pruner(self):
""" """
......
...@@ -8,7 +8,8 @@ import torch.nn.functional as F ...@@ -8,7 +8,8 @@ import torch.nn.functional as F
import math import math
from unittest import TestCase, main from unittest import TestCase, main
from nni.compression.torch import LevelPruner, SlimPruner, FPGMPruner, L1FilterPruner, \ from nni.compression.torch import LevelPruner, SlimPruner, FPGMPruner, L1FilterPruner, \
L2FilterPruner, AGP_Pruner, ActivationMeanRankFilterPruner, ActivationAPoZRankFilterPruner L2FilterPruner, AGP_Pruner, ActivationMeanRankFilterPruner, ActivationAPoZRankFilterPruner, \
TaylorFOWeightFilterPruner
def validate_sparsity(wrapper, sparsity, bias=False): def validate_sparsity(wrapper, sparsity, bias=False):
masks = [wrapper.weight_mask] masks = [wrapper.weight_mask]
...@@ -39,7 +40,7 @@ prune_config = { ...@@ -39,7 +40,7 @@ prune_config = {
'start_epoch': 0, 'start_epoch': 0,
'end_epoch': 10, 'end_epoch': 10,
'frequency': 1, 'frequency': 1,
'op_types': ['default'] 'op_types': ['Conv2d']
}], }],
'validators': [] 'validators': []
}, },
...@@ -83,6 +84,16 @@ prune_config = { ...@@ -83,6 +84,16 @@ prune_config = {
lambda model: validate_sparsity(model.conv1, 0.5, model.bias) lambda model: validate_sparsity(model.conv1, 0.5, model.bias)
] ]
}, },
'taylorfo': {
'pruner_class': TaylorFOWeightFilterPruner,
'config_list': [{
'sparsity': 0.5,
'op_types': ['Conv2d'],
}],
'validators': [
lambda model: validate_sparsity(model.conv1, 0.5, model.bias)
]
},
'mean_activation': { 'mean_activation': {
'pruner_class': ActivationMeanRankFilterPruner, 'pruner_class': ActivationMeanRankFilterPruner,
'config_list': [{ 'config_list': [{
...@@ -116,9 +127,8 @@ class Model(nn.Module): ...@@ -116,9 +127,8 @@ class Model(nn.Module):
def forward(self, x): def forward(self, x):
return self.fc(self.pool(self.bn1(self.conv1(x))).view(x.size(0), -1)) return self.fc(self.pool(self.bn1(self.conv1(x))).view(x.size(0), -1))
def pruners_test(pruner_names=['level', 'agp', 'slim', 'fpgm', 'l1', 'l2', 'mean_activation', 'apoz'], bias=True): def pruners_test(pruner_names=['agp', 'level', 'slim', 'fpgm', 'l1', 'l2', 'taylorfo', 'mean_activation', 'apoz'], bias=True):
for pruner_name in pruner_names: for pruner_name in pruner_names:
print('testing {}...'.format(pruner_name))
model = Model(bias=bias) model = Model(bias=bias)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
config_list = prune_config[pruner_name]['config_list'] config_list = prune_config[pruner_name]['config_list']
...@@ -142,6 +152,11 @@ def pruners_test(pruner_names=['level', 'agp', 'slim', 'fpgm', 'l1', 'l2', 'mean ...@@ -142,6 +152,11 @@ def pruners_test(pruner_names=['level', 'agp', 'slim', 'fpgm', 'l1', 'l2', 'mean
loss.backward() loss.backward()
optimizer.step() optimizer.step()
if pruner_name == 'taylorfo':
# taylorfo algorithm calculate contributions at first iteration(step), and do pruning
# when iteration >= statistics_batch_num (default 1)
optimizer.step()
pruner.export_model('./model_tmp.pth', './mask_tmp.pth', './onnx_tmp.pth', input_shape=(2,1,28,28)) pruner.export_model('./model_tmp.pth', './mask_tmp.pth', './onnx_tmp.pth', input_shape=(2,1,28,28))
for v in prune_config[pruner_name]['validators']: for v in prune_config[pruner_name]['validators']:
...@@ -151,6 +166,30 @@ def pruners_test(pruner_names=['level', 'agp', 'slim', 'fpgm', 'l1', 'l2', 'mean ...@@ -151,6 +166,30 @@ def pruners_test(pruner_names=['level', 'agp', 'slim', 'fpgm', 'l1', 'l2', 'mean
os.remove('./mask_tmp.pth') os.remove('./mask_tmp.pth')
os.remove('./onnx_tmp.pth') os.remove('./onnx_tmp.pth')
def test_agp(pruning_algorithm):
model = Model()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
config_list = prune_config['agp']['config_list']
pruner = AGP_Pruner(model, config_list, optimizer, pruning_algorithm=pruning_algorithm)
pruner.compress()
x = torch.randn(2, 1, 28, 28)
y = torch.tensor([0, 1]).long()
for epoch in range(config_list[0]['start_epoch'], config_list[0]['end_epoch']+1):
pruner.update_epoch(epoch)
out = model(x)
loss = F.cross_entropy(out, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
target_sparsity = pruner.compute_target_sparsity(config_list[0])
actual_sparsity = (model.conv1.weight_mask == 0).sum().item() / model.conv1.weight_mask.numel()
# set abs_tol = 0.2, considering the sparsity error for channel pruning when number of channels is small.
assert math.isclose(actual_sparsity, target_sparsity, abs_tol=0.2)
class PrunerTestCase(TestCase): class PrunerTestCase(TestCase):
def test_pruners(self): def test_pruners(self):
pruners_test(bias=True) pruners_test(bias=True)
...@@ -158,5 +197,13 @@ class PrunerTestCase(TestCase): ...@@ -158,5 +197,13 @@ class PrunerTestCase(TestCase):
def test_pruners_no_bias(self): def test_pruners_no_bias(self):
pruners_test(bias=False) pruners_test(bias=False)
def test_agp_pruner(self):
for pruning_algorithm in ['l1', 'l2', 'taylorfo', 'apoz']:
test_agp(pruning_algorithm)
for pruning_algorithm in ['level']:
prune_config['agp']['config_list'][0]['op_types'] = ['default']
test_agp(pruning_algorithm)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment