Commit b0c0eb7b authored by Tang Lang's avatar Tang Lang Committed by QuanluZhang
Browse files

Add network trimming pruning algorithm and fix bias mask(testing) (#1867)

parent 80a49a10
ActivationRankFilterPruner on NNI Compressor
===
## 1. Introduction
ActivationRankFilterPruner is a series of pruners which prune filters according to some importance criterion calculated from the filters' output activations.
| Pruner | Importance criterion | Reference paper |
| :----------------------------: | :-------------------------------: | :----------------------------------------------------------: |
| ActivationAPoZRankFilterPruner | APoZ(average percentage of zeros) | [Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient Deep Architectures](https://arxiv.org/abs/1607.03250) |
| ActivationMeanRankFilterPruner | mean value of output activations | [Pruning Convolutional Neural Networks for Resource Efficient Inference](https://arxiv.org/abs/1611.06440) |
## 2. Pruners
### ActivationAPoZRankFilterPruner
Hengyuan Hu, Rui Peng, Yu-Wing Tai and Chi-Keung Tang,
"[Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient Deep Architectures](https://arxiv.org/abs/1607.03250)", ICLR 2016.
ActivationAPoZRankFilterPruner prunes the filters with the smallest APoZ(average percentage of zeros) of output activations.
The APoZ is defined as:
![](../../img/apoz.png)
### ActivationMeanRankFilterPruner
Pavlo Molchanov, Stephen Tyree, Tero Karras, Timo Aila and Jan Kautz,
"[Pruning Convolutional Neural Networks for Resource Efficient Inference](https://arxiv.org/abs/1611.06440)", ICLR 2017.
ActivationMeanRankFilterPruner prunes the filters with the smallest mean value of output activations
## 3. Usage
PyTorch code
```python
from nni.compression.torch import ActivationAPoZRankFilterPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'], 'op_names': ['conv1', 'conv2'] }]
pruner = ActivationAPoZRankFilterPruner(model, config_list, statistics_batch_num=1)
pruner.compress()
```
#### User configuration for ActivationAPoZRankFilterPruner
- **sparsity:** This is to specify the sparsity operations to be compressed to
- **op_types:** Only Conv2d is supported in ActivationAPoZRankFilterPruner
## 4. Experiment
TODO.
...@@ -14,10 +14,14 @@ We have provided several compression algorithms, including several pruning and q ...@@ -14,10 +14,14 @@ We have provided several compression algorithms, including several pruning and q
|---|---| |---|---|
| [Level Pruner](./Pruner.md#level-pruner) | Pruning the specified ratio on each weight based on absolute values of weights | | [Level Pruner](./Pruner.md#level-pruner) | Pruning the specified ratio on each weight based on absolute values of weights |
| [AGP Pruner](./Pruner.md#agp-pruner) | Automated gradual pruning (To prune, or not to prune: exploring the efficacy of pruning for model compression) [Reference Paper](https://arxiv.org/abs/1710.01878)| | [AGP Pruner](./Pruner.md#agp-pruner) | Automated gradual pruning (To prune, or not to prune: exploring the efficacy of pruning for model compression) [Reference Paper](https://arxiv.org/abs/1710.01878)|
| [L1Filter Pruner](./Pruner.md#l1filter-pruner) | Pruning least important filters in convolution layers(PRUNING FILTERS FOR EFFICIENT CONVNETS)[Reference Paper](https://arxiv.org/abs/1608.08710) |
| [Slim Pruner](./Pruner.md#slim-pruner) | Pruning channels in convolution layers by pruning scaling factors in BN layers(Learning Efficient Convolutional Networks through Network Slimming)[Reference Paper](https://arxiv.org/abs/1708.06519) |
| [Lottery Ticket Pruner](./Pruner.md#agp-pruner) | The pruning process used by "The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks". It prunes a model iteratively. [Reference Paper](https://arxiv.org/abs/1803.03635)| | [Lottery Ticket Pruner](./Pruner.md#agp-pruner) | The pruning process used by "The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks". It prunes a model iteratively. [Reference Paper](https://arxiv.org/abs/1803.03635)|
| [FPGM Pruner](./Pruner.md#fpgm-pruner) | Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration [Reference Paper](https://arxiv.org/pdf/1811.00250.pdf)| | [FPGM Pruner](./Pruner.md#fpgm-pruner) | Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration [Reference Paper](https://arxiv.org/pdf/1811.00250.pdf)|
| [L1Filter Pruner](./Pruner.md#l1filter-pruner) | Pruning filters with the smallest L1 norm of weights in convolution layers(PRUNING FILTERS FOR EFFICIENT CONVNETS)[Reference Paper](https://arxiv.org/abs/1608.08710) |
| [L2Filter Pruner](./Pruner.md#l2filter-pruner) | Pruning filters with the smallest L2 norm of weights in convolution layers |
| [ActivationAPoZRankFilterPruner](./Pruner.md#ActivationAPoZRankFilterPruner) | Pruning filters prunes the filters with the smallest APoZ(average percentage of zeros) of output activations(Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient Deep Architectures)[Reference Paper](https://arxiv.org/abs/1607.03250) |
| [ActivationMeanRankFilterPruner](./Pruner.md#ActivationMeanRankFilterPruner) | Pruning filters prunes the filters with the smallest mean value of output activations(Pruning Convolutional Neural Networks for Resource Efficient Inference)[Reference Paper](https://arxiv.org/abs/1611.06440) |
| [Slim Pruner](./Pruner.md#slim-pruner) | Pruning channels in convolution layers by pruning scaling factors in BN layers(Learning Efficient Convolutional Networks through Network Slimming)[Reference Paper](https://arxiv.org/abs/1708.06519) |
**Quantization** **Quantization**
......
...@@ -10,7 +10,7 @@ We first sort the weights in the specified layer by their absolute values. And t ...@@ -10,7 +10,7 @@ We first sort the weights in the specified layer by their absolute values. And t
### Usage ### Usage
Tensorflow code Tensorflow code
``` ```python
from nni.compression.tensorflow import LevelPruner from nni.compression.tensorflow import LevelPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['default'] }] config_list = [{ 'sparsity': 0.8, 'op_types': ['default'] }]
pruner = LevelPruner(model_graph, config_list) pruner = LevelPruner(model_graph, config_list)
...@@ -18,7 +18,7 @@ pruner.compress() ...@@ -18,7 +18,7 @@ pruner.compress()
``` ```
PyTorch code PyTorch code
``` ```python
from nni.compression.torch import LevelPruner from nni.compression.torch import LevelPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['default'] }] config_list = [{ 'sparsity': 0.8, 'op_types': ['default'] }]
pruner = LevelPruner(model, config_list) pruner = LevelPruner(model, config_list)
...@@ -40,8 +40,6 @@ This is an iterative pruner, In [To prune, or not to prune: exploring the effica ...@@ -40,8 +40,6 @@ This is an iterative pruner, In [To prune, or not to prune: exploring the effica
### Usage ### Usage
You can prune all weight from 0% to 80% sparsity in 10 epoch with the code below. You can prune all weight from 0% to 80% sparsity in 10 epoch with the code below.
First, you should import pruner and add mask to model.
Tensorflow code Tensorflow code
```python ```python
from nni.compression.tensorflow import AGP_Pruner from nni.compression.tensorflow import AGP_Pruner
...@@ -71,7 +69,7 @@ pruner = AGP_Pruner(model, config_list) ...@@ -71,7 +69,7 @@ pruner = AGP_Pruner(model, config_list)
pruner.compress() pruner.compress()
``` ```
Second, you should add code below to update epoch number when you finish one epoch in your training code. you should add code below to update epoch number when you finish one epoch in your training code.
Tensorflow code Tensorflow code
```python ```python
...@@ -133,13 +131,16 @@ The above configuration means that there are 5 times of iterative pruning. As th ...@@ -133,13 +131,16 @@ The above configuration means that there are 5 times of iterative pruning. As th
* **sparsity:** The final sparsity when the compression is done. * **sparsity:** The final sparsity when the compression is done.
*** ***
## FPGM Pruner ## WeightRankFilterPruner
WeightRankFilterPruner is a series of pruners which prune the filters with the smallest importance criterion calculated from the weights in convolution layers to achieve a preset level of network sparsity
### 1, FPGM Pruner
This is an one-shot pruner, FPGM Pruner is an implementation of paper [Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration](https://arxiv.org/pdf/1811.00250.pdf) This is an one-shot pruner, FPGM Pruner is an implementation of paper [Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration](https://arxiv.org/pdf/1811.00250.pdf)
>Previous works utilized “smaller-norm-less-important” criterion to prune filters with smaller norm values in a convolutional neural network. In this paper, we analyze this norm-based criterion and point out that its effectiveness depends on two requirements that are not always met: (1) the norm deviation of the filters should be large; (2) the minimum norm of the filters should be small. To solve this problem, we propose a novel filter pruning method, namely Filter Pruning via Geometric Median (FPGM), to compress the model regardless of those two requirements. Unlike previous methods, FPGM compresses CNN models by pruning filters with redundancy, rather than those with “relatively less” importance. >Previous works utilized “smaller-norm-less-important” criterion to prune filters with smaller norm values in a convolutional neural network. In this paper, we analyze this norm-based criterion and point out that its effectiveness depends on two requirements that are not always met: (1) the norm deviation of the filters should be large; (2) the minimum norm of the filters should be small. To solve this problem, we propose a novel filter pruning method, namely Filter Pruning via Geometric Median (FPGM), to compress the model regardless of those two requirements. Unlike previous methods, FPGM compresses CNN models by pruning filters with redundancy, rather than those with “relatively less” importance.
### Usage #### Usage
First, you should import pruner and add mask to model.
Tensorflow code Tensorflow code
```python ```python
...@@ -163,7 +164,7 @@ pruner.compress() ...@@ -163,7 +164,7 @@ pruner.compress()
``` ```
Note: FPGM Pruner is used to prune convolutional layers within deep neural networks, therefore the `op_types` field supports only convolutional layers. Note: FPGM Pruner is used to prune convolutional layers within deep neural networks, therefore the `op_types` field supports only convolutional layers.
Second, you should add code below to update epoch number at beginning of each epoch. you should add code below to update epoch number at beginning of each epoch.
Tensorflow code Tensorflow code
```python ```python
...@@ -180,7 +181,7 @@ You can view example for more information ...@@ -180,7 +181,7 @@ You can view example for more information
*** ***
## L1Filter Pruner ### 2, L1Filter Pruner
This is an one-shot pruner, In ['PRUNING FILTERS FOR EFFICIENT CONVNETS'](https://arxiv.org/abs/1608.08710), authors Hao Li, Asim Kadav, Igor Durdanovic, Hanan Samet and Hans Peter Graf. This is an one-shot pruner, In ['PRUNING FILTERS FOR EFFICIENT CONVNETS'](https://arxiv.org/abs/1608.08710), authors Hao Li, Asim Kadav, Igor Durdanovic, Hanan Samet and Hans Peter Graf.
...@@ -193,12 +194,16 @@ This is an one-shot pruner, In ['PRUNING FILTERS FOR EFFICIENT CONVNETS'](https: ...@@ -193,12 +194,16 @@ This is an one-shot pruner, In ['PRUNING FILTERS FOR EFFICIENT CONVNETS'](https:
> 1. For each filter ![](http://latex.codecogs.com/gif.latex?F_{i,j}), calculate the sum of its absolute kernel weights![](http://latex.codecogs.com/gif.latex?s_j=\sum_{l=1}^{n_i}\sum|K_l|) > 1. For each filter ![](http://latex.codecogs.com/gif.latex?F_{i,j}), calculate the sum of its absolute kernel weights![](http://latex.codecogs.com/gif.latex?s_j=\sum_{l=1}^{n_i}\sum|K_l|)
> 2. Sort the filters by ![](http://latex.codecogs.com/gif.latex?s_j). > 2. Sort the filters by ![](http://latex.codecogs.com/gif.latex?s_j).
> 3. Prune ![](http://latex.codecogs.com/gif.latex?m) filters with the smallest sum values and their corresponding feature maps. The > 3. Prune ![](http://latex.codecogs.com/gif.latex?m) filters with the smallest sum values and their corresponding feature maps. The
> kernels in the next convolutional layer corresponding to the pruned feature maps are also > kernels in the next convolutional layer corresponding to the pruned feature maps are also
> removed. > removed.
> 4. A new kernel matrix is created for both the ![](http://latex.codecogs.com/gif.latex?i)th and ![](http://latex.codecogs.com/gif.latex?i+1)th layers, and the remaining kernel > 4. A new kernel matrix is created for both the ![](http://latex.codecogs.com/gif.latex?i)th and ![](http://latex.codecogs.com/gif.latex?i+1)th layers, and the remaining kernel
> weights are copied to the new model. > weights are copied to the new model.
``` #### Usage
PyTorch code
```python
from nni.compression.torch import L1FilterPruner from nni.compression.torch import L1FilterPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }] config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
pruner = L1FilterPruner(model, config_list) pruner = L1FilterPruner(model, config_list)
...@@ -208,7 +213,90 @@ pruner.compress() ...@@ -208,7 +213,90 @@ pruner.compress()
#### User configuration for L1Filter Pruner #### User configuration for L1Filter Pruner
- **sparsity:** This is to specify the sparsity operations to be compressed to - **sparsity:** This is to specify the sparsity operations to be compressed to
- **op_types:** Only Conv2d is supported in L1Filter Pruner - **op_types:** Only Conv1d and Conv2d is supported in L1Filter Pruner
***
### 3, L2Filter Pruner
This is a structured pruning algorithm that prunes the filters with the smallest L2 norm of the weights.
#### Usage
PyTorch code
```python
from nni.compression.torch import L2FilterPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
pruner = L2FilterPruner(model, config_list)
pruner.compress()
```
#### User configuration for L2Filter Pruner
- **sparsity:** This is to specify the sparsity operations to be compressed to
- **op_types:** Only Conv1d and Conv2d is supported in L2Filter Pruner
## ActivationRankFilterPruner
ActivationRankFilterPruner is a series of pruners which prune the filters with the smallest importance criterion calculated from the output activations of convolution layers to achieve a preset level of network sparsity
### 1, ActivationAPoZRankFilterPruner
This is an one-shot pruner, ActivationAPoZRankFilterPruner is an implementation of paper [Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient Deep Architectures](https://arxiv.org/abs/1607.03250)
#### Usage
PyTorch code
```python
from nni.compression.torch import ActivationAPoZRankFilterPruner
config_list = [{
'sparsity': 0.5,
'op_types': ['Conv2d']
}]
pruner = ActivationAPoZRankFilterPruner(model, config_list, statistics_batch_num=1)
pruner.compress()
```
Note: ActivationAPoZRankFilterPruner is used to prune convolutional layers within deep neural networks, therefore the `op_types` field supports only convolutional layers.
You can view example for more information
#### User configuration for ActivationAPoZRankFilterPruner
- **sparsity:** How much percentage of convolutional filters are to be pruned.
- **op_types:** Only Conv2d is supported in ActivationAPoZRankFilterPruner
***
### 2, ActivationMeanRankFilterPruner
This is an one-shot pruner, ActivationMeanRankFilterPruner is an implementation of paper [Pruning Convolutional Neural Networks for Resource Efficient Inference](https://arxiv.org/abs/1611.06440)
#### Usage
PyTorch code
```python
from nni.compression.torch import ActivationMeanRankFilterPruner
config_list = [{
'sparsity': 0.5,
'op_types': ['Conv2d']
}]
pruner = ActivationMeanRankFilterPruner(model, config_list)
pruner.compress()
```
Note: ActivationMeanRankFilterPruner is used to prune convolutional layers within deep neural networks, therefore the `op_types` field supports only convolutional layers.
You can view example for more information
#### User configuration for ActivationMeanRankFilterPruner
- **sparsity:** How much percentage of convolutional filters are to be pruned.
- **op_types:** Only Conv2d is supported in ActivationMeanRankFilterPruner
***
## Slim Pruner ## Slim Pruner
...@@ -222,7 +310,7 @@ This is an one-shot pruner, In ['Learning Efficient Convolutional Networks throu ...@@ -222,7 +310,7 @@ This is an one-shot pruner, In ['Learning Efficient Convolutional Networks throu
PyTorch code PyTorch code
``` ```python
from nni.compression.torch import SlimPruner from nni.compression.torch import SlimPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['BatchNorm2d'] }] config_list = [{ 'sparsity': 0.8, 'op_types': ['BatchNorm2d'] }]
pruner = SlimPruner(model, config_list) pruner = SlimPruner(model, config_list)
......
L1FilterPruner on NNI Compressor WeightRankFilterPruner on NNI Compressor
=== ===
## 1. Introduction ## 1. Introduction
WeightRankFilterPruner is a series of pruners which prune filters according to some importance criterion calculated from the filters' weight.
| Pruner | Importance criterion | Reference paper |
| :------------: | :-------------------------: | :----------------------------------------------------------: |
| L1FilterPruner | L1 norm of weights | [PRUNING FILTERS FOR EFFICIENT CONVNETS](https://arxiv.org/abs/1608.08710) |
| L2FilterPruner | L2 norm of weights | |
| FPGMPruner | Geometric Median of weights | [Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration](https://arxiv.org/pdf/1811.00250.pdf) |
## 2. Pruners
### L1FilterPruner
L1FilterPruner is a general structured pruning algorithm for pruning filters in the convolutional layers. L1FilterPruner is a general structured pruning algorithm for pruning filters in the convolutional layers.
In ['PRUNING FILTERS FOR EFFICIENT CONVNETS'](https://arxiv.org/abs/1608.08710), authors Hao Li, Asim Kadav, Igor Durdanovic, Hanan Samet and Hans Peter Graf. In ['PRUNING FILTERS FOR EFFICIENT CONVNETS'](https://arxiv.org/abs/1608.08710), authors Hao Li, Asim Kadav, Igor Durdanovic, Hanan Samet and Hans Peter Graf.
...@@ -16,12 +28,26 @@ In ['PRUNING FILTERS FOR EFFICIENT CONVNETS'](https://arxiv.org/abs/1608.08710), ...@@ -16,12 +28,26 @@ In ['PRUNING FILTERS FOR EFFICIENT CONVNETS'](https://arxiv.org/abs/1608.08710),
> 1. For each filter ![](http://latex.codecogs.com/gif.latex?F_{i,j}), calculate the sum of its absolute kernel weights![](http://latex.codecogs.com/gif.latex?s_j=\sum_{l=1}^{n_i}\sum|K_l|) > 1. For each filter ![](http://latex.codecogs.com/gif.latex?F_{i,j}), calculate the sum of its absolute kernel weights![](http://latex.codecogs.com/gif.latex?s_j=\sum_{l=1}^{n_i}\sum|K_l|)
> 2. Sort the filters by ![](http://latex.codecogs.com/gif.latex?s_j). > 2. Sort the filters by ![](http://latex.codecogs.com/gif.latex?s_j).
> 3. Prune ![](http://latex.codecogs.com/gif.latex?m) filters with the smallest sum values and their corresponding feature maps. The > 3. Prune ![](http://latex.codecogs.com/gif.latex?m) filters with the smallest sum values and their corresponding feature maps. The
> kernels in the next convolutional layer corresponding to the pruned feature maps are also > kernels in the next convolutional layer corresponding to the pruned feature maps are also
> removed. > removed.
> 4. A new kernel matrix is created for both the ![](http://latex.codecogs.com/gif.latex?i)th and ![](http://latex.codecogs.com/gif.latex?i+1)th layers, and the remaining kernel > 4. A new kernel matrix is created for both the ![](http://latex.codecogs.com/gif.latex?i)th and ![](http://latex.codecogs.com/gif.latex?i+1)th layers, and the remaining kernel
> weights are copied to the new model. > weights are copied to the new model.
### L2FilterPruner
L2FilterPruner is similar to L1FilterPruner, but only replace the importance criterion from L1 norm to L2 norm
### FPGMPruner
Yang He, Ping Liu, Ziwei Wang, Zhilan Hu, Yi Yang
"[Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration](https://arxiv.org/abs/1811.00250)", CVPR 2019.
FPGMPruner prune filters with the smallest geometric median
![](../../img/fpgm_fig1.png)
## 2. Usage ## 3. Usage
PyTorch code PyTorch code
...@@ -37,9 +63,9 @@ pruner.compress() ...@@ -37,9 +63,9 @@ pruner.compress()
- **sparsity:** This is to specify the sparsity operations to be compressed to - **sparsity:** This is to specify the sparsity operations to be compressed to
- **op_types:** Only Conv2d is supported in L1Filter Pruner - **op_types:** Only Conv2d is supported in L1Filter Pruner
## 3. Experiment ## 4. Experiment
We implemented one of the experiments in ['PRUNING FILTERS FOR EFFICIENT CONVNETS'](https://arxiv.org/abs/1608.08710), we pruned **VGG-16** for CIFAR-10 to **VGG-16-pruned-A** in the paper, in which $64\%$ parameters are pruned. Our experiments results are as follows: We implemented one of the experiments in ['PRUNING FILTERS FOR EFFICIENT CONVNETS'](https://arxiv.org/abs/1608.08710) with **L1FilterPruner**, we pruned **VGG-16** for CIFAR-10 to **VGG-16-pruned-A** in the paper, in which $64\%$ parameters are pruned. Our experiments results are as follows:
| Model | Error(paper/ours) | Parameters | Pruned | | Model | Error(paper/ours) | Parameters | Pruned |
| --------------- | ----------------- | --------------- | -------- | | --------------- | ----------------- | --------------- | -------- |
......
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.compression.torch import ActivationAPoZRankFilterPruner
from models.cifar10.vgg import VGG
def train(model, device, train_loader, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
acc = 100 * correct / len(test_loader.dataset)
print('Loss: {} Accuracy: {}%)\n'.format(
test_loss, acc))
return acc
def main():
torch.manual_seed(0)
device = torch.device('cuda')
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data.cifar10', train=True, download=True,
transform=transforms.Compose([
transforms.Pad(4),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=200, shuffle=False)
model = VGG(depth=16)
model.to(device)
# Train the base VGG-16 model
print('=' * 10 + 'Train the unpruned base model' + '=' * 10)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 160, 0)
for epoch in range(160):
train(model, device, train_loader, optimizer)
test(model, device, test_loader)
lr_scheduler.step(epoch)
torch.save(model.state_dict(), 'vgg16_cifar10.pth')
# Test base model accuracy
print('=' * 10 + 'Test on the original model' + '=' * 10)
model.load_state_dict(torch.load('vgg16_cifar10.pth'))
test(model, device, test_loader)
# top1 = 93.51%
# Pruning Configuration, in paper 'PRUNING FILTERS FOR EFFICIENT CONVNETS',
# Conv_1, Conv_8, Conv_9, Conv_10, Conv_11, Conv_12 are pruned with 50% sparsity, as 'VGG-16-pruned-A'
configure_list = [{
'sparsity': 0.5,
'op_types': ['default'],
'op_names': ['feature.0', 'feature.24', 'feature.27', 'feature.30', 'feature.34', 'feature.37']
}]
# Prune model and test accuracy without fine tuning.
print('=' * 10 + 'Test on the pruned model before fine tune' + '=' * 10)
pruner = ActivationAPoZRankFilterPruner(model, configure_list)
model = pruner.compress()
test(model, device, test_loader)
# top1 = 88.19%
# Fine tune the pruned model for 40 epochs and test accuracy
print('=' * 10 + 'Fine tuning' + '=' * 10)
optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
best_top1 = 0
for epoch in range(40):
pruner.update_epoch(epoch)
print('# Epoch {} #'.format(epoch))
train(model, device, train_loader, optimizer_finetune)
top1 = test(model, device, test_loader)
if top1 > best_top1:
best_top1 = top1
# Export the best model, 'model_path' stores state_dict of the pruned model,
# mask_path stores mask_dict of the pruned model
pruner.export_model(model_path='pruned_vgg16_cifar10.pth', mask_path='mask_vgg16_cifar10.pth')
# Test the exported model
print('=' * 10 + 'Test on the pruned model after fine tune' + '=' * 10)
new_model = VGG(depth=16)
new_model.to(device)
new_model.load_state_dict(torch.load('pruned_vgg16_cifar10.pth'))
test(new_model, device, test_loader)
# top1 = 93.53%
if __name__ == '__main__':
main()
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.compression.torch import L1FilterPruner
from models.cifar10.vgg import VGG
def train(model, device, train_loader, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
acc = 100 * correct / len(test_loader.dataset)
print('Loss: {} Accuracy: {}%)\n'.format(
test_loss, acc))
return acc
def main():
torch.manual_seed(0)
device = torch.device('cuda')
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data.cifar10', train=True, download=True,
transform=transforms.Compose([
transforms.Pad(4),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=200, shuffle=False)
model = VGG(depth=16)
model.to(device)
# Train the base VGG-16 model
print('=' * 10 + 'Train the unpruned base model' + '=' * 10)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 160, 0)
for epoch in range(160):
train(model, device, train_loader, optimizer)
test(model, device, test_loader)
lr_scheduler.step(epoch)
torch.save(model.state_dict(), 'vgg16_cifar10.pth')
# Test base model accuracy
print('=' * 10 + 'Test on the original model' + '=' * 10)
model.load_state_dict(torch.load('vgg16_cifar10.pth'))
test(model, device, test_loader)
# top1 = 93.51%
# Pruning Configuration, in paper 'PRUNING FILTERS FOR EFFICIENT CONVNETS',
# Conv_1, Conv_8, Conv_9, Conv_10, Conv_11, Conv_12 are pruned with 50% sparsity, as 'VGG-16-pruned-A'
configure_list = [{
'sparsity': 0.5,
'op_types': ['default'],
'op_names': ['feature.0', 'feature.24', 'feature.27', 'feature.30', 'feature.34', 'feature.37']
}]
# Prune model and test accuracy without fine tuning.
print('=' * 10 + 'Test on the pruned model before fine tune' + '=' * 10)
pruner = L1FilterPruner(model, configure_list)
model = pruner.compress()
test(model, device, test_loader)
# top1 = 88.19%
# Fine tune the pruned model for 40 epochs and test accuracy
print('=' * 10 + 'Fine tuning' + '=' * 10)
optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
best_top1 = 0
for epoch in range(40):
pruner.update_epoch(epoch)
print('# Epoch {} #'.format(epoch))
train(model, device, train_loader, optimizer_finetune)
top1 = test(model, device, test_loader)
if top1 > best_top1:
best_top1 = top1
# Export the best model, 'model_path' stores state_dict of the pruned model,
# mask_path stores mask_dict of the pruned model
pruner.export_model(model_path='pruned_vgg16_cifar10.pth', mask_path='mask_vgg16_cifar10.pth')
# Test the exported model
print('=' * 10 + 'Test on the pruned model after fine tune' + '=' * 10)
new_model = VGG(depth=16)
new_model.to(device)
new_model.load_state_dict(torch.load('pruned_vgg16_cifar10.pth'))
test(new_model, device, test_loader)
# top1 = 93.53%
if __name__ == '__main__':
main()
...@@ -5,7 +5,8 @@ import logging ...@@ -5,7 +5,8 @@ import logging
import torch import torch
from .compressor import Pruner from .compressor import Pruner
__all__ = ['LevelPruner', 'AGP_Pruner', 'SlimPruner', 'L1FilterPruner', 'L2FilterPruner', 'FPGMPruner'] __all__ = ['LevelPruner', 'AGP_Pruner', 'SlimPruner', 'L1FilterPruner', 'L2FilterPruner', 'FPGMPruner',
'ActivationAPoZRankFilterPruner', 'ActivationMeanRankFilterPruner']
logger = logging.getLogger('torch pruner') logger = logging.getLogger('torch pruner')
...@@ -26,7 +27,7 @@ class LevelPruner(Pruner): ...@@ -26,7 +27,7 @@ class LevelPruner(Pruner):
""" """
super().__init__(model, config_list) super().__init__(model, config_list)
self.if_init_list = {} self.mask_calculated_ops = set()
def calc_mask(self, layer, config): def calc_mask(self, layer, config):
""" """
...@@ -39,22 +40,24 @@ class LevelPruner(Pruner): ...@@ -39,22 +40,24 @@ class LevelPruner(Pruner):
layer's pruning config layer's pruning config
Returns Returns
------- -------
torch.Tensor dict
mask of the layer's weight dictionary for storing masks
""" """
weight = layer.module.weight.data weight = layer.module.weight.data
op_name = layer.name op_name = layer.name
if self.if_init_list.get(op_name, True): if op_name not in self.mask_calculated_ops:
w_abs = weight.abs() w_abs = weight.abs()
k = int(weight.numel() * config['sparsity']) k = int(weight.numel() * config['sparsity'])
if k == 0: if k == 0:
return torch.ones(weight.shape).type_as(weight) return torch.ones(weight.shape).type_as(weight)
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max() threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
mask = torch.gt(w_abs, threshold).type_as(weight) mask_weight = torch.gt(w_abs, threshold).type_as(weight)
mask = {'weight': mask_weight}
self.mask_dict.update({op_name: mask}) self.mask_dict.update({op_name: mask})
self.if_init_list.update({op_name: False}) self.mask_calculated_ops.add(op_name)
else: else:
assert op_name in self.mask_dict, "op_name not in the mask_dict"
mask = self.mask_dict[op_name] mask = self.mask_dict[op_name]
return mask return mask
...@@ -94,8 +97,8 @@ class AGP_Pruner(Pruner): ...@@ -94,8 +97,8 @@ class AGP_Pruner(Pruner):
layer's pruning config layer's pruning config
Returns Returns
------- -------
torch.Tensor dict
mask of the layer's weight dictionary for storing masks
""" """
weight = layer.module.weight.data weight = layer.module.weight.data
...@@ -104,7 +107,7 @@ class AGP_Pruner(Pruner): ...@@ -104,7 +107,7 @@ class AGP_Pruner(Pruner):
freq = config.get('frequency', 1) freq = config.get('frequency', 1)
if self.now_epoch >= start_epoch and self.if_init_list.get(op_name, True) \ if self.now_epoch >= start_epoch and self.if_init_list.get(op_name, True) \
and (self.now_epoch - start_epoch) % freq == 0: and (self.now_epoch - start_epoch) % freq == 0:
mask = self.mask_dict.get(op_name, torch.ones(weight.shape).type_as(weight)) mask = self.mask_dict.get(op_name, {'weight': torch.ones(weight.shape).type_as(weight)})
target_sparsity = self.compute_target_sparsity(config) target_sparsity = self.compute_target_sparsity(config)
k = int(weight.numel() * target_sparsity) k = int(weight.numel() * target_sparsity)
if k == 0 or target_sparsity >= 1 or target_sparsity <= 0: if k == 0 or target_sparsity >= 1 or target_sparsity <= 0:
...@@ -112,11 +115,11 @@ class AGP_Pruner(Pruner): ...@@ -112,11 +115,11 @@ class AGP_Pruner(Pruner):
# if we want to generate new mask, we should update weigth first # if we want to generate new mask, we should update weigth first
w_abs = weight.abs() * mask w_abs = weight.abs() * mask
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max() threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
new_mask = torch.gt(w_abs, threshold).type_as(weight) new_mask = {'weight': torch.gt(w_abs, threshold).type_as(weight)}
self.mask_dict.update({op_name: new_mask}) self.mask_dict.update({op_name: new_mask})
self.if_init_list.update({op_name: False}) self.if_init_list.update({op_name: False})
else: else:
new_mask = self.mask_dict.get(op_name, torch.ones(weight.shape).type_as(weight)) new_mask = self.mask_dict.get(op_name, {'weight': torch.ones(weight.shape).type_as(weight)})
return new_mask return new_mask
def compute_target_sparsity(self, config): def compute_target_sparsity(self, config):
...@@ -208,8 +211,8 @@ class SlimPruner(Pruner): ...@@ -208,8 +211,8 @@ class SlimPruner(Pruner):
layer's pruning config layer's pruning config
Returns Returns
------- -------
torch.Tensor dict
mask of the layer's weight dictionary for storing masks
""" """
weight = layer.module.weight.data weight = layer.module.weight.data
...@@ -219,10 +222,17 @@ class SlimPruner(Pruner): ...@@ -219,10 +222,17 @@ class SlimPruner(Pruner):
if op_name in self.mask_calculated_ops: if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict assert op_name in self.mask_dict
return self.mask_dict.get(op_name) return self.mask_dict.get(op_name)
mask = torch.ones(weight.size()).type_as(weight) base_mask = torch.ones(weight.size()).type_as(weight).detach()
mask = {'weight': base_mask.detach(), 'bias': base_mask.clone().detach()}
try: try:
filters = weight.size(0)
num_prune = int(filters * config.get('sparsity'))
if filters < 2 or num_prune < 1:
return mask
w_abs = weight.abs() w_abs = weight.abs()
mask = torch.gt(w_abs, self.global_threshold).type_as(weight) mask_weight = torch.gt(w_abs, self.global_threshold).type_as(weight)
mask_bias = mask_weight.clone()
mask = {'weight': mask_weight.detach(), 'bias': mask_bias.detach()}
finally: finally:
self.mask_dict.update({layer.name: mask}) self.mask_dict.update({layer.name: mask})
self.mask_calculated_ops.add(layer.name) self.mask_calculated_ops.add(layer.name)
...@@ -230,7 +240,7 @@ class SlimPruner(Pruner): ...@@ -230,7 +240,7 @@ class SlimPruner(Pruner):
return mask return mask
class RankFilterPruner(Pruner): class WeightRankFilterPruner(Pruner):
""" """
A structured pruning base class that prunes the filters with the smallest A structured pruning base class that prunes the filters with the smallest
importance criterion in convolution layers to achieve a preset level of network sparsity. importance criterion in convolution layers to achieve a preset level of network sparsity.
...@@ -248,10 +258,10 @@ class RankFilterPruner(Pruner): ...@@ -248,10 +258,10 @@ class RankFilterPruner(Pruner):
""" """
super().__init__(model, config_list) super().__init__(model, config_list)
self.mask_calculated_ops = set() self.mask_calculated_ops = set() # operations whose mask has been calculated
def _get_mask(self, base_mask, weight, num_prune): def _get_mask(self, base_mask, weight, num_prune):
return torch.ones(weight.size()).type_as(weight) return {'weight': None, 'bias': None}
def calc_mask(self, layer, config): def calc_mask(self, layer, config):
""" """
...@@ -265,20 +275,25 @@ class RankFilterPruner(Pruner): ...@@ -265,20 +275,25 @@ class RankFilterPruner(Pruner):
layer's pruning config layer's pruning config
Returns Returns
------- -------
torch.Tensor dict
mask of the layer's weight dictionary for storing masks
""" """
weight = layer.module.weight.data weight = layer.module.weight.data
op_name = layer.name op_name = layer.name
op_type = layer.type op_type = layer.type
assert 0 <= config.get('sparsity') < 1 assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)"
assert op_type in ['Conv1d', 'Conv2d'] assert op_type in ['Conv1d', 'Conv2d'], "only support Conv1d and Conv2d"
assert op_type in config.get('op_types') assert op_type in config.get('op_types')
if op_name in self.mask_calculated_ops: if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict assert op_name in self.mask_dict
return self.mask_dict.get(op_name) return self.mask_dict.get(op_name)
mask = torch.ones(weight.size()).type_as(weight) mask_weight = torch.ones(weight.size()).type_as(weight).detach()
if hasattr(layer.module, 'bias') and layer.module.bias is not None:
mask_bias = torch.ones(layer.module.bias.size()).type_as(layer.module.bias).detach()
else:
mask_bias = None
mask = {'weight': mask_weight, 'bias': mask_bias}
try: try:
filters = weight.size(0) filters = weight.size(0)
num_prune = int(filters * config.get('sparsity')) num_prune = int(filters * config.get('sparsity'))
...@@ -288,10 +303,10 @@ class RankFilterPruner(Pruner): ...@@ -288,10 +303,10 @@ class RankFilterPruner(Pruner):
finally: finally:
self.mask_dict.update({op_name: mask}) self.mask_dict.update({op_name: mask})
self.mask_calculated_ops.add(op_name) self.mask_calculated_ops.add(op_name)
return mask.detach() return mask
class L1FilterPruner(RankFilterPruner): class L1FilterPruner(WeightRankFilterPruner):
""" """
A structured pruning algorithm that prunes the filters of smallest magnitude A structured pruning algorithm that prunes the filters of smallest magnitude
weights sum in the convolution layers to achieve a preset level of network sparsity. weights sum in the convolution layers to achieve a preset level of network sparsity.
...@@ -319,31 +334,33 @@ class L1FilterPruner(RankFilterPruner): ...@@ -319,31 +334,33 @@ class L1FilterPruner(RankFilterPruner):
Filters with the smallest sum of its absolute kernel weights are masked. Filters with the smallest sum of its absolute kernel weights are masked.
Parameters Parameters
---------- ----------
base_mask : torch.Tensor base_mask : dict
The basic mask with the same shape of weight, all item in the basic mask is 1. The basic mask with the same shape of weight or bias, all item in the basic mask is 1.
weight : torch.Tensor weight : torch.Tensor
Layer's weight Layer's weight
num_prune : int num_prune : int
Num of filters to prune Num of filters to prune
Returns Returns
------- -------
torch.Tensor dict
Mask of the layer's weight dictionary for storing masks
""" """
filters = weight.shape[0] filters = weight.shape[0]
w_abs = weight.abs() w_abs = weight.abs()
w_abs_structured = w_abs.view(filters, -1).sum(dim=1) w_abs_structured = w_abs.view(filters, -1).sum(dim=1)
threshold = torch.topk(w_abs_structured.view(-1), num_prune, largest=False)[0].max() threshold = torch.topk(w_abs_structured.view(-1), num_prune, largest=False)[0].max()
mask = torch.gt(w_abs_structured, threshold)[:, None, None, None].expand_as(weight).type_as(weight) mask_weight = torch.gt(w_abs_structured, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
mask_bias = torch.gt(w_abs_structured, threshold).type_as(weight)
return mask return {'weight': mask_weight.detach(), 'bias': mask_bias.detach()}
class L2FilterPruner(RankFilterPruner): class L2FilterPruner(WeightRankFilterPruner):
""" """
A structured pruning algorithm that prunes the filters with the A structured pruning algorithm that prunes the filters with the
smallest L2 norm of the absolute kernel weights are masked. smallest L2 norm of the weights.
""" """
def __init__(self, model, config_list): def __init__(self, model, config_list):
...@@ -365,27 +382,28 @@ class L2FilterPruner(RankFilterPruner): ...@@ -365,27 +382,28 @@ class L2FilterPruner(RankFilterPruner):
Filters with the smallest L2 norm of the absolute kernel weights are masked. Filters with the smallest L2 norm of the absolute kernel weights are masked.
Parameters Parameters
---------- ----------
base_mask : torch.Tensor base_mask : dict
The basic mask with the same shape of weight, all item in the basic mask is 1. The basic mask with the same shape of weight or bias, all item in the basic mask is 1.
weight : torch.Tensor weight : torch.Tensor
Layer's weight Layer's weight
num_prune : int num_prune : int
Num of filters to prune Num of filters to prune
Returns Returns
------- -------
torch.Tensor dict
Mask of the layer's weight dictionary for storing masks
""" """
filters = weight.shape[0] filters = weight.shape[0]
w = weight.view(filters, -1) w = weight.view(filters, -1)
w_l2_norm = torch.sqrt((w ** 2).sum(dim=1)) w_l2_norm = torch.sqrt((w ** 2).sum(dim=1))
threshold = torch.topk(w_l2_norm.view(-1), num_prune, largest=False)[0].max() threshold = torch.topk(w_l2_norm.view(-1), num_prune, largest=False)[0].max()
mask = torch.gt(w_l2_norm, threshold)[:, None, None, None].expand_as(weight).type_as(weight) mask_weight = torch.gt(w_l2_norm, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
mask_bias = torch.gt(w_l2_norm, threshold).type_as(weight)
return mask return {'weight': mask_weight.detach(), 'bias': mask_bias.detach()}
class FPGMPruner(RankFilterPruner): class FPGMPruner(WeightRankFilterPruner):
""" """
A filter pruner via geometric median. A filter pruner via geometric median.
"Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration", "Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration",
...@@ -410,20 +428,22 @@ class FPGMPruner(RankFilterPruner): ...@@ -410,20 +428,22 @@ class FPGMPruner(RankFilterPruner):
Filters with the smallest sum of its absolute kernel weights are masked. Filters with the smallest sum of its absolute kernel weights are masked.
Parameters Parameters
---------- ----------
base_mask : torch.Tensor base_mask : dict
The basic mask with the same shape of weight, all item in the basic mask is 1. The basic mask with the same shape of weight and bias, all item in the basic mask is 1.
weight : torch.Tensor weight : torch.Tensor
Layer's weight Layer's weight
num_prune : int num_prune : int
Num of filters to prune Num of filters to prune
Returns Returns
------- -------
torch.Tensor dict
Mask of the layer's weight dictionary for storing masks
""" """
min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune) min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune)
for idx in min_gm_idx: for idx in min_gm_idx:
base_mask[idx] = 0. base_mask['weight'][idx] = 0.
if base_mask['bias'] is not None:
base_mask['bias'][idx] = 0.
return base_mask return base_mask
def _get_min_gm_kernel_idx(self, weight, n): def _get_min_gm_kernel_idx(self, weight, n):
...@@ -471,3 +491,251 @@ class FPGMPruner(RankFilterPruner): ...@@ -471,3 +491,251 @@ class FPGMPruner(RankFilterPruner):
def update_epoch(self, epoch): def update_epoch(self, epoch):
self.mask_calculated_ops = set() self.mask_calculated_ops = set()
class ActivationRankFilterPruner(Pruner):
"""
A structured pruning base class that prunes the filters with the smallest
importance criterion in convolution layers to achieve a preset level of network sparsity.
Hengyuan Hu, Rui Peng, Yu-Wing Tai and Chi-Keung Tang,
"Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient Deep Architectures", ICLR 2016.
https://arxiv.org/abs/1607.03250
Pavlo Molchanov, Stephen Tyree, Tero Karras, Timo Aila and Jan Kautz,
"Pruning Convolutional Neural Networks for Resource Efficient Inference", ICLR 2017.
https://arxiv.org/abs/1611.06440
"""
def __init__(self, model, config_list, activation='relu', statistics_batch_num=1):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
activation : str
Activation function
statistics_batch_num : int
Num of batches for activation statistics
"""
super().__init__(model, config_list)
self.mask_calculated_ops = set()
self.statistics_batch_num = statistics_batch_num
self.collected_activation = {}
self.hooks = {}
assert activation in ['relu', 'relu6']
if activation == 'relu':
self.activation = torch.nn.functional.relu
elif activation == 'relu6':
self.activation = torch.nn.functional.relu6
else:
self.activation = None
def compress(self):
"""
Compress the model, register a hook for collecting activations.
"""
modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress:
self._instrument_layer(layer, config)
self.collected_activation[layer.name] = []
def _hook(module_, input_, output, name=layer.name):
if len(self.collected_activation[name]) < self.statistics_batch_num:
self.collected_activation[name].append(self.activation(output.detach().cpu()))
layer.module.register_forward_hook(_hook)
return self.bound_model
def _get_mask(self, base_mask, activations, num_prune):
return {'weight': None, 'bias': None}
def calc_mask(self, layer, config):
"""
Calculate the mask of given layer.
Filters with the smallest importance criterion which is calculated from the activation are masked.
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
Returns
-------
dict
dictionary for storing masks
"""
weight = layer.module.weight.data
op_name = layer.name
op_type = layer.type
assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)"
assert op_type in ['Conv2d'], "only support Conv2d"
assert op_type in config.get('op_types')
if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict
return self.mask_dict.get(op_name)
mask_weight = torch.ones(weight.size()).type_as(weight).detach()
if hasattr(layer.module, 'bias') and layer.module.bias is not None:
mask_bias = torch.ones(layer.module.bias.size()).type_as(layer.module.bias).detach()
else:
mask_bias = None
mask = {'weight': mask_weight, 'bias': mask_bias}
try:
filters = weight.size(0)
num_prune = int(filters * config.get('sparsity'))
if filters < 2 or num_prune < 1 or len(self.collected_activation[layer.name]) < self.statistics_batch_num:
return mask
mask = self._get_mask(mask, self.collected_activation[layer.name], num_prune)
finally:
if len(self.collected_activation[layer.name]) == self.statistics_batch_num:
self.mask_dict.update({op_name: mask})
self.mask_calculated_ops.add(op_name)
return mask
class ActivationAPoZRankFilterPruner(ActivationRankFilterPruner):
"""
A structured pruning algorithm that prunes the filters with the
smallest APoZ(average percentage of zeros) of output activations.
Hengyuan Hu, Rui Peng, Yu-Wing Tai and Chi-Keung Tang,
"Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient Deep Architectures", ICLR 2016.
https://arxiv.org/abs/1607.03250
"""
def __init__(self, model, config_list, activation='relu', statistics_batch_num=1):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
activation : str
Activation function
statistics_batch_num : int
Num of batches for activation statistics
"""
super().__init__(model, config_list, activation, statistics_batch_num)
def _get_mask(self, base_mask, activations, num_prune):
"""
Calculate the mask of given layer.
Filters with the smallest APoZ(average percentage of zeros) of output activations are masked.
Parameters
----------
base_mask : dict
The basic mask with the same shape of weight, all item in the basic mask is 1.
activations : list
Layer's output activations
num_prune : int
Num of filters to prune
Returns
-------
dict
dictionary for storing masks
"""
apoz = self._calc_apoz(activations)
prune_indices = torch.argsort(apoz, descending=True)[:num_prune]
for idx in prune_indices:
base_mask['weight'][idx] = 0.
if base_mask['bias'] is not None:
base_mask['bias'][idx] = 0.
return base_mask
def _calc_apoz(self, activations):
"""
Calculate APoZ(average percentage of zeros) of activations.
Parameters
----------
activations : list
Layer's output activations
Returns
-------
torch.Tensor
Filter's APoZ(average percentage of zeros) of the activations
"""
activations = torch.cat(activations, 0)
_eq_zero = torch.eq(activations, torch.zeros_like(activations))
_apoz = torch.sum(_eq_zero, dim=(0, 2, 3)) / torch.numel(_eq_zero[:, 0, :, :])
return _apoz
class ActivationMeanRankFilterPruner(ActivationRankFilterPruner):
"""
A structured pruning algorithm that prunes the filters with the
smallest mean value of output activations.
Pavlo Molchanov, Stephen Tyree, Tero Karras, Timo Aila and Jan Kautz,
"Pruning Convolutional Neural Networks for Resource Efficient Inference", ICLR 2017.
https://arxiv.org/abs/1611.06440
"""
def __init__(self, model, config_list, activation='relu', statistics_batch_num=1):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
activation : str
Activation function
statistics_batch_num : int
Num of batches for activation statistics
"""
super().__init__(model, config_list, activation, statistics_batch_num)
def _get_mask(self, base_mask, activations, num_prune):
"""
Calculate the mask of given layer.
Filters with the smallest APoZ(average percentage of zeros) of output activations are masked.
Parameters
----------
base_mask : dict
The basic mask with the same shape of weight, all item in the basic mask is 1.
activations : list
Layer's output activations
num_prune : int
Num of filters to prune
Returns
-------
dict
dictionary for storing masks
"""
mean_activation = self._cal_mean_activation(activations)
prune_indices = torch.argsort(mean_activation)[:num_prune]
for idx in prune_indices:
base_mask['weight'][idx] = 0.
if base_mask['bias'] is not None:
base_mask['bias'][idx] = 0.
return base_mask
def _cal_mean_activation(self, activations):
"""
Calculate mean value of activations.
Parameters
----------
activations : list
Layer's output activations
Returns
-------
torch.Tensor
Filter's mean value of the output activations
"""
activations = torch.cat(activations, 0)
mean_activation = torch.mean(activations, dim=(0, 2, 3))
return mean_activation
...@@ -16,6 +16,7 @@ class LayerInfo: ...@@ -16,6 +16,7 @@ class LayerInfo:
self._forward = None self._forward = None
class Compressor: class Compressor:
""" """
Abstract base PyTorch compressor Abstract base PyTorch compressor
...@@ -193,10 +194,16 @@ class Pruner(Compressor): ...@@ -193,10 +194,16 @@ class Pruner(Compressor):
layer._forward = layer.module.forward layer._forward = layer.module.forward
def new_forward(*inputs): def new_forward(*inputs):
mask = self.calc_mask(layer, config)
# apply mask to weight # apply mask to weight
old_weight = layer.module.weight.data old_weight = layer.module.weight.data
mask = self.calc_mask(layer, config) mask_weight = mask['weight']
layer.module.weight.data = old_weight.mul(mask) layer.module.weight.data = old_weight.mul(mask_weight)
# apply mask to bias
if mask.__contains__('bias') and hasattr(layer.module, 'bias') and layer.module.bias is not None:
old_bias = layer.module.bias.data
mask_bias = mask['bias']
layer.module.bias.data = old_bias.mul(mask_bias)
# calculate forward # calculate forward
ret = layer._forward(*inputs) ret = layer._forward(*inputs)
return ret return ret
...@@ -224,12 +231,14 @@ class Pruner(Compressor): ...@@ -224,12 +231,14 @@ class Pruner(Compressor):
for name, m in self.bound_model.named_modules(): for name, m in self.bound_model.named_modules():
if name == "": if name == "":
continue continue
mask = self.mask_dict.get(name) masks = self.mask_dict.get(name)
if mask is not None: if masks is not None:
mask_sum = mask.sum().item() mask_sum = masks['weight'].sum().item()
mask_num = mask.numel() mask_num = masks['weight'].numel()
_logger.info('Layer: %s Sparsity: %.2f', name, 1 - mask_sum / mask_num) _logger.info('Layer: %s Sparsity: %.2f', name, 1 - mask_sum / mask_num)
m.weight.data = m.weight.data.mul(mask) m.weight.data = m.weight.data.mul(masks['weight'])
if masks.__contains__('bias') and hasattr(m, 'bias') and m.bias is not None:
m.bias.data = m.bias.data.mul(masks['bias'])
else: else:
_logger.info('Layer: %s NOT compressed', name) _logger.info('Layer: %s NOT compressed', name)
torch.save(self.bound_model.state_dict(), model_path) torch.save(self.bound_model.state_dict(), model_path)
...@@ -258,7 +267,6 @@ class Quantizer(Compressor): ...@@ -258,7 +267,6 @@ class Quantizer(Compressor):
""" """
quantize should overload this method to quantize weight. quantize should overload this method to quantize weight.
This method is effectively hooked to :meth:`forward` of the model. This method is effectively hooked to :meth:`forward` of the model.
Parameters Parameters
---------- ----------
weight : Tensor weight : Tensor
...@@ -272,7 +280,6 @@ class Quantizer(Compressor): ...@@ -272,7 +280,6 @@ class Quantizer(Compressor):
""" """
quantize should overload this method to quantize output. quantize should overload this method to quantize output.
This method is effectively hooked to :meth:`forward` of the model. This method is effectively hooked to :meth:`forward` of the model.
Parameters Parameters
---------- ----------
output : Tensor output : Tensor
...@@ -286,7 +293,6 @@ class Quantizer(Compressor): ...@@ -286,7 +293,6 @@ class Quantizer(Compressor):
""" """
quantize should overload this method to quantize input. quantize should overload this method to quantize input.
This method is effectively hooked to :meth:`forward` of the model. This method is effectively hooked to :meth:`forward` of the model.
Parameters Parameters
---------- ----------
inputs : Tensor inputs : Tensor
...@@ -300,7 +306,6 @@ class Quantizer(Compressor): ...@@ -300,7 +306,6 @@ class Quantizer(Compressor):
def _instrument_layer(self, layer, config): def _instrument_layer(self, layer, config):
""" """
Create a wrapper forward function to replace the original one. Create a wrapper forward function to replace the original one.
Parameters Parameters
---------- ----------
layer : LayerInfo layer : LayerInfo
...@@ -365,7 +370,6 @@ class QuantGrad(torch.autograd.Function): ...@@ -365,7 +370,6 @@ class QuantGrad(torch.autograd.Function):
""" """
This method should be overrided by subclass to provide customized backward function, This method should be overrided by subclass to provide customized backward function,
default implementation is Straight-Through Estimator default implementation is Straight-Through Estimator
Parameters Parameters
---------- ----------
tensor : Tensor tensor : Tensor
...@@ -375,7 +379,6 @@ class QuantGrad(torch.autograd.Function): ...@@ -375,7 +379,6 @@ class QuantGrad(torch.autograd.Function):
quant_type : QuantType quant_type : QuantType
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`, the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`,
you can define different behavior for different types. you can define different behavior for different types.
Returns Returns
------- -------
tensor tensor
...@@ -399,3 +402,4 @@ def _check_weight(module): ...@@ -399,3 +402,4 @@ def _check_weight(module):
return isinstance(module.weight.data, torch.Tensor) return isinstance(module.weight.data, torch.Tensor)
except AttributeError: except AttributeError:
return False return False
\ No newline at end of file
...@@ -17,6 +17,7 @@ class LotteryTicketPruner(Pruner): ...@@ -17,6 +17,7 @@ class LotteryTicketPruner(Pruner):
4. Reset the remaining parameters to their values in theta_0, creating the winning ticket f(x;m*theta_0). 4. Reset the remaining parameters to their values in theta_0, creating the winning ticket f(x;m*theta_0).
5. Repeat step 2, 3, and 4. 5. Repeat step 2, 3, and 4.
""" """
def __init__(self, model, config_list, optimizer, lr_scheduler=None, reset_weights=True): def __init__(self, model, config_list, optimizer, lr_scheduler=None, reset_weights=True):
""" """
Parameters Parameters
...@@ -55,7 +56,8 @@ class LotteryTicketPruner(Pruner): ...@@ -55,7 +56,8 @@ class LotteryTicketPruner(Pruner):
assert 'prune_iterations' in config, 'prune_iterations must exist in your config' assert 'prune_iterations' in config, 'prune_iterations must exist in your config'
assert 'sparsity' in config, 'sparsity must exist in your config' assert 'sparsity' in config, 'sparsity must exist in your config'
if prune_iterations is not None: if prune_iterations is not None:
assert prune_iterations == config['prune_iterations'], 'The values of prune_iterations must be equal in your config' assert prune_iterations == config[
'prune_iterations'], 'The values of prune_iterations must be equal in your config'
prune_iterations = config['prune_iterations'] prune_iterations = config['prune_iterations']
return prune_iterations return prune_iterations
...@@ -67,8 +69,8 @@ class LotteryTicketPruner(Pruner): ...@@ -67,8 +69,8 @@ class LotteryTicketPruner(Pruner):
if print_mask: if print_mask:
print('mask: ', mask) print('mask: ', mask)
# calculate current sparsity # calculate current sparsity
mask_num = mask.sum().item() mask_num = mask['weight'].sum().item()
mask_size = mask.numel() mask_size = mask['weight'].numel()
print('sparsity: ', 1 - mask_num / mask_size) print('sparsity: ', 1 - mask_num / mask_size)
torch.set_printoptions(profile='default') torch.set_printoptions(profile='default')
...@@ -84,11 +86,11 @@ class LotteryTicketPruner(Pruner): ...@@ -84,11 +86,11 @@ class LotteryTicketPruner(Pruner):
curr_sparsity = self._calc_sparsity(sparsity) curr_sparsity = self._calc_sparsity(sparsity)
assert self.mask_dict.get(op_name) is not None assert self.mask_dict.get(op_name) is not None
curr_mask = self.mask_dict.get(op_name) curr_mask = self.mask_dict.get(op_name)
w_abs = weight.abs() * curr_mask w_abs = weight.abs() * curr_mask['weight']
k = int(w_abs.numel() * curr_sparsity) k = int(w_abs.numel() * curr_sparsity)
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max() threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
mask = torch.gt(w_abs, threshold).type_as(weight) mask = torch.gt(w_abs, threshold).type_as(weight)
return mask return {'weight': mask}
def calc_mask(self, layer, config): def calc_mask(self, layer, config):
""" """
......
...@@ -136,12 +136,12 @@ class CompressorTestCase(TestCase): ...@@ -136,12 +136,12 @@ class CompressorTestCase(TestCase):
model.conv2.weight.data = torch.tensor(w).float() model.conv2.weight.data = torch.tensor(w).float()
layer = torch_compressor.compressor.LayerInfo('conv2', model.conv2) layer = torch_compressor.compressor.LayerInfo('conv2', model.conv2)
masks = pruner.calc_mask(layer, config_list[0]) masks = pruner.calc_mask(layer, config_list[0])
assert all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.])) assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))
pruner.update_epoch(1) pruner.update_epoch(1)
model.conv2.weight.data = torch.tensor(w).float() model.conv2.weight.data = torch.tensor(w).float()
masks = pruner.calc_mask(layer, config_list[1]) masks = pruner.calc_mask(layer, config_list[1])
assert all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.])) assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.]))
@tf2 @tf2
def test_tf_fpgm_pruner(self): def test_tf_fpgm_pruner(self):
...@@ -190,8 +190,8 @@ class CompressorTestCase(TestCase): ...@@ -190,8 +190,8 @@ class CompressorTestCase(TestCase):
mask1 = pruner.calc_mask(layer1, config_list[0]) mask1 = pruner.calc_mask(layer1, config_list[0])
layer2 = torch_compressor.compressor.LayerInfo('conv2', model.conv2) layer2 = torch_compressor.compressor.LayerInfo('conv2', model.conv2)
mask2 = pruner.calc_mask(layer2, config_list[1]) mask2 = pruner.calc_mask(layer2, config_list[1])
assert all(torch.sum(mask1, (1, 2, 3)).numpy() == np.array([0., 27., 27., 27., 27.])) assert all(torch.sum(mask1['weight'], (1, 2, 3)).numpy() == np.array([0., 27., 27., 27., 27.]))
assert all(torch.sum(mask2, (1, 2, 3)).numpy() == np.array([0., 0., 0., 27., 27.])) assert all(torch.sum(mask2['weight'], (1, 2, 3)).numpy() == np.array([0., 0., 0., 27., 27.]))
def test_torch_slim_pruner(self): def test_torch_slim_pruner(self):
""" """
...@@ -218,8 +218,10 @@ class CompressorTestCase(TestCase): ...@@ -218,8 +218,10 @@ class CompressorTestCase(TestCase):
mask1 = pruner.calc_mask(layer1, config_list[0]) mask1 = pruner.calc_mask(layer1, config_list[0])
layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2) layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2)
mask2 = pruner.calc_mask(layer2, config_list[0]) mask2 = pruner.calc_mask(layer2, config_list[0])
assert all(mask1.numpy() == np.array([0., 1., 1., 1., 1.])) assert all(mask1['weight'].numpy() == np.array([0., 1., 1., 1., 1.]))
assert all(mask2.numpy() == np.array([0., 1., 1., 1., 1.])) assert all(mask2['weight'].numpy() == np.array([0., 1., 1., 1., 1.]))
assert all(mask1['bias'].numpy() == np.array([0., 1., 1., 1., 1.]))
assert all(mask2['bias'].numpy() == np.array([0., 1., 1., 1., 1.]))
config_list = [{'sparsity': 0.6, 'op_types': ['BatchNorm2d']}] config_list = [{'sparsity': 0.6, 'op_types': ['BatchNorm2d']}]
model.bn1.weight.data = torch.tensor(w).float() model.bn1.weight.data = torch.tensor(w).float()
...@@ -230,8 +232,10 @@ class CompressorTestCase(TestCase): ...@@ -230,8 +232,10 @@ class CompressorTestCase(TestCase):
mask1 = pruner.calc_mask(layer1, config_list[0]) mask1 = pruner.calc_mask(layer1, config_list[0])
layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2) layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2)
mask2 = pruner.calc_mask(layer2, config_list[0]) mask2 = pruner.calc_mask(layer2, config_list[0])
assert all(mask1.numpy() == np.array([0., 0., 0., 1., 1.])) assert all(mask1['weight'].numpy() == np.array([0., 0., 0., 1., 1.]))
assert all(mask2.numpy() == np.array([0., 0., 0., 1., 1.])) assert all(mask2['weight'].numpy() == np.array([0., 0., 0., 1., 1.]))
assert all(mask1['bias'].numpy() == np.array([0., 0., 0., 1., 1.]))
assert all(mask2['bias'].numpy() == np.array([0., 0., 0., 1., 1.]))
def test_torch_QAT_quantizer(self): def test_torch_QAT_quantizer(self):
model = TorchModel() model = TorchModel()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment