Unverified Commit 82dea355 authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

support lottery ticket hypothesis (#1685)

parent e240b895
Lottery Ticket Hypothesis on NNI
===
## Introduction
The paper [The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks](https://arxiv.org/abs/1803.03635) is mainly a measurement and analysis paper, it delivers very interesting insights. To support it on NNI, we mainly implement the training approach for finding *winning tickets*.
In this paper, the authors use the following process to prune a model, called *iterative prunning*:
>1. Randomly initialize a neural network f(x;theta_0) (where theta_0 follows D_{theta}).
>2. Train the network for j iterations, arriving at parameters theta_j.
>3. Prune p% of the parameters in theta_j, creating a mask m.
>4. Reset the remaining parameters to their values in theta_0, creating the winning ticket f(x;m*theta_0).
>5. Repeat step 2, 3, and 4.
If the configured final sparsity is P (e.g., 0.8) and there are n times iterative pruning, each iterative pruning prunes 1-(1-P)^(1/n) of the weights that survive the previous round.
## Reproduce Results
We try to reproduce the experiment result of the fully connected network on MNIST using the same configuration as in the paper. The code can be referred [here](https://github.com/microsoft/nni/tree/master/examples/model_compress/lottery_torch_mnist_fc.py). In this experiment, we prune 10 times, for each pruning we train the pruned model for 50 epochs.
![](../../img/lottery_ticket_mnist_fc.png)
The above figure shows the result of the fully connected network. `round0-sparsity-0.0` is the performance without pruning. Consistent with the paper, pruning around 80% also obtain similar performance compared to non-pruning, and converges a little faster. If pruning too much, e.g., larger than 94%, the accuracy becomes lower and convergence becomes a little slower. A little different from the paper, the trend of the data in the paper is relatively more clear.
...@@ -12,6 +12,7 @@ We have provided two naive compression algorithms and three popular ones for use ...@@ -12,6 +12,7 @@ We have provided two naive compression algorithms and three popular ones for use
|---|---| |---|---|
| [Level Pruner](./Pruner.md#level-pruner) | Pruning the specified ratio on each weight based on absolute values of weights | | [Level Pruner](./Pruner.md#level-pruner) | Pruning the specified ratio on each weight based on absolute values of weights |
| [AGP Pruner](./Pruner.md#agp-pruner) | Automated gradual pruning (To prune, or not to prune: exploring the efficacy of pruning for model compression) [Reference Paper](https://arxiv.org/abs/1710.01878)| | [AGP Pruner](./Pruner.md#agp-pruner) | Automated gradual pruning (To prune, or not to prune: exploring the efficacy of pruning for model compression) [Reference Paper](https://arxiv.org/abs/1710.01878)|
| [Lottery Ticket Pruner](./Pruner.md#agp-pruner) | The pruning process used by "The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks". It prunes a model iteratively. [Reference Paper](https://arxiv.org/abs/1803.03635)|
| [FPGM Pruner](./Pruner.md#fpgm-pruner) | Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration [Reference Paper](https://arxiv.org/pdf/1811.00250.pdf)| | [FPGM Pruner](./Pruner.md#fpgm-pruner) | Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration [Reference Paper](https://arxiv.org/pdf/1811.00250.pdf)|
| [Naive Quantizer](./Quantizer.md#naive-quantizer) | Quantize weights to default 8 bits | | [Naive Quantizer](./Quantizer.md#naive-quantizer) | Quantize weights to default 8 bits |
| [QAT Quantizer](./Quantizer.md#qat-quantizer) | Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference. [Reference Paper](http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf)| | [QAT Quantizer](./Quantizer.md#qat-quantizer) | Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference. [Reference Paper](http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf)|
......
...@@ -92,6 +92,47 @@ You can view example for more information ...@@ -92,6 +92,47 @@ You can view example for more information
*** ***
## Lottery Ticket Hypothesis
[The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks](https://arxiv.org/abs/1803.03635), authors Jonathan Frankle and Michael Carbin,provides comprehensive measurement and analysis, and articulate the *lottery ticket hypothesis*: dense, randomly-initialized, feed-forward networks contain subnetworks (*winning tickets*) that -- when trained in isolation -- reach test accuracy comparable to the original network in a similar number of iterations.
In this paper, the authors use the following process to prune a model, called *iterative prunning*:
>1. Randomly initialize a neural network f(x;theta_0) (where theta_0 follows D_{theta}).
>2. Train the network for j iterations, arriving at parameters theta_j.
>3. Prune p% of the parameters in theta_j, creating a mask m.
>4. Reset the remaining parameters to their values in theta_0, creating the winning ticket f(x;m*theta_0).
>5. Repeat step 2, 3, and 4.
If the configured final sparsity is P (e.g., 0.8) and there are n times iterative pruning, each iterative pruning prunes 1-(1-P)^(1/n) of the weights that survive the previous round.
### Usage
PyTorch code
```python
from nni.compression.torch import LotteryTicketPruner
config_list = [{
'prune_iterations': 5,
'sparsity': 0.8,
'op_types': ['default']
}]
pruner = LotteryTicketPruner(model, config_list, optimizer)
pruner.compress()
for _ in pruner.get_prune_iterations():
pruner.prune_iteration_start()
for epoch in range(epoch_num):
...
```
The above configuration means that there are 5 times of iterative pruning. As the 5 times iterative pruning are executed in the same run, LotteryTicketPruner needs `model` and `optimizer` (**Note that should add `lr_scheduler` if used**) to reset their states every time a new prune iteration starts. Please use `get_prune_iterations` to get the pruning iterations, and invoke `prune_iteration_start` at the beginning of each iteration. `epoch_num` is better to be large enough for model convergence, because the hypothesis is that the performance (accuracy) got in latter rounds with high sparsity could be comparable with that got in the first round. Simple reproducing results can be found [here](./LotteryTicketHypothesis.md).
*Tensorflow version will be supported later.*
#### User configuration for LotteryTicketPruner
* **prune_iterations:** The number of rounds for the iterative pruning, i.e., the number of iterative pruning.
* **sparsity:** The final sparsity when the compression is done.
***
## FPGM Pruner ## FPGM Pruner
FPGM Pruner is an implementation of paper [Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration](https://arxiv.org/pdf/1811.00250.pdf) FPGM Pruner is an implementation of paper [Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration](https://arxiv.org/pdf/1811.00250.pdf)
......
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from nni.compression.torch import LotteryTicketPruner
class fc1(nn.Module):
def __init__(self, num_classes=10):
super(fc1, self).__init__()
self.classifier = nn.Sequential(
nn.Linear(28*28, 300),
nn.ReLU(inplace=True),
nn.Linear(300, 100),
nn.ReLU(inplace=True),
nn.Linear(100, num_classes),
)
def forward(self, x):
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
def train(model, train_loader, optimizer, criterion):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.train()
for batch_idx, (imgs, targets) in enumerate(train_loader):
optimizer.zero_grad()
imgs, targets = imgs.to(device), targets.to(device)
output = model(imgs)
train_loss = criterion(output, targets)
train_loss.backward()
optimizer.step()
return train_loss.item()
def test(model, test_loader, criterion):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
correct += pred.eq(target.data.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100. * correct / len(test_loader.dataset)
return accuracy
if __name__ == '__main__':
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
traindataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
testdataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(traindataset, batch_size=60, shuffle=True, num_workers=0, drop_last=False)
test_loader = torch.utils.data.DataLoader(testdataset, batch_size=60, shuffle=False, num_workers=0, drop_last=True)
model = fc1().to("cuda" if torch.cuda.is_available() else "cpu")
optimizer = torch.optim.Adam(model.parameters(), lr=1.2e-3)
criterion = nn.CrossEntropyLoss()
configure_list = [{
'prune_iterations': 10,
'sparsity': 0.96,
'op_types': ['default']
}]
pruner = LotteryTicketPruner(model, configure_list, optimizer)
pruner.compress()
for i in pruner.get_prune_iterations():
pruner.prune_iteration_start()
loss = 0
accuracy = 0
for epoch in range(50):
loss = train(model, train_loader, optimizer, criterion)
accuracy = test(model, test_loader, criterion)
print('current epoch: {0}, loss: {1}, accuracy: {2}'.format(epoch, loss, accuracy))
print('prune iteration: {0}, loss: {1}, accuracy: {2}'.format(i, loss, accuracy))
pruner.export_model('model.pth', 'mask.pth')
from .compressor import LayerInfo, Compressor, Pruner, Quantizer from .compressor import LayerInfo, Compressor, Pruner, Quantizer
from .builtin_pruners import * from .builtin_pruners import *
from .builtin_quantizers import * from .builtin_quantizers import *
from .lottery_ticket import LotteryTicketPruner
...@@ -13,7 +13,6 @@ class LayerInfo: ...@@ -13,7 +13,6 @@ class LayerInfo:
self._forward = None self._forward = None
class Compressor: class Compressor:
""" """
Abstract base PyTorch compressor Abstract base PyTorch compressor
...@@ -37,7 +36,6 @@ class Compressor: ...@@ -37,7 +36,6 @@ class Compressor:
def detect_modules_to_compress(self): def detect_modules_to_compress(self):
""" """
detect all modules should be compressed, and save the result in `self.modules_to_compress`. detect all modules should be compressed, and save the result in `self.modules_to_compress`.
The model will be instrumented and user should never edit it after calling this method. The model will be instrumented and user should never edit it after calling this method.
""" """
if self.modules_to_compress is None: if self.modules_to_compress is None:
...@@ -49,7 +47,6 @@ class Compressor: ...@@ -49,7 +47,6 @@ class Compressor:
self.modules_to_compress.append((layer, config)) self.modules_to_compress.append((layer, config))
return self.modules_to_compress return self.modules_to_compress
def compress(self): def compress(self):
""" """
Compress the model with algorithm implemented by subclass. Compress the model with algorithm implemented by subclass.
...@@ -218,6 +215,8 @@ class Pruner(Compressor): ...@@ -218,6 +215,8 @@ class Pruner(Compressor):
input_shape : list or tuple input_shape : list or tuple
input shape to onnx model input shape to onnx model
""" """
if self.detect_modules_to_compress() and not self.mask_dict:
_logger.warning('You may not use self.mask_dict in base Pruner class to record masks')
assert model_path is not None, 'model_path must be specified' assert model_path is not None, 'model_path must be specified'
for name, m in self.bound_model.named_modules(): for name, m in self.bound_model.named_modules():
if name == "": if name == "":
...@@ -227,25 +226,20 @@ class Pruner(Compressor): ...@@ -227,25 +226,20 @@ class Pruner(Compressor):
mask_sum = mask.sum().item() mask_sum = mask.sum().item()
mask_num = mask.numel() mask_num = mask.numel()
_logger.info('Layer: %s Sparsity: %.2f', name, 1 - mask_sum / mask_num) _logger.info('Layer: %s Sparsity: %.2f', name, 1 - mask_sum / mask_num)
print('Layer: %s Sparsity: %.2f' % (name, 1 - mask_sum / mask_num))
m.weight.data = m.weight.data.mul(mask) m.weight.data = m.weight.data.mul(mask)
else: else:
_logger.info('Layer: %s NOT compressed', name) _logger.info('Layer: %s NOT compressed', name)
print('Layer: %s NOT compressed' % name)
torch.save(self.bound_model.state_dict(), model_path) torch.save(self.bound_model.state_dict(), model_path)
_logger.info('Model state_dict saved to %s', model_path) _logger.info('Model state_dict saved to %s', model_path)
print('Model state_dict saved to %s' % model_path)
if mask_path is not None: if mask_path is not None:
torch.save(self.mask_dict, mask_path) torch.save(self.mask_dict, mask_path)
_logger.info('Mask dict saved to %s', mask_path) _logger.info('Mask dict saved to %s', mask_path)
print('Mask dict saved to %s' % mask_path)
if onnx_path is not None: if onnx_path is not None:
assert input_shape is not None, 'input_shape must be specified to export onnx model' assert input_shape is not None, 'input_shape must be specified to export onnx model'
# input info needed # input info needed
input_data = torch.Tensor(*input_shape) input_data = torch.Tensor(*input_shape)
torch.onnx.export(self.bound_model, input_data, onnx_path) torch.onnx.export(self.bound_model, input_data, onnx_path)
_logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path) _logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path)
print('Model in onnx with input shape %s saved to %s' % (input_data.shape, onnx_path))
class Quantizer(Compressor): class Quantizer(Compressor):
......
import copy
import logging
import torch
from .compressor import Pruner
_logger = logging.getLogger(__name__)
class LotteryTicketPruner(Pruner):
"""
This is a Pytorch implementation of the paper "The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks",
following NNI model compression interface.
1. Randomly initialize a neural network f(x;theta_0) (where theta_0 follows D_{theta}).
2. Train the network for j iterations, arriving at parameters theta_j.
3. Prune p% of the parameters in theta_j, creating a mask m.
4. Reset the remaining parameters to their values in theta_0, creating the winning ticket f(x;m*theta_0).
5. Repeat step 2, 3, and 4.
"""
def __init__(self, model, config_list, optimizer, lr_scheduler=None, reset_weights=True):
"""
Parameters
----------
model : pytorch model
The model to be pruned
config_list : list
Supported keys:
- prune_iterations : The number of rounds for the iterative pruning.
- sparsity : The final sparsity when the compression is done.
optimizer : pytorch optimizer
The optimizer for the model
lr_scheduler : pytorch lr scheduler
The lr scheduler for the model if used
reset_weights : bool
Whether reset weights and optimizer at the beginning of each round.
"""
super().__init__(model, config_list)
self.curr_prune_iteration = None
self.prune_iterations = self._validate_config(config_list)
# save init weights and optimizer
self.reset_weights = reset_weights
if self.reset_weights:
self._model = model
self._optimizer = optimizer
self._model_state = copy.deepcopy(model.state_dict())
self._optimizer_state = copy.deepcopy(optimizer.state_dict())
self._lr_scheduler = lr_scheduler
if lr_scheduler is not None:
self._scheduler_state = copy.deepcopy(lr_scheduler.state_dict())
def _validate_config(self, config_list):
prune_iterations = None
for config in config_list:
assert 'prune_iterations' in config, 'prune_iterations must exist in your config'
assert 'sparsity' in config, 'sparsity must exist in your config'
if prune_iterations is not None:
assert prune_iterations == config['prune_iterations'], 'The values of prune_iterations must be equal in your config'
prune_iterations = config['prune_iterations']
return prune_iterations
def _print_masks(self, print_mask=False):
torch.set_printoptions(threshold=1000)
for op_name in self.mask_dict.keys():
mask = self.mask_dict[op_name]
print('op name: ', op_name)
if print_mask:
print('mask: ', mask)
# calculate current sparsity
mask_num = mask.sum().item()
mask_size = mask.numel()
print('sparsity: ', 1 - mask_num / mask_size)
torch.set_printoptions(profile='default')
def _calc_sparsity(self, sparsity):
keep_ratio_once = (1 - sparsity) ** (1 / self.prune_iterations)
curr_keep_ratio = keep_ratio_once ** self.curr_prune_iteration
return max(1 - curr_keep_ratio, 0)
def _calc_mask(self, weight, sparsity, op_name):
if self.curr_prune_iteration == 0:
mask = torch.ones(weight.shape).type_as(weight)
else:
curr_sparsity = self._calc_sparsity(sparsity)
assert self.mask_dict.get(op_name) is not None
curr_mask = self.mask_dict.get(op_name)
w_abs = weight.abs() * curr_mask
k = int(w_abs.numel() * curr_sparsity)
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
mask = torch.gt(w_abs, threshold).type_as(weight)
return mask
def calc_mask(self, layer, config):
"""
Generate mask for the given ``weight``.
Parameters
----------
layer : LayerInfo
The layer to be pruned
config : dict
Pruning configurations for this weight
Returns
-------
tensor
The mask for this weight
"""
assert self.mask_dict.get(layer.name) is not None, 'Please call iteration_start before training'
mask = self.mask_dict[layer.name]
return mask
def get_prune_iterations(self):
"""
Return the range for iterations.
In the first prune iteration, masks are all one, thus, add one more iteration
Returns
-------
list
A list for pruning iterations
"""
return range(self.prune_iterations + 1)
def prune_iteration_start(self):
"""
Control the pruning procedure on updated epoch number.
Should be called at the beginning of the epoch.
"""
if self.curr_prune_iteration is None:
self.curr_prune_iteration = 0
else:
self.curr_prune_iteration += 1
assert self.curr_prune_iteration < self.prune_iterations + 1, 'Exceed the configured prune_iterations'
modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress:
sparsity = config.get('sparsity')
mask = self._calc_mask(layer.module.weight.data, sparsity, layer.name)
self.mask_dict.update({layer.name: mask})
self._print_masks()
# reinit weights back to original after new masks are generated
if self.reset_weights:
self._model.load_state_dict(self._model_state)
self._optimizer.load_state_dict(self._optimizer_state)
if self._lr_scheduler is not None:
self._lr_scheduler.load_state_dict(self._scheduler_state)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment