ModelSpeedup.rst 5.92 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
Speed up Masked Model
=====================

*This feature is in Beta version.*

Introduction
------------

Pruning algorithms usually use weight masks to simulate the real pruning. Masks can be used
to check model performance of a specific pruning (or sparsity), but there is no real speedup.
Since model speedup is the ultimate goal of model pruning, we try to provide a tool to users
to convert a model to a smaller one based on user provided masks (the masks come from the
pruning algorithms).

There are two types of pruning. One is fine-grained pruning, it does not change the shape of weights, and input/output tensors. Sparse kernel is required to speed up a fine-grained pruned layer. The other is coarse-grained pruning (e.g., channels), shape of weights and input/output tensors usually change due to such pruning. To speed up this kind of pruning, there is no need to use sparse kernel, just replace the pruned layer with smaller one. Since the support of sparse kernels in community is limited, we only support the speedup of coarse-grained pruning and leave the support of fine-grained pruning in future.

Design and Implementation
-------------------------

To speed up a model, the pruned layers should be replaced, either replaced with smaller layer for coarse-grained mask, or replaced with sparse kernel for fine-grained mask. Coarse-grained mask usually changes the shape of weights or input/output tensors, thus, we should do shape inference to check are there other unpruned layers should be replaced as well due to shape change. Therefore, in our design, there are two main steps: first, do shape inference to find out all the modules that should be replaced; second, replace the modules. The first step requires topology (i.e., connections) of the model, we use ``jit.trace`` to obtain the model graph for PyTorch.

For each module, we should prepare four functions, three for shape inference and one for module replacement. The three shape inference functions are: given weight shape infer input/output shape, given input shape infer weight/output shape, given output shape infer weight/input shape. The module replacement function returns a newly created module which is smaller.

Usage
-----

.. code-block:: python

   from nni.compression.pytorch import ModelSpeedup
   # model: the model you want to speed up
   # dummy_input: dummy input of the model, given to `jit.trace`
   # masks_file: the mask file created by pruning algorithms
   m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file)
   m_speedup.speedup_model()
   dummy_input = dummy_input.to(device)
   start = time.time()
   out = model(dummy_input)
   print('elapsed time: ', time.time() - start)

40
For complete examples please refer to :githublink:`the code <examples/model_compress/pruning/speedup/model_speedup.py>`
41
42
43
44
45
46
47
48
49
50
51
52
53

NOTE: The current implementation supports PyTorch 1.3.1 or newer.

Limitations
-----------

Since every module requires four functions for shape inference and module replacement, this is a large amount of work, we only implemented the ones that are required by the examples. If you want to speed up your own model which cannot supported by the current implementation, you are welcome to contribute.

For PyTorch we can only replace modules, if functions in ``forward`` should be replaced, our current implementation does not work. One workaround is make the function a PyTorch module.

Speedup Results of Examples
---------------------------

54
The code of these experiments can be found :githublink:`here <examples/model_compress/pruning/speedup/model_speedup.py>`.
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190

slim pruner example
^^^^^^^^^^^^^^^^^^^

on one V100 GPU,
input tensor: ``torch.randn(64, 3, 32, 32)``

.. list-table::
   :header-rows: 1
   :widths: auto

   * - Times
     - Mask Latency
     - Speedup Latency
   * - 1
     - 0.01197
     - 0.005107
   * - 2
     - 0.02019
     - 0.008769
   * - 4
     - 0.02733
     - 0.014809
   * - 8
     - 0.04310
     - 0.027441
   * - 16
     - 0.07731
     - 0.05008
   * - 32
     - 0.14464
     - 0.10027


fpgm pruner example
^^^^^^^^^^^^^^^^^^^

on cpu,
input tensor: ``torch.randn(64, 1, 28, 28)``\ ,
too large variance

.. list-table::
   :header-rows: 1
   :widths: auto

   * - Times
     - Mask Latency
     - Speedup Latency
   * - 1
     - 0.01383
     - 0.01839
   * - 2
     - 0.01167
     - 0.003558
   * - 4
     - 0.01636
     - 0.01088
   * - 40
     - 0.14412
     - 0.08268
   * - 40
     - 1.29385
     - 0.14408
   * - 40
     - 0.41035
     - 0.46162
   * - 400
     - 6.29020
     - 5.82143


l1filter pruner example
^^^^^^^^^^^^^^^^^^^^^^^

on one V100 GPU,
input tensor: ``torch.randn(64, 3, 32, 32)``

.. list-table::
   :header-rows: 1
   :widths: auto

   * - Times
     - Mask Latency
     - Speedup Latency
   * - 1
     - 0.01026
     - 0.003677
   * - 2
     - 0.01657
     - 0.008161
   * - 4
     - 0.02458
     - 0.020018
   * - 8
     - 0.03498
     - 0.025504
   * - 16
     - 0.06757
     - 0.047523
   * - 32
     - 0.10487
     - 0.086442


APoZ pruner example
^^^^^^^^^^^^^^^^^^^

on one V100 GPU,
input tensor: ``torch.randn(64, 3, 32, 32)``

.. list-table::
   :header-rows: 1
   :widths: auto

   * - Times
     - Mask Latency
     - Speedup Latency
   * - 1
     - 0.01389
     - 0.004208
   * - 2
     - 0.01628
     - 0.008310
   * - 4
     - 0.02521
     - 0.014008
   * - 8
     - 0.03386
     - 0.023923
   * - 16
     - 0.06042
     - 0.046183
   * - 32
     - 0.12421
     - 0.087113

191
192
193
194
195
196
197
198
199

SimulatedAnnealing pruner example
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In this experiment, we use SimulatedAnnealing pruner to prune the resnet18 on the cifar10 dataset.
We measure the latencies and accuracies of the pruned model under different sparsity ratios, as shown in the following figure.
The latency is measured on one V100 GPU and the input tensor is  ``torch.randn(128, 3, 32, 32)``.


J-shang's avatar
J-shang committed
200
201
202
203
204
205
206
207
208
.. image:: ../../img/SA_latency_accuracy.png


User configuration for ModelSpeedup
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

**PyTorch**

..  autoclass:: nni.compression.pytorch.ModelSpeedup