Framework.md 10.6 KB
Newer Older
1
# Customize A New Compression Algorithm
QuanluZhang's avatar
QuanluZhang committed
2

3
4
5
6
7
8
9
10
11
12
13
```eval_rst
.. contents::
```

To simplify writing a new compression algorithm, we design programming interfaces which are simple but flexible enough. There are interfaces for pruning and quantization respectively. Below, we first demonstrate how to customize a new pruning algorithm and then demonstrate how to customize a new quantization algorithm.

## Customize a new pruning algorithm

To better demonstrate how to customize a new pruning algorithm, it is necessary for users to first understand the framework for supporting various pruning algorithms in NNI.

### Framework overview for pruning algorithms
Cjkkkk's avatar
Cjkkkk committed
14

15
Following example shows how to use a pruner:
Cjkkkk's avatar
Cjkkkk committed
16

17
18
```python
from nni.compression.torch import LevelPruner
Cjkkkk's avatar
Cjkkkk committed
19

20
# load a pretrained model or train a model before using a pruner
Cjkkkk's avatar
Cjkkkk committed
21
22
23

configure_list = [{
    'sparsity': 0.7,
24
    'op_types': ['Conv2d', 'Linear'],
Cjkkkk's avatar
Cjkkkk committed
25
26
27
}]

optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
28
pruner = LevelPruner(model, configure_list, optimizer)
Cjkkkk's avatar
Cjkkkk committed
29
model = pruner.compress()
30
31
32

# model is ready for pruning, now start finetune the model,
# the model will be pruned during training automatically
Cjkkkk's avatar
Cjkkkk committed
33
34
```

35
36
37
38
A pruner receives `model`, `config_list` and `optimizer` as arguments. It prunes the model per the `config_list` during training loop by adding a hook on `optimizer.step()`.

From implementation perspective, a pruner consists of a `weight masker` instance and multiple `module wrapper` instances.

39
#### Weight masker
40
41
42

A `weight masker` is the implementation of pruning algorithms, it can prune a specified layer wrapped by `module wrapper` with specified sparsity.

43
#### Module wrapper
44
45
46
47
48
49
50
51
52
53
54
55

A `module wrapper` is a module containing:

1. the origin module
2. some buffers used by `calc_mask`
3. a new forward method that applies masks before running the original forward method.

the reasons to use `module wrapper`:

1. some buffers are needed by `calc_mask` to calculate masks and these buffers should be registered in `module wrapper` so that the original modules are not contaminated.
2. a new `forward` method is needed to apply masks to weight before calling the real `forward` method.

56
#### Pruner
57
58
59
60
61
62
63

A `pruner` is responsible for:

1. Manage / verify config_list.
2. Use `module wrapper` to wrap the model layers and add hook on `optimizer.step`
3. Use `weight masker` to calculate masks of layers while pruning.
4. Export pruned model weights and masks.
Cjkkkk's avatar
Cjkkkk committed
64

65
### Implement a new pruning algorithm
Cjkkkk's avatar
Cjkkkk committed
66

67
Implementing a new pruning algorithm requires implementing a `weight masker` class which shoud be a subclass of `WeightMasker`, and a `pruner` class, which should be a subclass `Pruner`.
68
69

An implementation of `weight masker` may look like this:
Cjkkkk's avatar
Cjkkkk committed
70

71
72
73
74
75
76
77
78
79
80
81
82
```python
class MyMasker(WeightMasker):
    def __init__(self, model, pruner):
        super().__init__(model, pruner)
        # You can do some initialization here, such as collecting some statistics data
        # if it is necessary for your algorithms to calculate the masks.

    def calc_mask(self, sparsity, wrapper, wrapper_idx=None):
        # calculate the masks based on the wrapper.weight, and sparsity, 
        # and anything else
        # mask = ...
        return {'weight_mask': mask}
Cjkkkk's avatar
Cjkkkk committed
83
```
84
85
86

