Unverified Commit ec5af41f authored by Ningxin Zheng's avatar Ningxin Zheng Committed by GitHub
Browse files

Constraint-aware one-shot pruners (#2657)

parent a4802083
# Dependency-aware Mode for Filter Pruning
Currently, we have several filter pruning algorithm for the convolutional layers: FPGM Pruner, L1Filter Pruner, L2Filter Pruner, Activation APoZ Rank Filter Pruner, Activation Mean Rank Filter Pruner, Taylor FO On Weight Pruner. In these filter pruning algorithms, the pruner will prune each convolutional layer separately. While pruning a convolution layer, the algorithm will quantify the importance of each filter based on some specific rules(such as l1-norm), and prune the less important filters.
As [dependency analysis utils](./CompressionUtils.md) shows, if the output channels of two convolutional layers(conv1, conv2) are added together, then these two conv layers have channel dependency with each other(more details please see [Compression Utils](./CompressionUtils.md)). Take the following figure as an example.
![](../../img/mask_conflict.jpg)
If we prune the first 50% of output channels(filters) for conv1, and prune the last 50% of output channels for conv2. Although both layers have pruned 50% of the filters, the speedup module still needs to add zeros to align the output channels. In this case, we cannot harvest the speed benefit from the model pruning.
To better gain the speed benefit of the model pruning, we add a dependency-aware mode for the Filter Pruner. In the dependency-aware mode, the pruner prunes the model not only based on the l1 norm of each filter, but also the topology of the whole network architecture.
In the dependency-aware mode(`dependency_aware` is set `True`), the pruner will try to prune the same output channels for the layers that have the channel dependencies with each other, as shown in the following figure.
![](../../img/dependency-aware.jpg)
Take the dependency-aware mode of L1Filter Pruner as an example. Specifically, the pruner will calculate the L1 norm (for example) sum of all the layers in the dependency set for each channel. Obviously, the number of channels that can actually be pruned of this dependency set in the end is determined by the minimum sparsity of layers in this dependency set(denoted by `min_sparsity`). According to the L1 norm sum of each channel, the pruner will prune the same `min_sparsity` channels for all the layers. Next, the pruner will additionally prune `sparsity` - `min_sparsity` channels for each convolutional layer based on its own L1 norm of each channel. For example, suppose the output channels of `conv1` , `conv2` are added together and the configured sparsities of `conv1` and `conv2` are 0.3, 0.2 respectively. In this case, the `dependency-aware pruner` will
- First, prune the same 20% of channels for `conv1` and `conv2` according to L1 norm sum of `conv1` and `conv2`.
- Second, the pruner will additionally prune 10% channels for `conv1` according to the L1 norm of each channel of `conv1`.
In addition, for the convolutional layers that have more than one filter group, `dependency-aware pruner` will also try to prune the same number of the channels for each filter group. Overall, this pruner will prune the model according to the L1 norm of each filter and try to meet the topological constrains(channel dependency, etc) to improve the final speed gain after the speedup process.
In the dependency-aware mode, the pruner will provide a better speed gain from the model pruning.
## Usage
In this section, we will show how to enable the dependency-aware mode for the filter pruner. Currently, only the one-shot pruners such as FPGM Pruner, L1Filter Pruner, L2Filter Pruner, Activation APoZ Rank Filter Pruner, Activation Mean Rank Filter Pruner, Taylor FO On Weight Pruner, support the dependency-aware mode.
To enable the dependency-aware mode for `L1FilterPruner`:
```python
from nni.compression.torch import L1FilterPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
# dummy_input is necessary for the dependency_aware mode
dummy_input = torch.ones(1, 3, 224, 224).cuda()
pruner = L1FilterPruner(model, config_list, dependency_aware=True, dummy_input=dummy_input)
# for L2FilterPruner
# pruner = L2FilterPruner(model, config_list, dependency_aware=True, dummy_input=dummy_input)
# for FPGMPruner
# pruner = FPGMPruner(model, config_list, dependency_aware=True, dummy_input=dummy_input)
# for ActivationAPoZRankFilterPruner
# pruner = ActivationAPoZRankFilterPruner(model, config_list, statistics_batch_num=1, , dependency_aware=True, dummy_input=dummy_input)
# for ActivationMeanRankFilterPruner
# pruner = ActivationMeanRankFilterPruner(model, config_list, statistics_batch_num=1, dependency_aware=True, dummy_input=dummy_input)
# for TaylorFOWeightFilterPruner
# pruner = TaylorFOWeightFilterPruner(model, config_list, statistics_batch_num=1, dependency_aware=True, dummy_input=dummy_input)
pruner.compress()
```
## Evaluation
In order to compare the performance of the pruner with or without the dependency-aware mode, we use L1FilterPruner to prune the Mobilenet_v2 separately when the dependency-aware mode is turned on and off. To simplify the experiment, we use the uniform pruning which means we allocate the same sparsity for all convolutional layers in the model.
We trained a Mobilenet_v2 model on the cifar10 dataset and prune the model based on this pretrained checkpoint. The following figure shows the accuracy and FLOPs of the model pruned by different pruners.
![](../../img/mobilev2_l1_cifar.jpg)
In the figure, the `Dependency-aware` represents the L1FilterPruner with dependency-aware mode enabled. `L1 Filter` is the normal `L1FilterPruner` without the dependency-aware mode, and the `No-Dependency` means pruner only prunes the layers that has no channel dependency with other layers. As we can see in the figure, when the dependency-aware mode enabled, the pruner can bring higher accuracy under the same Flops.
\ No newline at end of file
...@@ -114,7 +114,9 @@ FPGMPruner prune filters with the smallest geometric median. ...@@ -114,7 +114,9 @@ FPGMPruner prune filters with the smallest geometric median.
![](../../img/fpgm_fig1.png) ![](../../img/fpgm_fig1.png)
>Previous works utilized “smaller-norm-less-important” criterion to prune filters with smaller norm values in a convolutional neural network. In this paper, we analyze this norm-based criterion and point out that its effectiveness depends on two requirements that are not always met: (1) the norm deviation of the filters should be large; (2) the minimum norm of the filters should be small. To solve this problem, we propose a novel filter pruning method, namely Filter Pruning via Geometric Median (FPGM), to compress the model regardless of those two requirements. Unlike previous methods, FPGM compresses CNN models by pruning filters with redundancy, rather than those with “relatively less” importance. >Previous works utilized “smaller-norm-less-important” criterion to prune filters with smaller norm values in a convolutional neural network. In this paper, we analyze this norm-based criterion and point out that its effectiveness depends on two requirements that are not always met: (1) the norm deviation of the filters should be large; (2) the minimum norm of the filters should be small. To solve this problem, we propose a novel filter pruning method, namely Filter Pruning via Geometric Median (FPGM), to compress the model regardless of those two requirements. Unlike previous methods, FPGM compresses CNN models by pruning filters with redundancy, rather than those with “relatively less” importance.
We also provide a dependency-aware mode for this pruner to get better speedup from the pruning. Please reference [dependency-aware](./DependencyAware.md) for more details.
### Usage ### Usage
...@@ -154,6 +156,8 @@ This is an one-shot pruner, In ['PRUNING FILTERS FOR EFFICIENT CONVNETS'](https: ...@@ -154,6 +156,8 @@ This is an one-shot pruner, In ['PRUNING FILTERS FOR EFFICIENT CONVNETS'](https:
> 4. A new kernel matrix is created for both the ![](http://latex.codecogs.com/gif.latex?i)th and ![](http://latex.codecogs.com/gif.latex?i+1)th layers, and the remaining kernel > 4. A new kernel matrix is created for both the ![](http://latex.codecogs.com/gif.latex?i)th and ![](http://latex.codecogs.com/gif.latex?i+1)th layers, and the remaining kernel
> weights are copied to the new model. > weights are copied to the new model.
In addition, we also provide a dependency-aware mode for the L1FilterPruner. For more details about the dependency-aware mode, please reference [dependency-aware mode](./DependencyAware.md).
### Usage ### Usage
PyTorch code PyTorch code
...@@ -189,6 +193,8 @@ The experiments code can be found at [examples/model_compress]( https://github.c ...@@ -189,6 +193,8 @@ The experiments code can be found at [examples/model_compress]( https://github.c
This is a structured pruning algorithm that prunes the filters with the smallest L2 norm of the weights. It is implemented as a one-shot pruner. This is a structured pruning algorithm that prunes the filters with the smallest L2 norm of the weights. It is implemented as a one-shot pruner.
We also provide a dependency-aware mode for this pruner to get better speedup from the pruning. Please reference [dependency-aware](./DependencyAware.md) for more details.
### Usage ### Usage
PyTorch code PyTorch code
...@@ -200,6 +206,7 @@ pruner = L2FilterPruner(model, config_list) ...@@ -200,6 +206,7 @@ pruner = L2FilterPruner(model, config_list)
pruner.compress() pruner.compress()
``` ```
### User configuration for L2Filter Pruner ### User configuration for L2Filter Pruner
##### PyTorch ##### PyTorch
...@@ -208,6 +215,7 @@ pruner.compress() ...@@ -208,6 +215,7 @@ pruner.compress()
``` ```
*** ***
## ActivationAPoZRankFilter Pruner ## ActivationAPoZRankFilter Pruner
ActivationAPoZRankFilter Pruner is a pruner which prunes the filters with the smallest importance criterion `APoZ` calculated from the output activations of convolution layers to achieve a preset level of network sparsity. The pruning criterion `APoZ` is explained in the paper [Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient Deep Architectures](https://arxiv.org/abs/1607.03250). ActivationAPoZRankFilter Pruner is a pruner which prunes the filters with the smallest importance criterion `APoZ` calculated from the output activations of convolution layers to achieve a preset level of network sparsity. The pruning criterion `APoZ` is explained in the paper [Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient Deep Architectures](https://arxiv.org/abs/1607.03250).
...@@ -216,6 +224,8 @@ The APoZ is defined as: ...@@ -216,6 +224,8 @@ The APoZ is defined as:
![](../../img/apoz.png) ![](../../img/apoz.png)
We also provide a dependency-aware mode for this pruner to get better speedup from the pruning. Please reference [dependency-aware](./DependencyAware.md) for more details.
### Usage ### Usage
PyTorch code PyTorch code
...@@ -234,6 +244,8 @@ Note: ActivationAPoZRankFilterPruner is used to prune convolutional layers withi ...@@ -234,6 +244,8 @@ Note: ActivationAPoZRankFilterPruner is used to prune convolutional layers withi
You can view [example](https://github.com/microsoft/nni/blob/master/examples/model_compress/model_prune_torch.py) for more information. You can view [example](https://github.com/microsoft/nni/blob/master/examples/model_compress/model_prune_torch.py) for more information.
### User configuration for ActivationAPoZRankFilter Pruner ### User configuration for ActivationAPoZRankFilter Pruner
##### PyTorch ##### PyTorch
...@@ -247,6 +259,8 @@ You can view [example](https://github.com/microsoft/nni/blob/master/examples/mod ...@@ -247,6 +259,8 @@ You can view [example](https://github.com/microsoft/nni/blob/master/examples/mod
ActivationMeanRankFilterPruner is a pruner which prunes the filters with the smallest importance criterion `mean activation` calculated from the output activations of convolution layers to achieve a preset level of network sparsity. The pruning criterion `mean activation` is explained in section 2.2 of the paper[Pruning Convolutional Neural Networks for Resource Efficient Inference](https://arxiv.org/abs/1611.06440). Other pruning criteria mentioned in this paper will be supported in future release. ActivationMeanRankFilterPruner is a pruner which prunes the filters with the smallest importance criterion `mean activation` calculated from the output activations of convolution layers to achieve a preset level of network sparsity. The pruning criterion `mean activation` is explained in section 2.2 of the paper[Pruning Convolutional Neural Networks for Resource Efficient Inference](https://arxiv.org/abs/1611.06440). Other pruning criteria mentioned in this paper will be supported in future release.
We also provide a dependency-aware mode for this pruner to get better speedup from the pruning. Please reference [dependency-aware](./DependencyAware.md) for more details.
### Usage ### Usage
PyTorch code PyTorch code
...@@ -265,6 +279,7 @@ Note: ActivationMeanRankFilterPruner is used to prune convolutional layers withi ...@@ -265,6 +279,7 @@ Note: ActivationMeanRankFilterPruner is used to prune convolutional layers withi
You can view [example](https://github.com/microsoft/nni/blob/master/examples/model_compress/model_prune_torch.py) for more information. You can view [example](https://github.com/microsoft/nni/blob/master/examples/model_compress/model_prune_torch.py) for more information.
### User configuration for ActivationMeanRankFilterPruner ### User configuration for ActivationMeanRankFilterPruner
##### PyTorch ##### PyTorch
...@@ -273,6 +288,7 @@ You can view [example](https://github.com/microsoft/nni/blob/master/examples/mod ...@@ -273,6 +288,7 @@ You can view [example](https://github.com/microsoft/nni/blob/master/examples/mod
``` ```
*** ***
## TaylorFOWeightFilter Pruner ## TaylorFOWeightFilter Pruner
TaylorFOWeightFilter Pruner is a pruner which prunes convolutional layers based on estimated importance calculated from the first order taylor expansion on weights to achieve a preset level of network sparsity. The estimated importance of filters is defined as the paper [Importance Estimation for Neural Network Pruning](http://jankautz.com/publications/Importance4NNPruning_CVPR19.pdf). Other pruning criteria mentioned in this paper will be supported in future release. TaylorFOWeightFilter Pruner is a pruner which prunes convolutional layers based on estimated importance calculated from the first order taylor expansion on weights to achieve a preset level of network sparsity. The estimated importance of filters is defined as the paper [Importance Estimation for Neural Network Pruning](http://jankautz.com/publications/Importance4NNPruning_CVPR19.pdf). Other pruning criteria mentioned in this paper will be supported in future release.
...@@ -281,6 +297,8 @@ TaylorFOWeightFilter Pruner is a pruner which prunes convolutional layers based ...@@ -281,6 +297,8 @@ TaylorFOWeightFilter Pruner is a pruner which prunes convolutional layers based
![](../../img/importance_estimation_sum.png) ![](../../img/importance_estimation_sum.png)
We also provide a dependency-aware mode for this pruner to get better speedup from the pruning. Please reference [dependency-aware](./DependencyAware.md) for more details.
### Usage ### Usage
PyTorch code PyTorch code
......
...@@ -17,7 +17,7 @@ For details, please refer to the following tutorials: ...@@ -17,7 +17,7 @@ For details, please refer to the following tutorials:
Overview <Compressor/Overview> Overview <Compressor/Overview>
Quick Start <Compressor/QuickStart> Quick Start <Compressor/QuickStart>
Pruners <Compressor/Pruner> Pruning <pruning>
Quantizers <Compressor/Quantizer> Quantizers <Compressor/Quantizer>
Automatic Model Compression <Compressor/AutoCompression> Automatic Model Compression <Compressor/AutoCompression>
Model Speedup <Compressor/ModelSpeedup> Model Speedup <Compressor/ModelSpeedup>
......
#################
Pruning
#################
NNI provides several pruning algorithms that support fine-grained weight pruning and structural filter pruning.
It supports Tensorflow and PyTorch with unified interface.
For users to prune their models, they only need to add several lines in their code.
For the structural filter pruning, NNI also provides a dependency-aware mode. In the dependency-aware mode, the
filter pruner will get better speed gain after the speedup.
For details, please refer to the following tutorials:
.. toctree::
:maxdepth: 2
Pruners <Compressor/Pruner>
Dependency Aware Mode <Compressor/DependencyAware>
...@@ -48,7 +48,7 @@ prune_config = { ...@@ -48,7 +48,7 @@ prune_config = {
'dataset_name': 'mnist', 'dataset_name': 'mnist',
'model_name': 'naive', 'model_name': 'naive',
'pruner_class': FPGMPruner, 'pruner_class': FPGMPruner,
'config_list':[{ 'config_list': [{
'sparsity': 0.5, 'sparsity': 0.5,
'op_types': ['Conv2d'] 'op_types': ['Conv2d']
}] }]
...@@ -85,6 +85,7 @@ prune_config = { ...@@ -85,6 +85,7 @@ prune_config = {
} }
} }
def get_data_loaders(dataset_name='mnist', batch_size=128): def get_data_loaders(dataset_name='mnist', batch_size=128):
assert dataset_name in ['cifar10', 'mnist'] assert dataset_name in ['cifar10', 'mnist']
...@@ -98,20 +99,23 @@ def get_data_loaders(dataset_name='mnist', batch_size=128): ...@@ -98,20 +99,23 @@ def get_data_loaders(dataset_name='mnist', batch_size=128):
train_loader = DataLoader( train_loader = DataLoader(
ds_class( ds_class(
'./data', train=True, download=True, './data', train=True, download=True,
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(MEAN, STD)]) transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize(MEAN, STD)])
), ),
batch_size=batch_size, shuffle=True batch_size=batch_size, shuffle=True
) )
test_loader = DataLoader( test_loader = DataLoader(
ds_class( ds_class(
'./data', train=False, download=True, './data', train=False, download=True,
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(MEAN, STD)]) transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize(MEAN, STD)])
), ),
batch_size=batch_size, shuffle=False batch_size=batch_size, shuffle=False
) )
return train_loader, test_loader return train_loader, test_loader
class NaiveModel(torch.nn.Module): class NaiveModel(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -132,6 +136,7 @@ class NaiveModel(torch.nn.Module): ...@@ -132,6 +136,7 @@ class NaiveModel(torch.nn.Module):
x = self.fc2(x) x = self.fc2(x)
return x return x
def create_model(model_name='naive'): def create_model(model_name='naive'):
assert model_name in ['naive', 'vgg16', 'vgg19'] assert model_name in ['naive', 'vgg16', 'vgg19']
...@@ -142,10 +147,18 @@ def create_model(model_name='naive'): ...@@ -142,10 +147,18 @@ def create_model(model_name='naive'):
else: else:
return VGG(19) return VGG(19)
def create_pruner(model, pruner_name, optimizer=None):
def create_pruner(model, pruner_name, optimizer=None, dependency_aware=False, dummy_input=None):
pruner_class = prune_config[pruner_name]['pruner_class'] pruner_class = prune_config[pruner_name]['pruner_class']
config_list = prune_config[pruner_name]['config_list'] config_list = prune_config[pruner_name]['config_list']
return pruner_class(model, config_list, optimizer) kw_args = {}
if dependency_aware:
print('Enable the dependency_aware mode')
# note that, not all pruners support the dependency_aware mode
kw_args['dependency_aware'] = True
kw_args['dummy_input'] = dummy_input
pruner = pruner_class(model, config_list, optimizer, **kw_args)
return pruner
def train(model, device, train_loader, optimizer): def train(model, device, train_loader, optimizer):
model.train() model.train()
...@@ -157,7 +170,9 @@ def train(model, device, train_loader, optimizer): ...@@ -157,7 +170,9 @@ def train(model, device, train_loader, optimizer):
loss.backward() loss.backward()
optimizer.step() optimizer.step()
if batch_idx % 100 == 0: if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item())) print('{:2.0f}% Loss {}'.format(
100 * batch_idx / len(train_loader), loss.item()))
def test(model, device, test_loader): def test(model, device, test_loader):
model.eval() model.eval()
...@@ -167,7 +182,8 @@ def test(model, device, test_loader): ...@@ -167,7 +182,8 @@ def test(model, device, test_loader):
for data, target in test_loader: for data, target in test_loader:
data, target = data.to(device), target.to(device) data, target = data.to(device), target.to(device)
output = model(data) output = model(data)
test_loss += F.cross_entropy(output, target, reduction='sum').item() test_loss += F.cross_entropy(output,
target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True) pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item() correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset) test_loss /= len(test_loader.dataset)
...@@ -177,20 +193,25 @@ def test(model, device, test_loader): ...@@ -177,20 +193,25 @@ def test(model, device, test_loader):
test_loss, acc)) test_loss, acc))
return acc return acc
def main(args): def main(args):
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') device = torch.device(
'cuda') if torch.cuda.is_available() else torch.device('cpu')
os.makedirs(args.checkpoints_dir, exist_ok=True) os.makedirs(args.checkpoints_dir, exist_ok=True)
model_name = prune_config[args.pruner_name]['model_name'] model_name = prune_config[args.pruner_name]['model_name']
dataset_name = prune_config[args.pruner_name]['dataset_name'] dataset_name = prune_config[args.pruner_name]['dataset_name']
train_loader, test_loader = get_data_loaders(dataset_name, args.batch_size) train_loader, test_loader = get_data_loaders(dataset_name, args.batch_size)
dummy_input, _ = next(iter(train_loader))
dummy_input = dummy_input.to(device)
model = create_model(model_name).cuda() model = create_model(model_name).cuda()
if args.resume_from is not None and os.path.exists(args.resume_from): if args.resume_from is not None and os.path.exists(args.resume_from):
print('loading checkpoint {} ...'.format(args.resume_from)) print('loading checkpoint {} ...'.format(args.resume_from))
model.load_state_dict(torch.load(args.resume_from)) model.load_state_dict(torch.load(args.resume_from))
test(model, device, test_loader) test(model, device, test_loader)
else: else:
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4) optimizer = torch.optim.SGD(
model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
if args.multi_gpu and torch.cuda.device_count(): if args.multi_gpu and torch.cuda.device_count():
model = nn.DataParallel(model) model = nn.DataParallel(model)
...@@ -204,17 +225,21 @@ def main(args): ...@@ -204,17 +225,21 @@ def main(args):
print('start model pruning...') print('start model pruning...')
model_path = os.path.join(args.checkpoints_dir, 'pruned_{}_{}_{}.pth'.format(model_name, dataset_name, args.pruner_name)) model_path = os.path.join(args.checkpoints_dir, 'pruned_{}_{}_{}.pth'.format(
mask_path = os.path.join(args.checkpoints_dir, 'mask_{}_{}_{}.pth'.format(model_name, dataset_name, args.pruner_name)) model_name, dataset_name, args.pruner_name))
mask_path = os.path.join(args.checkpoints_dir, 'mask_{}_{}_{}.pth'.format(
model_name, dataset_name, args.pruner_name))
# pruner needs to be initialized from a model not wrapped by DataParallel # pruner needs to be initialized from a model not wrapped by DataParallel
if isinstance(model, nn.DataParallel): if isinstance(model, nn.DataParallel):
model = model.module model = model.module
optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4) optimizer_finetune = torch.optim.SGD(
model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
best_top1 = 0 best_top1 = 0
pruner = create_pruner(model, args.pruner_name, optimizer_finetune) pruner = create_pruner(model, args.pruner_name,
optimizer_finetune, args.dependency_aware, dummy_input)
model = pruner.compress() model = pruner.compress()
if args.multi_gpu and torch.cuda.device_count() > 1: if args.multi_gpu and torch.cuda.device_count() > 1:
...@@ -231,15 +256,23 @@ def main(args): ...@@ -231,15 +256,23 @@ def main(args):
# mask_path stores mask_dict of the pruned model # mask_path stores mask_dict of the pruned model
pruner.export_model(model_path=model_path, mask_path=mask_path) pruner.export_model(model_path=model_path, mask_path=mask_path)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--pruner_name", type=str, default="level", help="pruner name") parser.add_argument("--pruner_name", type=str,
default="level", help="pruner name")
parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--pretrain_epochs", type=int, default=10, help="training epochs before model pruning") parser.add_argument("--pretrain_epochs", type=int,
parser.add_argument("--prune_epochs", type=int, default=10, help="training epochs for model pruning") default=10, help="training epochs before model pruning")
parser.add_argument("--checkpoints_dir", type=str, default="./checkpoints", help="checkpoints directory") parser.add_argument("--prune_epochs", type=int, default=10,
parser.add_argument("--resume_from", type=str, default=None, help="pretrained model weights") help="training epochs for model pruning")
parser.add_argument("--multi_gpu", action="store_true", help="Use multiple GPUs for training") parser.add_argument("--checkpoints_dir", type=str,
default="./checkpoints", help="checkpoints directory")
parser.add_argument("--resume_from", type=str,
default=None, help="pretrained model weights")
parser.add_argument("--multi_gpu", action="store_true",
help="Use multiple GPUs for training")
parser.add_argument("--dependency_aware", action="store_true", default=False,
help="If enable the dependency_aware mode for the pruner")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
...@@ -13,4 +13,3 @@ from .admm_pruner import ADMMPruner ...@@ -13,4 +13,3 @@ from .admm_pruner import ADMMPruner
from .auto_compress_pruner import AutoCompressPruner from .auto_compress_pruner import AutoCompressPruner
from .sensitivity_pruner import SensitivityPruner from .sensitivity_pruner import SensitivityPruner
from .amc import AMCPruner from .amc import AMCPruner
...@@ -3,14 +3,19 @@ ...@@ -3,14 +3,19 @@
import logging import logging
from schema import And, Optional from schema import And, Optional
from nni._graph_utils import TorchModuleGraph
from nni.compression.torch.utils.shape_dependency import ChannelDependency, GroupDependency
from .constants import MASKER_DICT from .constants import MASKER_DICT
from ..utils.config_validation import CompressorSchema from ..utils.config_validation import CompressorSchema
from ..compressor import Pruner from ..compressor import Pruner
__all__ = ['LevelPruner', 'SlimPruner', 'L1FilterPruner', 'L2FilterPruner', 'FPGMPruner', \
'TaylorFOWeightFilterPruner', 'ActivationAPoZRankFilterPruner', 'ActivationMeanRankFilterPruner']
logger = logging.getLogger('torch pruner') __all__ = ['LevelPruner', 'SlimPruner', 'L1FilterPruner', 'L2FilterPruner', 'FPGMPruner',
'TaylorFOWeightFilterPruner', 'ActivationAPoZRankFilterPruner', 'ActivationMeanRankFilterPruner']
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class OneshotPruner(Pruner): class OneshotPruner(Pruner):
""" """
...@@ -35,7 +40,8 @@ class OneshotPruner(Pruner): ...@@ -35,7 +40,8 @@ class OneshotPruner(Pruner):
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
self.set_wrappers_attribute("if_calculated", False) self.set_wrappers_attribute("if_calculated", False)
self.masker = MASKER_DICT[pruning_algorithm](model, self, **algo_kwargs) self.masker = MASKER_DICT[pruning_algorithm](
model, self, **algo_kwargs)
def validate_config(self, model, config_list): def validate_config(self, model, config_list):
""" """
...@@ -75,7 +81,8 @@ class OneshotPruner(Pruner): ...@@ -75,7 +81,8 @@ class OneshotPruner(Pruner):
sparsity = wrapper.config['sparsity'] sparsity = wrapper.config['sparsity']
if not wrapper.if_calculated: if not wrapper.if_calculated:
masks = self.masker.calc_mask(sparsity=sparsity, wrapper=wrapper, wrapper_idx=wrapper_idx) masks = self.masker.calc_mask(
sparsity=sparsity, wrapper=wrapper, wrapper_idx=wrapper_idx)
# masker.calc_mask returns None means calc_mask is not calculated sucessfully, can try later # masker.calc_mask returns None means calc_mask is not calculated sucessfully, can try later
if masks is not None: if masks is not None:
...@@ -84,6 +91,7 @@ class OneshotPruner(Pruner): ...@@ -84,6 +91,7 @@ class OneshotPruner(Pruner):
else: else:
return None return None
class LevelPruner(OneshotPruner): class LevelPruner(OneshotPruner):
""" """
Parameters Parameters
...@@ -97,9 +105,11 @@ class LevelPruner(OneshotPruner): ...@@ -97,9 +105,11 @@ class LevelPruner(OneshotPruner):
optimizer: torch.optim.Optimizer optimizer: torch.optim.Optimizer
Optimizer used to train model Optimizer used to train model
""" """
def __init__(self, model, config_list, optimizer=None): def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, pruning_algorithm='level', optimizer=optimizer) super().__init__(model, config_list, pruning_algorithm='level', optimizer=optimizer)
class SlimPruner(OneshotPruner): class SlimPruner(OneshotPruner):
""" """
Parameters Parameters
...@@ -113,6 +123,7 @@ class SlimPruner(OneshotPruner): ...@@ -113,6 +123,7 @@ class SlimPruner(OneshotPruner):
optimizer: torch.optim.Optimizer optimizer: torch.optim.Optimizer
Optimizer used to train model Optimizer used to train model
""" """
def __init__(self, model, config_list, optimizer=None): def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, pruning_algorithm='slim', optimizer=optimizer) super().__init__(model, config_list, pruning_algorithm='slim', optimizer=optimizer)
...@@ -128,9 +139,50 @@ class SlimPruner(OneshotPruner): ...@@ -128,9 +139,50 @@ class SlimPruner(OneshotPruner):
if len(config_list) > 1: if len(config_list) > 1:
logger.warning('Slim pruner only supports 1 configuration') logger.warning('Slim pruner only supports 1 configuration')
class _StructuredFilterPruner(OneshotPruner): class _StructuredFilterPruner(OneshotPruner):
def __init__(self, model, config_list, pruning_algorithm, optimizer=None, **algo_kwargs): """
super().__init__(model, config_list, pruning_algorithm=pruning_algorithm, optimizer=optimizer, **algo_kwargs) _StructuredFilterPruner has two ways to calculate the masks
for conv layers. In the normal way, the _StructuredFilterPruner
will calculate the mask of each layer separately. For example, each
conv layer determine which filters should be pruned according to its L1
norm. In constrast, in the dependency-aware way, the layers that in a
dependency group will be pruned jointly and these layers will be forced
to prune the same channels.
"""
def __init__(self, model, config_list, pruning_algorithm, optimizer=None, dependency_aware=False, dummy_input=None, **algo_kwargs):
super().__init__(model, config_list, pruning_algorithm=pruning_algorithm,
optimizer=optimizer, **algo_kwargs)
self.dependency_aware = dependency_aware
# set the dependency-aware switch for the masker
self.masker.dependency_aware = dependency_aware
self.dummy_input = dummy_input
if self.dependency_aware:
errmsg = "When dependency_aware is set, the dummy_input should not be None"
assert self.dummy_input is not None, errmsg
# Get the TorchModuleGraph of the target model
# to trace the model, we need to unwrap the wrappers
self._unwrap_model()
self.graph = TorchModuleGraph(model, dummy_input)
self._wrap_model()
self.channel_depen = ChannelDependency(
traced_model=self.graph.trace)
self.group_depen = GroupDependency(traced_model=self.graph.trace)
self.channel_depen = self.channel_depen.dependency_sets
self.channel_depen = {
name: sets for sets in self.channel_depen for name in sets}
self.group_depen = self.group_depen.dependency_sets
def update_mask(self):
if not self.dependency_aware:
# if we use the normal way to update the mask,
# then call the update_mask of the father class
super(_StructuredFilterPruner, self).update_mask()
else:
# if we update the mask in a dependency-aware way
# then we call _dependency_update_mask
self._dependency_update_mask()
def validate_config(self, model, config_list): def validate_config(self, model, config_list):
schema = CompressorSchema([{ schema = CompressorSchema([{
...@@ -141,6 +193,71 @@ class _StructuredFilterPruner(OneshotPruner): ...@@ -141,6 +193,71 @@ class _StructuredFilterPruner(OneshotPruner):
schema.validate(config_list) schema.validate(config_list)
def _dependency_calc_mask(self, wrappers, channel_dsets, wrappers_idx=None):
"""
calculate the masks for the conv layers in the same
channel dependecy set. All the layers passed in have
the same number of channels.
Parameters
----------
wrappers: list
The list of the wrappers that in the same channel dependency
set.
wrappers_idx: list
The list of the indexes of wrapppers.
Returns
-------
masks: dict
A dict object that contains the masks of the layers in this
dependency group, the key is the name of the convolutional layers.
"""
# The number of the groups for each conv layers
# Note that, this number may be different from its
# original number of groups of filters.
groups = [self.group_depen[_w.name] for _w in wrappers]
sparsities = [_w.config['sparsity'] for _w in wrappers]
masks = self.masker.calc_mask(
sparsities, wrappers, wrappers_idx, channel_dsets=channel_dsets, groups=groups)
if masks is not None:
# if masks is None, then the mask calculation fails.
# for example, in activation related maskers, we should
# pass enough batches of data to the model, so that the
# masks can be calculated successfully.
for _w in wrappers:
_w.if_calculated = True
return masks
def _dependency_update_mask(self):
"""
In the original update_mask, the wraper of each layer will update its
own mask according to the sparsity specified in the config_list. However, in
the _dependency_update_mask, we may prune several layers at the same
time according the sparsities and the channel/group dependencies.
"""
name2wrapper = {x.name: x for x in self.get_modules_wrapper()}
wrapper2index = {x: i for i, x in enumerate(self.get_modules_wrapper())}
for wrapper in self.get_modules_wrapper():
if wrapper.if_calculated:
continue
# find all the conv layers that have channel dependecy with this layer
# and prune all these layers at the same time.
_names = [x for x in self.channel_depen[wrapper.name]]
logger.info('Pruning the dependent layers: %s', ','.join(_names))
_wrappers = [name2wrapper[name]
for name in _names if name in name2wrapper]
_wrapper_idxes = [wrapper2index[_w] for _w in _wrappers]
masks = self._dependency_calc_mask(
_wrappers, _names, wrappers_idx=_wrapper_idxes)
if masks is not None:
for layer in masks:
for mask_type in masks[layer]:
assert hasattr(
name2wrapper[layer], mask_type), "there is no attribute '%s' in wrapper on %s" % (mask_type, layer)
setattr(name2wrapper[layer], mask_type, masks[layer][mask_type])
class L1FilterPruner(_StructuredFilterPruner): class L1FilterPruner(_StructuredFilterPruner):
""" """
Parameters Parameters
...@@ -153,9 +270,23 @@ class L1FilterPruner(_StructuredFilterPruner): ...@@ -153,9 +270,23 @@ class L1FilterPruner(_StructuredFilterPruner):
- op_types : Only Conv2d is supported in L1FilterPruner. - op_types : Only Conv2d is supported in L1FilterPruner.
optimizer: torch.optim.Optimizer optimizer: torch.optim.Optimizer
Optimizer used to train model Optimizer used to train model
dependency_aware: bool
If prune the model in a dependency-aware way. If it is `True`, this pruner will
prune the model according to the l2-norm of weights and the channel-dependency or
group-dependency of the model. In this way, the pruner will force the conv layers
that have dependencies to prune the same channels, so the speedup module can better
harvest the speed benefit from the pruned model. Note that, if this flag is set True
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input : torch.Tensor
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
""" """
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, pruning_algorithm='l1', optimizer=optimizer) def __init__(self, model, config_list, optimizer=None, dependency_aware=False, dummy_input=None):
super().__init__(model, config_list, pruning_algorithm='l1', optimizer=optimizer,
dependency_aware=dependency_aware, dummy_input=dummy_input)
class L2FilterPruner(_StructuredFilterPruner): class L2FilterPruner(_StructuredFilterPruner):
""" """
...@@ -169,9 +300,23 @@ class L2FilterPruner(_StructuredFilterPruner): ...@@ -169,9 +300,23 @@ class L2FilterPruner(_StructuredFilterPruner):
- op_types : Only Conv2d is supported in L2FilterPruner. - op_types : Only Conv2d is supported in L2FilterPruner.
optimizer: torch.optim.Optimizer optimizer: torch.optim.Optimizer
Optimizer used to train model Optimizer used to train model
dependency_aware: bool
If prune the model in a dependency-aware way. If it is `True`, this pruner will
prune the model according to the l2-norm of weights and the channel-dependency or
group-dependency of the model. In this way, the pruner will force the conv layers
that have dependencies to prune the same channels, so the speedup module can better
harvest the speed benefit from the pruned model. Note that, if this flag is set True
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input : torch.Tensor
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
""" """
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, pruning_algorithm='l2', optimizer=optimizer) def __init__(self, model, config_list, optimizer=None, dependency_aware=False, dummy_input=None):
super().__init__(model, config_list, pruning_algorithm='l2', optimizer=optimizer,
dependency_aware=dependency_aware, dummy_input=dummy_input)
class FPGMPruner(_StructuredFilterPruner): class FPGMPruner(_StructuredFilterPruner):
""" """
...@@ -185,9 +330,23 @@ class FPGMPruner(_StructuredFilterPruner): ...@@ -185,9 +330,23 @@ class FPGMPruner(_StructuredFilterPruner):
- op_types : Only Conv2d is supported in FPGM Pruner. - op_types : Only Conv2d is supported in FPGM Pruner.
optimizer: torch.optim.Optimizer optimizer: torch.optim.Optimizer
Optimizer used to train model Optimizer used to train model
dependency_aware: bool
If prune the model in a dependency-aware way. If it is `True`, this pruner will
prune the model according to the l2-norm of weights and the channel-dependency or
group-dependency of the model. In this way, the pruner will force the conv layers
that have dependencies to prune the same channels, so the speedup module can better
harvest the speed benefit from the pruned model. Note that, if this flag is set True
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input : torch.Tensor
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
""" """
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, pruning_algorithm='fpgm', optimizer=optimizer) def __init__(self, model, config_list, optimizer=None, dependency_aware=False, dummy_input=None):
super().__init__(model, config_list, pruning_algorithm='fpgm',
dependency_aware=dependency_aware, dummy_input=dummy_input, optimizer=optimizer)
class TaylorFOWeightFilterPruner(_StructuredFilterPruner): class TaylorFOWeightFilterPruner(_StructuredFilterPruner):
""" """
...@@ -201,9 +360,28 @@ class TaylorFOWeightFilterPruner(_StructuredFilterPruner): ...@@ -201,9 +360,28 @@ class TaylorFOWeightFilterPruner(_StructuredFilterPruner):
- op_types : Currently only Conv2d is supported in TaylorFOWeightFilterPruner. - op_types : Currently only Conv2d is supported in TaylorFOWeightFilterPruner.
optimizer: torch.optim.Optimizer optimizer: torch.optim.Optimizer
Optimizer used to train model Optimizer used to train model
statistics_batch_num: int
The number of batches to statistic the activation.
dependency_aware: bool
If prune the model in a dependency-aware way. If it is `True`, this pruner will
prune the model according to the l2-norm of weights and the channel-dependency or
group-dependency of the model. In this way, the pruner will force the conv layers
that have dependencies to prune the same channels, so the speedup module can better
harvest the speed benefit from the pruned model. Note that, if this flag is set True
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input : torch.Tensor
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
""" """
def __init__(self, model, config_list, optimizer=None, statistics_batch_num=1):
super().__init__(model, config_list, pruning_algorithm='taylorfo', optimizer=optimizer, statistics_batch_num=statistics_batch_num) def __init__(self, model, config_list, optimizer=None, statistics_batch_num=1,
dependency_aware=False, dummy_input=None):
super().__init__(model, config_list, pruning_algorithm='taylorfo',
dependency_aware=dependency_aware, dummy_input=dummy_input,
optimizer=optimizer, statistics_batch_num=statistics_batch_num)
class ActivationAPoZRankFilterPruner(_StructuredFilterPruner): class ActivationAPoZRankFilterPruner(_StructuredFilterPruner):
""" """
...@@ -217,10 +395,30 @@ class ActivationAPoZRankFilterPruner(_StructuredFilterPruner): ...@@ -217,10 +395,30 @@ class ActivationAPoZRankFilterPruner(_StructuredFilterPruner):
- op_types : Only Conv2d is supported in ActivationAPoZRankFilterPruner. - op_types : Only Conv2d is supported in ActivationAPoZRankFilterPruner.
optimizer: torch.optim.Optimizer optimizer: torch.optim.Optimizer
Optimizer used to train model Optimizer used to train model
activation: str
The activation type.
statistics_batch_num: int
The number of batches to statistic the activation.
dependency_aware: bool
If prune the model in a dependency-aware way. If it is `True`, this pruner will
prune the model according to the l2-norm of weights and the channel-dependency or
group-dependency of the model. In this way, the pruner will force the conv layers
that have dependencies to prune the same channels, so the speedup module can better
harvest the speed benefit from the pruned model. Note that, if this flag is set True
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input : torch.Tensor
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
""" """
def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1):
super().__init__(model, config_list, pruning_algorithm='apoz', optimizer=optimizer, \ def __init__(self, model, config_list, optimizer=None, activation='relu',
activation=activation, statistics_batch_num=statistics_batch_num) statistics_batch_num=1, dependency_aware=False, dummy_input=None):
super().__init__(model, config_list, pruning_algorithm='apoz', optimizer=optimizer,
dependency_aware=dependency_aware, dummy_input=dummy_input,
activation=activation, statistics_batch_num=statistics_batch_num)
class ActivationMeanRankFilterPruner(_StructuredFilterPruner): class ActivationMeanRankFilterPruner(_StructuredFilterPruner):
""" """
...@@ -233,8 +431,26 @@ class ActivationMeanRankFilterPruner(_StructuredFilterPruner): ...@@ -233,8 +431,26 @@ class ActivationMeanRankFilterPruner(_StructuredFilterPruner):
- sparsity : How much percentage of convolutional filters are to be pruned. - sparsity : How much percentage of convolutional filters are to be pruned.
- op_types : Only Conv2d is supported in ActivationMeanRankFilterPruner. - op_types : Only Conv2d is supported in ActivationMeanRankFilterPruner.
optimizer: torch.optim.Optimizer optimizer: torch.optim.Optimizer
Optimizer used to train model Optimizer used to train model.
activation: str
The activation type.
statistics_batch_num: int
The number of batches to statistic the activation.
dependency_aware: bool
If prune the model in a dependency-aware way. If it is `True`, this pruner will
prune the model according to the l2-norm of weights and the channel-dependency or
group-dependency of the model. In this way, the pruner will force the conv layers
that have dependencies to prune the same channels, so the speedup module can better
harvest the speed benefit from the pruned model. Note that, if this flag is set True
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input : torch.Tensor
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
""" """
def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1):
super().__init__(model, config_list, pruning_algorithm='mean_activation', optimizer=optimizer, \ def __init__(self, model, config_list, optimizer=None, activation='relu',
activation=activation, statistics_batch_num=statistics_batch_num) statistics_batch_num=1, dependency_aware=False, dummy_input=None):
super().__init__(model, config_list, pruning_algorithm='mean_activation', optimizer=optimizer,
dependency_aware=dependency_aware, dummy_input=dummy_input,
activation=activation, statistics_batch_num=statistics_batch_num)
...@@ -7,12 +7,13 @@ import numpy as np ...@@ -7,12 +7,13 @@ import numpy as np
import torch import torch
from .weight_masker import WeightMasker from .weight_masker import WeightMasker
__all__ = ['L1FilterPrunerMasker', 'L2FilterPrunerMasker', 'FPGMPrunerMasker', \ __all__ = ['L1FilterPrunerMasker', 'L2FilterPrunerMasker', 'FPGMPrunerMasker',
'TaylorFOWeightFilterPrunerMasker', 'ActivationAPoZRankFilterPrunerMasker', \ 'TaylorFOWeightFilterPrunerMasker', 'ActivationAPoZRankFilterPrunerMasker',
'ActivationMeanRankFilterPrunerMasker', 'SlimPrunerMasker', 'AMCWeightMasker'] 'ActivationMeanRankFilterPrunerMasker', 'SlimPrunerMasker', 'AMCWeightMasker']
logger = logging.getLogger('torch filter pruners') logger = logging.getLogger('torch filter pruners')
class StructuredWeightMasker(WeightMasker): class StructuredWeightMasker(WeightMasker):
""" """
A structured pruning masker base class that prunes convolutional layer filters. A structured pruning masker base class that prunes convolutional layer filters.
...@@ -31,14 +32,48 @@ class StructuredWeightMasker(WeightMasker): ...@@ -31,14 +32,48 @@ class StructuredWeightMasker(WeightMasker):
be round up to 28 (which can be divided by 4) and only 4 filters are pruned. be round up to 28 (which can be divided by 4) and only 4 filters are pruned.
""" """
def __init__(self, model, pruner, preserve_round=1):
def __init__(self, model, pruner, preserve_round=1, dependency_aware=False):
self.model = model self.model = model
self.pruner = pruner self.pruner = pruner
self.preserve_round = preserve_round self.preserve_round = preserve_round
self.dependency_aware = dependency_aware
def calc_mask(self, sparsity, wrapper, wrapper_idx=None): def calc_mask(self, sparsity, wrapper, wrapper_idx=None, **depen_kwargs):
""" """
Calculate the mask of given layer. calculate the mask for `wrapper`.
Parameters
----------
sparsity: float/list of float
The target sparsity of the wrapper. If we calculate the mask in
the normal way, then sparsity is a float number. In contrast, if
we calculate the mask in the dependency-aware way, sparsity is a
list of float numbers, each float number corressponds to a sparsity
of a layer.
wrapper: PrunerModuleWrapper/list of PrunerModuleWrappers
The wrapper of the target layer. If we calculate the mask in the normal
way, then `wrapper` is an instance of PrunerModuleWrapper, else `wrapper`
is a list of PrunerModuleWrapper.
wrapper_idx: int/list of int
The index of the wrapper.
depen_kwargs: dict
The kw_args for the dependency-aware mode.
"""
if not self.dependency_aware:
# calculate the mask in the normal way, each layer calculate its
# own mask separately
return self._normal_calc_mask(sparsity, wrapper, wrapper_idx)
else:
# if the dependency_aware switch is on, then calculate the mask
# in the dependency-aware way
return self._dependency_calc_mask(sparsity, wrapper, wrapper_idx, **depen_kwargs)
def _get_current_state(self, sparsity, wrapper, wrapper_idx=None):
"""
Some pruner may prune the layers in a iterative way. In each pruning iteration,
we may get the current state of this wrapper/layer, and continue to prune this layer
based on the current state. This function is to get the current pruning state of the
target wrapper/layer.
Parameters Parameters
---------- ----------
sparsity: float sparsity: float
...@@ -49,10 +84,14 @@ class StructuredWeightMasker(WeightMasker): ...@@ -49,10 +84,14 @@ class StructuredWeightMasker(WeightMasker):
index of this wrapper in pruner's all wrappers index of this wrapper in pruner's all wrappers
Returns Returns
------- -------
dict base_mask: dict
dictionary for storing masks, keys of the dict: dict object that stores the mask of this wrapper in this iteration, if it is the
'weight_mask': weight mask tensor first iteration, then we create a new mask with all ones. If there is already a
'bias_mask': bias mask tensor (optional) mask in this wrapper, then we return the existing mask.
weight: tensor
the current weight of this layer
num_prune: int
how many filters we should prune
""" """
msg = 'module type {} is not supported!'.format(wrapper.type) msg = 'module type {} is not supported!'.format(wrapper.type)
assert wrapper.type == 'Conv2d', msg assert wrapper.type == 'Conv2d', msg
...@@ -78,17 +117,178 @@ class StructuredWeightMasker(WeightMasker): ...@@ -78,17 +117,178 @@ class StructuredWeightMasker(WeightMasker):
num_prune = int(num_total * sparsity) num_prune = int(num_total * sparsity)
if self.preserve_round > 1: if self.preserve_round > 1:
num_preserve = num_total - num_prune num_preserve = num_total - num_prune
num_preserve = int(math.ceil(num_preserve * 1. / self.preserve_round) * self.preserve_round) num_preserve = int(
math.ceil(num_preserve * 1. / self.preserve_round) * self.preserve_round)
if num_preserve > num_total: if num_preserve > num_total:
num_preserve = int(math.floor(num_total * 1. / self.preserve_round) * self.preserve_round) num_preserve = int(math.floor(
num_total * 1. / self.preserve_round) * self.preserve_round)
num_prune = num_total - num_preserve num_prune = num_total - num_preserve
# weight*mask_weight: apply base mask for iterative pruning
return mask, weight * mask_weight, num_prune
def _normal_calc_mask(self, sparsity, wrapper, wrapper_idx=None):
"""
Calculate the mask of given layer.
Parameters
----------
sparsity: float
pruning ratio, preserved weight ratio is `1 - sparsity`
wrapper: PrunerModuleWrapper
layer wrapper of this layer
wrapper_idx: int
index of this wrapper in pruner's all wrappers
Returns
-------
dict
dictionary for storing masks, keys of the dict:
'weight_mask': weight mask tensor
'bias_mask': bias mask tensor (optional)
"""
mask, weight, num_prune = self._get_current_state(
sparsity, wrapper, wrapper_idx)
num_total = weight.size(0)
if num_total < 2 or num_prune < 1: if num_total < 2 or num_prune < 1:
return mask return mask
# weight*mask_weight: apply base mask for iterative pruning
return self.get_mask(mask, weight*mask_weight, num_prune, wrapper, wrapper_idx)
def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx): return self.get_mask(mask, weight, num_prune, wrapper, wrapper_idx)
def _common_channel_to_prune(self, sparsities, wrappers, wrappers_idx, channel_dsets, groups):
"""
Calculate the common channels should be pruned by all the layers in this group.
This function is for filter pruning of Conv layers. if want to support the dependency-aware
mode for others ops, you need to inherit this class and overwrite `_common_channel_to_prune`.
Parameters
----------
sparsities : list
List of float that specify the sparsity for each conv layer.
wrappers : list
List of wrappers
groups : list
The number of the filter groups of each layer.
wrappers_idx : list
The indexes of the wrappers
"""
# sparsity configs for each wrapper
# sparsities = [_w.config['sparsity'] for _w in wrappers]
# check the type of the input wrappers
for _w in wrappers:
msg = 'module type {} is not supported!'.format(_w.type)
assert _w.type == 'Conv2d', msg
# Among the dependent layers, the layer with smallest
# sparsity determines the final benefit of the speedup
# module. To better harvest the speed benefit, we need
# to ensure that these dependent layers have at least
# `min_sparsity` pruned channel are the same.
if len(channel_dsets) == len(wrappers):
# all the layers in the dependency sets are pruned
min_sparsity = min(sparsities)
else:
# not all the layers in the dependency set
# are pruned
min_sparsity = 0
# donnot prune the channels that we cannot harvest the speed from
sparsities = [min_sparsity] * len(sparsities)
# find the max number of the filter groups of the dependent
# layers. The group constraint of this dependency set is decided
# by the layer with the max groups.
# should use the least common multiple for all the groups
# the max_group is lower than the channel_count, because
# the number of the filter is always divisible by the number of the group
max_group = np.lcm.reduce(groups)
channel_count = wrappers[0].module.weight.data.size(0)
device = wrappers[0].module.weight.device
channel_sum = torch.zeros(channel_count).to(device)
for _w, _w_idx in zip(wrappers, wrappers_idx):
# calculate the L1/L2 sum for all channels
c_sum = self.get_channel_sum(_w, _w_idx)
if c_sum is None:
# if the channel sum cannot be calculated
# now, return None
return None
channel_sum += c_sum
# prune the same `min_sparsity` channels based on channel_sum
# for all the layers in the channel sparsity
target_pruned = int(channel_count * min_sparsity)
# pruned_per_group may be zero, for example dw conv
pruned_per_group = int(target_pruned / max_group)
group_step = int(channel_count / max_group)
channel_masks = []
for gid in range(max_group):
_start = gid * group_step
_end = (gid + 1) * group_step
if pruned_per_group > 0:
threshold = torch.topk(
channel_sum[_start: _end], pruned_per_group, largest=False)[0].max()
group_mask = torch.gt(channel_sum[_start:_end], threshold)
else:
group_mask = torch.ones(group_step).to(device)
channel_masks.append(group_mask)
channel_masks = torch.cat(channel_masks, dim=0)
pruned_channel_index = (
channel_masks == False).nonzero().squeeze(1).tolist()
logger.info('Prune the %s channels for all dependent',
','.join([str(x) for x in pruned_channel_index]))
return channel_masks
def _dependency_calc_mask(self, sparsities, wrappers, wrappers_idx, channel_dsets, groups):
"""
Calculate the masks for the layers in the same dependency sets.
Similar to the traditional original calc_mask, _dependency_calc_mask
will prune the target layers based on the L1/L2 norm of the weights.
However, StructuredWeightMasker prunes the filter completely based on the
L1/L2 norm of each filter. In contrast, _dependency_calc_mask
will try to satisfy the channel/group dependency(see nni.compression.torch.
utils.shape_dependency for details). Specifically, _dependency_calc_mask
will try to prune the same channels for the layers that have channel dependency.
In addition, this mask calculator will also ensure that the number of filters
pruned in each group is the same(meet the group dependency).
Parameters
----------
sparsities : list
List of float that specify the sparsity for each conv layer.
wrappers : list
List of wrappers
groups : list
The number of the filter groups of each layer.
wrappers_idx : list
The indexes of the wrappers
"""
channel_masks = self._common_channel_to_prune(
sparsities, wrappers, wrappers_idx, channel_dsets, groups)
# calculate the mask for each layer based on channel_masks, first
# every layer will prune the same channels masked in channel_masks.
# If the sparsity of a layers is larger than min_sparsity, then it
# will continue prune sparsity - min_sparsity channels to meet the sparsity
# config.
masks = {}
for _pos, _w in enumerate(wrappers):
_w_idx = wrappers_idx[_pos]
sparsity = sparsities[_pos]
name = _w.name
# _tmp_mask = self._normal_calc_mask(
# sparsity, _w, _w_idx, channel_masks)
base_mask, current_weight, num_prune = self._get_current_state(
sparsity, _w, _w_idx)
num_total = current_weight.size(0)
if num_total < 2 or num_prune < 1:
return base_mask
_tmp_mask = self.get_mask(
base_mask, current_weight, num_prune, _w, _w_idx, channel_masks)
if _tmp_mask is None:
# if the mask calculation fails
return None
masks[name] = _tmp_mask
return masks
def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx, channel_masks=None):
""" """
Calculate the mask of given layer. Calculate the mask of given layer.
Parameters Parameters
...@@ -103,12 +303,38 @@ class StructuredWeightMasker(WeightMasker): ...@@ -103,12 +303,38 @@ class StructuredWeightMasker(WeightMasker):
layer wrapper of this layer layer wrapper of this layer
wrapper_idx: int wrapper_idx: int
index of this wrapper in pruner's all wrappers index of this wrapper in pruner's all wrappers
channel_masks: Tensor
If mask some channels for this layer in advance. In the dependency-aware
mode, before calculating the masks for each layer, we will calculate a common
mask for all the layers in the dependency set. For the pruners that doesnot
support dependency-aware mode, they can just ignore this parameter.
Returns Returns
------- -------
dict dict
dictionary for storing masks dictionary for storing masks
""" """
raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__)) raise NotImplementedError(
'{} get_mask is not implemented'.format(self.__class__.__name__))
def get_channel_sum(self, wrapper, wrapper_idx):
"""
Calculate the importance weight for each channel. If want to support the
dependency-aware mode for this one-shot pruner, this function must be
implemented.
Parameters
----------
wrapper: PrunerModuleWrapper
layer wrapper of this layer
wrapper_idx: int
index of this wrapper in pruner's all wrappers
Returns
-------
tensor
Tensor that indicates the importance of each channel
"""
raise NotImplementedError(
'{} get_channel_sum is not implemented'.format(self.__class__.__name__))
class L1FilterPrunerMasker(StructuredWeightMasker): class L1FilterPrunerMasker(StructuredWeightMasker):
""" """
...@@ -119,30 +345,56 @@ class L1FilterPrunerMasker(StructuredWeightMasker): ...@@ -119,30 +345,56 @@ class L1FilterPrunerMasker(StructuredWeightMasker):
https://arxiv.org/abs/1608.08710 https://arxiv.org/abs/1608.08710
""" """
def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx): def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx, channel_masks=None):
# get the l1-norm sum for each filter
w_abs_structured = self.get_channel_sum(wrapper, wrapper_idx)
if channel_masks is not None:
# if we need to mask some channels in advance
w_abs_structured = w_abs_structured * channel_masks
threshold = torch.topk(w_abs_structured.view(-1),
num_prune, largest=False)[0].max()
mask_weight = torch.gt(w_abs_structured, threshold)[
:, None, None, None].expand_as(weight).type_as(weight)
mask_bias = torch.gt(w_abs_structured, threshold).type_as(
weight).detach() if base_mask['bias_mask'] is not None else None
return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias}
def get_channel_sum(self, wrapper, wrapper_idx):
weight = wrapper.module.weight.data
filters = weight.shape[0] filters = weight.shape[0]
w_abs = weight.abs() w_abs = weight.abs()
w_abs_structured = w_abs.view(filters, -1).sum(dim=1) w_abs_structured = w_abs.view(filters, -1).sum(dim=1)
threshold = torch.topk(w_abs_structured.view(-1), num_prune, largest=False)[0].max() return w_abs_structured
mask_weight = torch.gt(w_abs_structured, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
mask_bias = torch.gt(w_abs_structured, threshold).type_as(weight).detach() if base_mask['bias_mask'] is not None else None
return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias}
class L2FilterPrunerMasker(StructuredWeightMasker): class L2FilterPrunerMasker(StructuredWeightMasker):
""" """
A structured pruning algorithm that prunes the filters with the A structured pruning algorithm that prunes the filters with the
smallest L2 norm of the weights. smallest L2 norm of the weights.
""" """
def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx):
def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx, channel_masks=None):
# get the l2-norm sum for each filter
w_l2_norm = self.get_channel_sum(wrapper, wrapper_idx)
if channel_masks is not None:
# if we need to mask some channels in advance
w_l2_norm = w_l2_norm * channel_masks
threshold = torch.topk(
w_l2_norm.view(-1), num_prune, largest=False)[0].max()
mask_weight = torch.gt(w_l2_norm, threshold)[
:, None, None, None].expand_as(weight).type_as(weight)
mask_bias = torch.gt(w_l2_norm, threshold).type_as(
weight).detach() if base_mask['bias_mask'] is not None else None
return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias}
def get_channel_sum(self, wrapper, wrapper_idx):
weight = wrapper.module.weight.data
filters = weight.shape[0] filters = weight.shape[0]
w = weight.view(filters, -1) w = weight.view(filters, -1)
w_l2_norm = torch.sqrt((w ** 2).sum(dim=1)) w_l2_norm = torch.sqrt((w ** 2).sum(dim=1))
threshold = torch.topk(w_l2_norm.view(-1), num_prune, largest=False)[0].max() return w_l2_norm
mask_weight = torch.gt(w_l2_norm, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
mask_bias = torch.gt(w_l2_norm, threshold).type_as(weight).detach() if base_mask['bias_mask'] is not None else None
return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias}
class FPGMPrunerMasker(StructuredWeightMasker): class FPGMPrunerMasker(StructuredWeightMasker):
...@@ -151,22 +403,23 @@ class FPGMPrunerMasker(StructuredWeightMasker): ...@@ -151,22 +403,23 @@ class FPGMPrunerMasker(StructuredWeightMasker):
"Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration", "Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration",
https://arxiv.org/pdf/1811.00250.pdf https://arxiv.org/pdf/1811.00250.pdf
""" """
def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx):
min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune) def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx, channel_masks=None):
min_gm_idx = self._get_min_gm_kernel_idx(
num_prune, wrapper, wrapper_idx, channel_masks)
for idx in min_gm_idx: for idx in min_gm_idx:
base_mask['weight_mask'][idx] = 0. base_mask['weight_mask'][idx] = 0.
if base_mask['bias_mask'] is not None: if base_mask['bias_mask'] is not None:
base_mask['bias_mask'][idx] = 0. base_mask['bias_mask'][idx] = 0.
return base_mask return base_mask
def _get_min_gm_kernel_idx(self, weight, n): def _get_min_gm_kernel_idx(self, num_prune, wrapper, wrapper_idx, channel_masks):
assert len(weight.size()) in [3, 4] channel_dist = self.get_channel_sum(wrapper, wrapper_idx)
if channel_masks is not None:
dist_list = [] channel_dist = channel_dist * channel_masks
for out_i in range(weight.size(0)): dist_list = [(channel_dist[i], i)
dist_sum = self._get_distance_sum(weight, out_i) for i in range(channel_dist.size(0))]
dist_list.append((dist_sum, out_i)) min_gm_kernels = sorted(dist_list, key=lambda x: x[0])[:num_prune]
min_gm_kernels = sorted(dist_list, key=lambda x: x[0])[:n]
return [x[1] for x in min_gm_kernels] return [x[1] for x in min_gm_kernels]
def _get_distance_sum(self, weight, out_idx): def _get_distance_sum(self, weight, out_idx):
...@@ -195,6 +448,16 @@ class FPGMPrunerMasker(StructuredWeightMasker): ...@@ -195,6 +448,16 @@ class FPGMPrunerMasker(StructuredWeightMasker):
x = torch.sqrt(x) x = torch.sqrt(x)
return x.sum() return x.sum()
def get_channel_sum(self, wrapper, wrapper_idx):
weight = wrapper.module.weight.data
assert len(weight.size()) in [3, 4]
dist_list = []
for out_i in range(weight.size(0)):
dist_sum = self._get_distance_sum(weight, out_i)
dist_list.append(dist_sum)
return torch.Tensor(dist_list).to(weight.device)
class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker): class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker):
""" """
A structured pruning algorithm that prunes the filters with the smallest A structured pruning algorithm that prunes the filters with the smallest
...@@ -203,6 +466,7 @@ class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker): ...@@ -203,6 +466,7 @@ class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker):
"Importance Estimation for Neural Network Pruning", CVPR 2019. "Importance Estimation for Neural Network Pruning", CVPR 2019.
http://jankautz.com/publications/Importance4NNPruning_CVPR19.pdf http://jankautz.com/publications/Importance4NNPruning_CVPR19.pdf
""" """
def __init__(self, model, pruner, statistics_batch_num=1): def __init__(self, model, pruner, statistics_batch_num=1):
super().__init__(model, pruner) super().__init__(model, pruner)
self.pruner.statistics_batch_num = statistics_batch_num self.pruner.statistics_batch_num = statistics_batch_num
...@@ -210,14 +474,14 @@ class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker): ...@@ -210,14 +474,14 @@ class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker):
self.pruner.iterations = 0 self.pruner.iterations = 0
self.pruner.patch_optimizer(self.calc_contributions) self.pruner.patch_optimizer(self.calc_contributions)
def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx): def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx, channel_masks=None):
if self.pruner.iterations < self.pruner.statistics_batch_num: channel_contribution = self.get_channel_sum(wrapper, wrapper_idx)
return None if channel_contribution is None:
# iteration is not enough
if wrapper.contribution is None:
return None return None
if channel_masks is not None:
prune_indices = torch.argsort(wrapper.contribution)[:num_prune] channel_contribution = channel_contribution * channel_masks
prune_indices = torch.argsort(channel_contribution)[:num_prune]
for idx in prune_indices: for idx in prune_indices:
base_mask['weight_mask'][idx] = 0. base_mask['weight_mask'][idx] = 0.
if base_mask['bias_mask'] is not None: if base_mask['bias_mask'] is not None:
...@@ -233,7 +497,8 @@ class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker): ...@@ -233,7 +497,8 @@ class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker):
return return
for wrapper in self.pruner.get_modules_wrapper(): for wrapper in self.pruner.get_modules_wrapper():
filters = wrapper.module.weight.size(0) filters = wrapper.module.weight.size(0)
contribution = (wrapper.module.weight*wrapper.module.weight.grad).data.pow(2).view(filters, -1).sum(dim=1) contribution = (
wrapper.module.weight*wrapper.module.weight.grad).data.pow(2).view(filters, -1).sum(dim=1)
if wrapper.contribution is None: if wrapper.contribution is None:
wrapper.contribution = contribution wrapper.contribution = contribution
else: else:
...@@ -241,6 +506,13 @@ class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker): ...@@ -241,6 +506,13 @@ class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker):
self.pruner.iterations += 1 self.pruner.iterations += 1
def get_channel_sum(self, wrapper, wrapper_idx):
if self.pruner.iterations < self.pruner.statistics_batch_num:
return None
if wrapper.contribution is None:
return None
return wrapper.contribution
class ActivationFilterPrunerMasker(StructuredWeightMasker): class ActivationFilterPrunerMasker(StructuredWeightMasker):
def __init__(self, model, pruner, statistics_batch_num=1, activation='relu'): def __init__(self, model, pruner, statistics_batch_num=1, activation='relu'):
...@@ -259,7 +531,8 @@ class ActivationFilterPrunerMasker(StructuredWeightMasker): ...@@ -259,7 +531,8 @@ class ActivationFilterPrunerMasker(StructuredWeightMasker):
def _add_activation_collector(self, pruner): def _add_activation_collector(self, pruner):
def collector(collected_activation): def collector(collected_activation):
def hook(module_, input_, output): def hook(module_, input_, output):
collected_activation.append(pruner.activation(output.detach().cpu())) collected_activation.append(
pruner.activation(output.detach().cpu()))
return hook return hook
pruner.collected_activation = {} pruner.collected_activation = {}
pruner._fwd_hook_id += 1 pruner._fwd_hook_id += 1
...@@ -267,11 +540,13 @@ class ActivationFilterPrunerMasker(StructuredWeightMasker): ...@@ -267,11 +540,13 @@ class ActivationFilterPrunerMasker(StructuredWeightMasker):
for wrapper_idx, wrapper in enumerate(pruner.get_modules_wrapper()): for wrapper_idx, wrapper in enumerate(pruner.get_modules_wrapper()):
pruner.collected_activation[wrapper_idx] = [] pruner.collected_activation[wrapper_idx] = []
handle = wrapper.register_forward_hook(collector(pruner.collected_activation[wrapper_idx])) handle = wrapper.register_forward_hook(
collector(pruner.collected_activation[wrapper_idx]))
pruner._fwd_hook_handles[pruner._fwd_hook_id].append(handle) pruner._fwd_hook_handles[pruner._fwd_hook_id].append(handle)
return pruner._fwd_hook_id return pruner._fwd_hook_id
class ActivationAPoZRankFilterPrunerMasker(ActivationFilterPrunerMasker): class ActivationAPoZRankFilterPrunerMasker(ActivationFilterPrunerMasker):
""" """
A structured pruning algorithm that prunes the filters with the A structured pruning algorithm that prunes the filters with the
...@@ -280,19 +555,22 @@ class ActivationAPoZRankFilterPrunerMasker(ActivationFilterPrunerMasker): ...@@ -280,19 +555,22 @@ class ActivationAPoZRankFilterPrunerMasker(ActivationFilterPrunerMasker):
"Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient Deep Architectures", ICLR 2016. "Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient Deep Architectures", ICLR 2016.
https://arxiv.org/abs/1607.03250 https://arxiv.org/abs/1607.03250
""" """
def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx):
assert wrapper_idx is not None def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx, channel_masks=None):
activations = self.pruner.collected_activation[wrapper_idx] apoz = self.get_channel_sum(wrapper, wrapper_idx)
if len(activations) < self.statistics_batch_num: if apoz is None:
# the collected activations are not enough
return None return None
apoz = self._calc_apoz(activations) if channel_masks is not None:
prune_indices = torch.argsort(apoz, descending=True)[:num_prune] apoz = apoz * channel_masks
prune_indices = torch.argsort(apoz)[:num_prune]
for idx in prune_indices: for idx in prune_indices:
base_mask['weight_mask'][idx] = 0. base_mask['weight_mask'][idx] = 0.
if base_mask['bias_mask'] is not None: if base_mask['bias_mask'] is not None:
base_mask['bias_mask'][idx] = 0. base_mask['bias_mask'][idx] = 0.
if len(activations) >= self.statistics_batch_num and self.pruner.hook_id in self.pruner._fwd_hook_handles: if self.pruner.hook_id in self.pruner._fwd_hook_handles:
self.pruner.remove_activation_collector(self.pruner.hook_id) self.pruner.remove_activation_collector(self.pruner.hook_id)
return base_mask return base_mask
...@@ -313,8 +591,18 @@ class ActivationAPoZRankFilterPrunerMasker(ActivationFilterPrunerMasker): ...@@ -313,8 +591,18 @@ class ActivationAPoZRankFilterPrunerMasker(ActivationFilterPrunerMasker):
""" """
activations = torch.cat(activations, 0) activations = torch.cat(activations, 0)
_eq_zero = torch.eq(activations, torch.zeros_like(activations)) _eq_zero = torch.eq(activations, torch.zeros_like(activations))
_apoz = torch.sum(_eq_zero, dim=(0, 2, 3)) / torch.numel(_eq_zero[:, 0, :, :]) _apoz = torch.sum(_eq_zero, dim=(0, 2, 3), dtype=torch.float64) / \
return _apoz torch.numel(_eq_zero[:, 0, :, :])
return torch.ones_like(_apoz) - _apoz
def get_channel_sum(self, wrapper, wrapper_idx):
assert wrapper_idx is not None
activations = self.pruner.collected_activation[wrapper_idx]
if len(activations) < self.statistics_batch_num:
# collected activations is not enough
return None
return self._calc_apoz(activations).to(wrapper.module.weight.device)
class ActivationMeanRankFilterPrunerMasker(ActivationFilterPrunerMasker): class ActivationMeanRankFilterPrunerMasker(ActivationFilterPrunerMasker):
""" """
...@@ -324,19 +612,24 @@ class ActivationMeanRankFilterPrunerMasker(ActivationFilterPrunerMasker): ...@@ -324,19 +612,24 @@ class ActivationMeanRankFilterPrunerMasker(ActivationFilterPrunerMasker):
"Pruning Convolutional Neural Networks for Resource Efficient Inference", ICLR 2017. "Pruning Convolutional Neural Networks for Resource Efficient Inference", ICLR 2017.
https://arxiv.org/abs/1611.06440 https://arxiv.org/abs/1611.06440
""" """
def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx):
assert wrapper_idx is not None def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx, channel_masks=None):
activations = self.pruner.collected_activation[wrapper_idx]
if len(activations) < self.statistics_batch_num: mean_activation = self.get_channel_sum(wrapper, wrapper_idx)
if mean_activation is None:
# the collected activation is not enough
return None return None
mean_activation = self._cal_mean_activation(activations) if channel_masks is not None:
mean_activation = mean_activation * channel_masks
prune_indices = torch.argsort(mean_activation)[:num_prune] prune_indices = torch.argsort(mean_activation)[:num_prune]
for idx in prune_indices: for idx in prune_indices:
base_mask['weight_mask'][idx] = 0. base_mask['weight_mask'][idx] = 0.
if base_mask['bias_mask'] is not None: if base_mask['bias_mask'] is not None:
base_mask['bias_mask'][idx] = 0. base_mask['bias_mask'][idx] = 0.
# if len(activations) < self.statistics_batch_num, the code
if len(activations) >= self.statistics_batch_num and self.pruner.hook_id in self.pruner._fwd_hook_handles: # cannot reach here
if self.pruner.hook_id in self.pruner._fwd_hook_handles:
self.pruner.remove_activation_collector(self.pruner.hook_id) self.pruner.remove_activation_collector(self.pruner.hook_id)
return base_mask return base_mask
...@@ -359,6 +652,17 @@ class ActivationMeanRankFilterPrunerMasker(ActivationFilterPrunerMasker): ...@@ -359,6 +652,17 @@ class ActivationMeanRankFilterPrunerMasker(ActivationFilterPrunerMasker):
mean_activation = torch.mean(activations, dim=(0, 2, 3)) mean_activation = torch.mean(activations, dim=(0, 2, 3))
return mean_activation return mean_activation
def get_channel_sum(self, wrapper, wrapper_idx):
assert wrapper_idx is not None
activations = self.pruner.collected_activation[wrapper_idx]
if len(activations) < self.statistics_batch_num:
return None
# the memory overhead here is acceptable, because only
# the mean_activation tensor returned by _cal_mean_activation
# is transfer to gpu.
return self._cal_mean_activation(activations).to(wrapper.module.weight.device)
class SlimPrunerMasker(WeightMasker): class SlimPrunerMasker(WeightMasker):
""" """
A structured pruning algorithm that prunes channels by pruning the weights of BN layers. A structured pruning algorithm that prunes channels by pruning the weights of BN layers.
...@@ -374,7 +678,8 @@ class SlimPrunerMasker(WeightMasker): ...@@ -374,7 +678,8 @@ class SlimPrunerMasker(WeightMasker):
weight_list.append(layer.module.weight.data.abs().clone()) weight_list.append(layer.module.weight.data.abs().clone())
all_bn_weights = torch.cat(weight_list) all_bn_weights = torch.cat(weight_list)
k = int(all_bn_weights.shape[0] * pruner.config_list[0]['sparsity']) k = int(all_bn_weights.shape[0] * pruner.config_list[0]['sparsity'])
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max() self.global_threshold = torch.topk(
all_bn_weights.view(-1), k, largest=False)[0].max()
def calc_mask(self, sparsity, wrapper, wrapper_idx=None): def calc_mask(self, sparsity, wrapper, wrapper_idx=None):
assert wrapper.type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning' assert wrapper.type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
...@@ -384,22 +689,27 @@ class SlimPrunerMasker(WeightMasker): ...@@ -384,22 +689,27 @@ class SlimPrunerMasker(WeightMasker):
weight = weight * wrapper.weight_mask weight = weight * wrapper.weight_mask
base_mask = torch.ones(weight.size()).type_as(weight).detach() base_mask = torch.ones(weight.size()).type_as(weight).detach()
mask = {'weight_mask': base_mask.detach(), 'bias_mask': base_mask.clone().detach()} mask = {'weight_mask': base_mask.detach(
), 'bias_mask': base_mask.clone().detach()}
filters = weight.size(0) filters = weight.size(0)
num_prune = int(filters * sparsity) num_prune = int(filters * sparsity)
if filters >= 2 and num_prune >= 1: if filters >= 2 and num_prune >= 1:
w_abs = weight.abs() w_abs = weight.abs()
mask_weight = torch.gt(w_abs, self.global_threshold).type_as(weight) mask_weight = torch.gt(
w_abs, self.global_threshold).type_as(weight)
mask_bias = mask_weight.clone() mask_bias = mask_weight.clone()
mask = {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias.detach()} mask = {'weight_mask': mask_weight.detach(
), 'bias_mask': mask_bias.detach()}
return mask return mask
def least_square_sklearn(X, Y): def least_square_sklearn(X, Y):
from sklearn.linear_model import LinearRegression from sklearn.linear_model import LinearRegression
reg = LinearRegression(fit_intercept=False) reg = LinearRegression(fit_intercept=False)
reg.fit(X, Y) reg.fit(X, Y)
return reg.coef_ return reg.coef_
class AMCWeightMasker(WeightMasker): class AMCWeightMasker(WeightMasker):
""" """
Weight maskser class for AMC pruner. Currently, AMCPruner only supports pruning kernel Weight maskser class for AMC pruner. Currently, AMCPruner only supports pruning kernel
...@@ -420,6 +730,7 @@ class AMCWeightMasker(WeightMasker): ...@@ -420,6 +730,7 @@ class AMCWeightMasker(WeightMasker):
32 - 6 = 26 filters are preserved. If preserve_round is 4, preserved filters will 32 - 6 = 26 filters are preserved. If preserve_round is 4, preserved filters will
be round up to 28 (which can be divided by 4) and only 4 filters are pruned. be round up to 28 (which can be divided by 4) and only 4 filters are pruned.
""" """
def __init__(self, model, pruner, preserve_round=1): def __init__(self, model, pruner, preserve_round=1):
self.model = model self.model = model
self.pruner = pruner self.pruner = pruner
...@@ -467,7 +778,8 @@ class AMCWeightMasker(WeightMasker): ...@@ -467,7 +778,8 @@ class AMCWeightMasker(WeightMasker):
num_prune = int(num_total * sparsity) num_prune = int(num_total * sparsity)
if self.preserve_round > 1: if self.preserve_round > 1:
num_preserve = num_total - num_prune num_preserve = num_total - num_prune
num_preserve = int(math.ceil(num_preserve * 1. / self.preserve_round) * self.preserve_round) num_preserve = int(
math.ceil(num_preserve * 1. / self.preserve_round) * self.preserve_round)
if num_preserve > num_total: if num_preserve > num_total:
num_preserve = num_total num_preserve = num_total
num_prune = num_total - num_preserve num_prune = num_total - num_preserve
...@@ -484,7 +796,8 @@ class AMCWeightMasker(WeightMasker): ...@@ -484,7 +796,8 @@ class AMCWeightMasker(WeightMasker):
if preserve_idx is None: if preserve_idx is None:
importance = np.abs(w).sum((0, 2, 3)) importance = np.abs(w).sum((0, 2, 3))
sorted_idx = np.argsort(-importance) # sum magnitude along C_in, sort descend # sum magnitude along C_in, sort descend
sorted_idx = np.argsort(-importance)
d_prime = num_preserve d_prime = num_preserve
preserve_idx = sorted_idx[:d_prime] # to preserve index preserve_idx = sorted_idx[:d_prime] # to preserve index
else: else:
...@@ -499,10 +812,13 @@ class AMCWeightMasker(WeightMasker): ...@@ -499,10 +812,13 @@ class AMCWeightMasker(WeightMasker):
masked_X = X[:, mask] masked_X = X[:, mask]
if w.shape[2] == 1: # 1x1 conv or fc if w.shape[2] == 1: # 1x1 conv or fc
rec_weight = least_square_sklearn(X=masked_X, Y=Y) rec_weight = least_square_sklearn(X=masked_X, Y=Y)
rec_weight = rec_weight.reshape(-1, 1, 1, d_prime) # (C_out, K_h, K_w, C_in') # (C_out, K_h, K_w, C_in')
rec_weight = np.transpose(rec_weight, (0, 3, 1, 2)) # (C_out, C_in', K_h, K_w) rec_weight = rec_weight.reshape(-1, 1, 1, d_prime)
# (C_out, C_in', K_h, K_w)
rec_weight = np.transpose(rec_weight, (0, 3, 1, 2))
else: else:
raise NotImplementedError('Current code only supports 1x1 conv now!') raise NotImplementedError(
'Current code only supports 1x1 conv now!')
rec_weight_pad = np.zeros_like(w) rec_weight_pad = np.zeros_like(w)
# pylint: disable=all # pylint: disable=all
rec_weight_pad[:, mask, :, :] = rec_weight rec_weight_pad[:, mask, :, :] = rec_weight
...@@ -513,7 +829,8 @@ class AMCWeightMasker(WeightMasker): ...@@ -513,7 +829,8 @@ class AMCWeightMasker(WeightMasker):
assert len(rec_weight.shape) == 2 assert len(rec_weight.shape) == 2
# now assign # now assign
wrapper.module.weight.data = torch.from_numpy(rec_weight).to(weight.device) wrapper.module.weight.data = torch.from_numpy(
rec_weight).to(weight.device)
mask_weight = torch.zeros_like(weight) mask_weight = torch.zeros_like(weight)
if wrapper.type == 'Linear': if wrapper.type == 'Linear':
......
...@@ -290,4 +290,5 @@ class ChannelMaskConflict(MaskFix): ...@@ -290,4 +290,5 @@ class ChannelMaskConflict(MaskFix):
_logger.info('Pruned Filters after fixing conflict:') _logger.info('Pruned Filters after fixing conflict:')
pruned_filters = set(list(range(ori_channels)))-channel_remain pruned_filters = set(list(range(ori_channels)))-channel_remain
_logger.info(str(sorted(pruned_filters))) _logger.info(str(sorted(pruned_filters)))
return self.masks return self.masks
...@@ -484,3 +484,6 @@ class GroupDependency(Dependency): ...@@ -484,3 +484,6 @@ class GroupDependency(Dependency):
for name in self.dependency: for name in self.dependency:
group = self.dependency[name] group = self.dependency[name]
csv_w.writerow([name, group]) csv_w.writerow([name, group])
@property
def dependency_sets(self):
return self.dependency
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import random
import unittest
from unittest import TestCase, main
import torch
import torch.nn as nn
import torchvision.models as models
import numpy as np
from nni.compression.torch import L1FilterPruner, L2FilterPruner, FPGMPruner, \
TaylorFOWeightFilterPruner, ActivationAPoZRankFilterPruner, \
ActivationMeanRankFilterPruner
from nni.compression.torch import ModelSpeedup
unittest.TestLoader.sortTestMethodsUsing = None
MODEL_FILE, MASK_FILE = './model.pth', './mask.pth'
def generate_random_sparsity(model):
"""
generate a random sparsity for all conv layers in the
model.
"""
cfg_list = []
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
sparsity = np.random.uniform(0.5, 0.99)
cfg_list.append({'op_types': ['Conv2d'], 'op_names': [name],
'sparsity': sparsity})
return cfg_list
def generate_random_sparsity_v2(model):
"""
only generate a random sparsity for some conv layers in
in the model.
"""
cfg_list = []
for name, module in model.named_modules():
# randomly pick 50% layers
if isinstance(module, nn.Conv2d) and random.uniform(0, 1) > 0.5:
sparsity = np.random.uniform(0.5, 0.99)
cfg_list.append({'op_types': ['Conv2d'], 'op_names': [name],
'sparsity': sparsity})
return cfg_list
class DependencyawareTest(TestCase):
@unittest.skipIf(torch.__version__ < "1.3.0", "not supported")
def test_dependency_aware_pruning(self):
model_zoo = ['resnet18']
pruners = [L1FilterPruner, L2FilterPruner, FPGMPruner, TaylorFOWeightFilterPruner]
sparsity = 0.7
cfg_list = [{'op_types': ['Conv2d'], 'sparsity':sparsity}]
dummy_input = torch.ones(1, 3, 224, 224)
for model_name in model_zoo:
for pruner in pruners:
print('Testing on ', pruner)
ori_filters = {}
Model = getattr(models, model_name)
net = Model(pretrained=True, progress=False)
# record the number of the filter of each conv layer
for name, module in net.named_modules():
if isinstance(module, nn.Conv2d):
ori_filters[name] = module.out_channels
# for the pruners that based on the activations, we need feed
# enough data before we call the compress function.
optimizer = torch.optim.SGD(net.parameters(), lr=0.0001,
momentum=0.9,
weight_decay=4e-5)
criterion = torch.nn.CrossEntropyLoss()
tmp_pruner = pruner(
net, cfg_list, optimizer, dependency_aware=True, dummy_input=dummy_input)
# train one single batch so that the the pruner can collect the
# statistic
optimizer.zero_grad()
out = net(dummy_input)
batchsize = dummy_input.size(0)
loss = criterion(out, torch.zeros(batchsize, dtype=torch.int64))
loss.backward()
optimizer.step()
tmp_pruner.compress()
tmp_pruner.export_model(MODEL_FILE, MASK_FILE)
# if we want to use the same model, we should unwrap the pruner before the speedup
tmp_pruner._unwrap_model()
ms = ModelSpeedup(net, dummy_input, MASK_FILE)
ms.speedup_model()
for name, module in net.named_modules():
if isinstance(module, nn.Conv2d):
expected = int(ori_filters[name] * (1-sparsity))
filter_diff = abs(expected - module.out_channels)
errmsg = '%s Ori: %d, Expected: %d, Real: %d' % (
name, ori_filters[name], expected, module.out_channels)
# because we are using the dependency-aware mode, so the number of the
# filters after speedup should be ori_filters[name] * ( 1 - sparsity )
print(errmsg)
assert filter_diff <= 1, errmsg
@unittest.skipIf(torch.__version__ < "1.3.0", "not supported")
def test_dependency_aware_random_config(self):
model_zoo = ['resnet18']
pruners = [L1FilterPruner, L2FilterPruner, FPGMPruner, TaylorFOWeightFilterPruner,
ActivationMeanRankFilterPruner, ActivationAPoZRankFilterPruner]
dummy_input = torch.ones(1, 3, 224, 224)
for model_name in model_zoo:
for pruner in pruners:
Model = getattr(models, model_name)
cfg_generator = [generate_random_sparsity, generate_random_sparsity_v2]
for _generator in cfg_generator:
net = Model(pretrained=True, progress=False)
cfg_list = _generator(net)
print('\n\nModel:', model_name)
print('Pruner', pruner)
print('Config_list:', cfg_list)
# for the pruners that based on the activations, we need feed
# enough data before we call the compress function.
optimizer = torch.optim.SGD(net.parameters(), lr=0.0001,
momentum=0.9,
weight_decay=4e-5)
criterion = torch.nn.CrossEntropyLoss()
tmp_pruner = pruner(
net, cfg_list, optimizer, dependency_aware=True, dummy_input=dummy_input)
# train one single batch so that the the pruner can collect the
# statistic
optimizer.zero_grad()
out = net(dummy_input)
batchsize = dummy_input.size(0)
loss = criterion(out, torch.zeros(batchsize, dtype=torch.int64))
loss.backward()
optimizer.step()
tmp_pruner.compress()
tmp_pruner.export_model(MODEL_FILE, MASK_FILE)
# if we want to use the same model, we should unwrap the pruner before the speedup
tmp_pruner._unwrap_model()
ms = ModelSpeedup(net, dummy_input, MASK_FILE)
ms.speedup_model()
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment