pruning_quick_start_mnist.py 3.53 KB
Newer Older
J-shang's avatar
J-shang committed
1
2
3
4
5
"""
Pruning Quickstart
==================

Model pruning is a technique to reduce the model size and computation by reducing model weight size or intermediate state size.
6
There are three common practices for pruning a DNN model:
J-shang's avatar
J-shang committed
7

8
9
10
#. Pre-training a model -> Pruning the model -> Fine-tuning the pruned model
#. Pruning a model during training (i.e., pruning aware training) -> Fine-tuning the pruned model
#. Pruning a model -> Training the pruned model from scratch
J-shang's avatar
J-shang committed
11

12
13
NNI supports all of the above pruning practices by working on the key pruning stage.
Following this tutorial for a quick look at how to use NNI to prune a model in a common practice.
J-shang's avatar
J-shang committed
14
15
16
17
18
19
"""

# %%
# Preparation
# -----------
#
20
# In this tutorial, we use a simple model and pre-trained on MNIST dataset.
J-shang's avatar
J-shang committed
21
22
23
24
25
26
27
28
29
30
31
# If you are familiar with defining a model and training in pytorch, you can skip directly to `Pruning Model`_.

import torch
import torch.nn.functional as F
from torch.optim import SGD

from scripts.compression_mnist_model import TorchModel, trainer, evaluator, device

# define the model
model = TorchModel().to(device)

32
33
34
35
36
# show the model structure, note that pruner will wrap the model layer.
print(model)

# %%

J-shang's avatar
J-shang committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# define the optimizer and criterion for pre-training

optimizer = SGD(model.parameters(), 1e-2)
criterion = F.nll_loss

# pre-train and evaluate the model on MNIST dataset
for epoch in range(3):
    trainer(model, optimizer, criterion)
    evaluator(model)

# %%
# Pruning Model
# -------------
#
51
52
# Using L1NormPruner to prune the model and generate the masks.
# Usually, a pruner requires original model and ``config_list`` as its inputs.
53
# Detailed about how to write ``config_list`` please refer :doc:`compression config specification <../compression/compression_config_list>`.
J-shang's avatar
J-shang committed
54
#
55
# The following `config_list` means all layers whose type is `Linear` or `Conv2d` will be pruned,
J-shang's avatar
J-shang committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# except the layer named `fc3`, because `fc3` is `exclude`.
# The final sparsity ratio for each layer is 50%. The layer named `fc3` will not be pruned.

config_list = [{
    'sparsity_per_layer': 0.5,
    'op_types': ['Linear', 'Conv2d']
}, {
    'exclude': True,
    'op_names': ['fc3']
}]

# %%
# Pruners usually require `model` and `config_list` as input arguments.

J-shang's avatar
J-shang committed
70
from nni.compression.pytorch.pruning import L1NormPruner
J-shang's avatar
J-shang committed
71
pruner = L1NormPruner(model, config_list)
72
73

# show the wrapped model structure, `PrunerModuleWrapper` have wrapped the layers that configured in the config_list.
J-shang's avatar
J-shang committed
74
print(model)
75
76
77

# %%

J-shang's avatar
J-shang committed
78
79
80
81
# compress the model and generate the masks
_, masks = pruner.compress()
# show the masks sparsity
for name, mask in masks.items():
82
    print(name, ' sparsity : ', '{:.2}'.format(mask['weight'].sum() / mask['weight'].numel()))
J-shang's avatar
J-shang committed
83
84

# %%
85
86
# Speedup the original model with masks, note that `ModelSpeedup` requires an unwrapped model.
# The model becomes smaller after speedup,
J-shang's avatar
J-shang committed
87
88
# and reaches a higher sparsity ratio because `ModelSpeedup` will propagate the masks across layers.

89
# need to unwrap the model, if the model is wrapped before speedup
J-shang's avatar
J-shang committed
90
91
pruner._unwrap_model()

92
# speedup the model
J-shang's avatar
J-shang committed
93
94
95
96
97
from nni.compression.pytorch.speedup import ModelSpeedup

ModelSpeedup(model, torch.rand(3, 1, 28, 28).to(device), masks).speedup_model()

# %%
98
# the model will become real smaller after speedup
J-shang's avatar
J-shang committed
99
100
101
102
103
104
print(model)

# %%
# Fine-tuning Compacted Model
# ---------------------------
# Note that if the model has been sped up, you need to re-initialize a new optimizer for fine-tuning.
105
# Because speedup will replace the masked big layers with dense small ones.
J-shang's avatar
J-shang committed
106
107
108
109

optimizer = SGD(model.parameters(), 1e-2)
for epoch in range(3):
    trainer(model, optimizer, criterion)