You can reference nni provided [weight masker](https://github.com/microsoft/nni/blob/master/src/sdk/pynni/nni/compression/torch/pruning/structured_pruning.py) implementations to implement your own weight masker.

87
A basic `pruner` looks likes this:
Cjkkkk's avatar
Cjkkkk committed
88
89

```python
90
class MyPruner(Pruner):
Cjkkkk's avatar
Cjkkkk committed
91
92
93
    def __init__(self, model, config_list, optimizer):
        super().__init__(model, config_list, optimizer)
        self.set_wrappers_attribute("if_calculated", False)
94
95
96
97
98
        # construct a weight masker instance
        self.masker = MyMasker(model, self)

    def calc_mask(self, wrapper, wrapper_idx=None):
        sparsity = wrapper.config['sparsity']
Cjkkkk's avatar
Cjkkkk committed
99
        if wrapper.if_calculated:
100
101
            # Already pruned, do not prune again as a one-shot pruner
            return None
Cjkkkk's avatar
Cjkkkk committed
102
        else:
103
104
            # call your masker to actually calcuate the mask for this layer
            masks = self.masker.calc_mask(sparsity=sparsity, wrapper=wrapper, wrapper_idx=wrapper_idx)
Cjkkkk's avatar
Cjkkkk committed
105
            wrapper.if_calculated = True
106
107
            return masks

Cjkkkk's avatar
Cjkkkk committed
108
109
```

110
111
112
113
114
115
116
Reference nni provided [pruner](https://github.com/microsoft/nni/blob/master/src/sdk/pynni/nni/compression/torch/pruning/one_shot.py) implementations to implement your own pruner class.

### Set wrapper attribute

Sometimes `calc_mask` must save some state data, therefore users can use `set_wrappers_attribute` API to register attribute just like how buffers are registered in PyTorch modules. These buffers will be registered to `module wrapper`. Users can access these buffers through `module wrapper`.
In above example, we use `set_wrappers_attribute` to set a buffer `if_calculated` which is used as flag indicating if the mask of a layer is already calculated.

Cjkkkk's avatar
Cjkkkk committed
117
### Collect data during forward
118
119

Sometimes users want to collect some data during the modules' forward method, for example, the mean value of the activation. This can be done by adding a customized collector to module.
Cjkkkk's avatar
Cjkkkk committed
120
121

```python
122
123
124
125
126
127
128
129
130
131
132
133
134
135
class MyMasker(WeightMasker):
    def __init__(self, model, pruner):
        super().__init__(model, pruner)
        # Set attribute `collected_activation` for all wrappers to store
        # activations for each layer
        self.pruner.set_wrappers_attribute("collected_activation", [])
        self.activation = torch.nn.functional.relu

        def collector(wrapper, input_, output):
            # The collected activation can be accessed via each wrapper's collected_activation
            # attribute
            wrapper.collected_activation.append(self.activation(output.detach().cpu()))

        self.pruner.hook_id = self.pruner.add_activation_collector(collector)
Cjkkkk's avatar
Cjkkkk committed
136
```
137

Cjkkkk's avatar
Cjkkkk committed
138
139
140
The collector function will be called each time the forward method runs.

Users can also remove this collector like this:
141

Cjkkkk's avatar
Cjkkkk committed
142
```python
143
144
145
146
147
148
# Save the collector identifier
collector_id = self.pruner.add_activation_collector(collector)

# When the collector is not used any more, it can be remove using
# the saved collector identifier
self.pruner.remove_activation_collector(collector_id)
Cjkkkk's avatar
Cjkkkk committed
149
150
151
```

### Multi-GPU support
152

Cjkkkk's avatar
Cjkkkk committed
153
On multi-GPU training, buffers and parameters are copied to multiple GPU every time the `forward` method runs on multiple GPU. If buffers and parameters are updated in the `forward` method, an `in-place` update is needed to ensure the update is effective.
154
Since `calc_mask` is called in the `optimizer.step` method, which happens after the `forward` method and happens only on one GPU, it supports multi-GPU naturally.
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277

***

## Customize a new quantization algorithm

To write a new quantization algorithm, you can write a class that inherits `nni.compression.torch.Quantizer`. Then, override the member functions with the logic of your algorithm. The member function to override is `quantize_weight`. `quantize_weight` directly returns the quantized weights rather than mask, because for quantization the quantized weights cannot be obtained by applying mask.

```python
from nni.compression.torch import Quantizer

class YourQuantizer(Quantizer):
    def __init__(self, model, config_list):
        """
        Suggest you to use the NNI defined spec for config
        """
        super().__init__(model, config_list)

    def quantize_weight(self, weight, config, **kwargs):
        """
        quantize should overload this method to quantize weight tensors.
        This method is effectively hooked to :meth:`forward` of the model.

        Parameters
        ----------
        weight : Tensor
            weight that needs to be quantized
        config : dict
            the configuration for weight quantization
        """

        # Put your code to generate `new_weight` here

        return new_weight
    
    def quantize_output(self, output, config, **kwargs):
        """
        quantize should overload this method to quantize output.
        This method is effectively hooked to `:meth:`forward` of the model.

        Parameters
        ----------
        output : Tensor
            output that needs to be quantized
        config : dict
            the configuration for output quantization
        """

        # Put your code to generate `new_output` here

        return new_output

    def quantize_input(self, *inputs, config, **kwargs):
        """
        quantize should overload this method to quantize input.
        This method is effectively hooked to :meth:`forward` of the model.

        Parameters
        ----------
        inputs : Tensor
            inputs that needs to be quantized
        config : dict
            the configuration for inputs quantization
        """

        # Put your code to generate `new_input` here

        return new_input

    def update_epoch(self, epoch_num):
        pass

    def step(self):
        """
        Can do some processing based on the model or weights binded
        in the func bind_model
        """
        pass
```

### Customize backward function

Sometimes it's necessary for a quantization operation to have a customized backward function, such as [Straight-Through Estimator](https://stackoverflow.com/questions/38361314/the-concept-of-straight-through-estimator-ste), user can customize a backward function as follow:

```python
from nni.compression.torch.compressor import Quantizer, QuantGrad, QuantType

class ClipGrad(QuantGrad):
    @staticmethod
    def quant_backward(tensor, grad_output, quant_type):
        """
        This method should be overrided by subclass to provide customized backward function,
        default implementation is Straight-Through Estimator
        Parameters
        ----------
        tensor : Tensor
            input of quantization operation
        grad_output : Tensor
            gradient of the output of quantization operation
        quant_type : QuantType
            the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`,
            you can define different behavior for different types.
        Returns
        -------
        tensor
            gradient of the input of quantization operation
        """

        # for quant_output function, set grad to zero if the absolute value of tensor is larger than 1
        if quant_type == QuantType.QUANT_OUTPUT: 
            grad_output[torch.abs(tensor) > 1] = 0
        return grad_output


class YourQuantizer(Quantizer):
    def __init__(self, model, config_list):
        super().__init__(model, config_list)
        # set your customized backward function to overwrite default backward function
        self.quant_grad = ClipGrad

```

If you do not customize `QuantGrad`, the default backward is Straight-Through Estimator. 
_Coming Soon_ ...