Framework.md 7.28 KB
Newer Older
1
# Framework overview of model compression
QuanluZhang's avatar
QuanluZhang committed
2

3
4
5
6
```eval_rst
.. contents::
```

7
Below picture shows the components overview of model compression framework.
8

9
![](../../img/compressor_framework.jpg)
10

11
There are 3 major components/classes in NNI model compression framework: `Compressor`, `Pruner` and `Quantizer`. Let's look at them in detail one by one:
12

13
## Compressor
Cjkkkk's avatar
Cjkkkk committed
14

15
Compressor is the base class for pruner and quntizer, it provides a unified interface for pruner and quantizer for end users, so that pruner and quantizer can be used in the same way. For example, to use a pruner:
Cjkkkk's avatar
Cjkkkk committed
16

17
```python
18
from nni.algorithms.compression.pytorch.pruning import LevelPruner
Cjkkkk's avatar
Cjkkkk committed
19

20
# load a pretrained model or train a model before using a pruner
Cjkkkk's avatar
Cjkkkk committed
21
22
23

configure_list = [{
    'sparsity': 0.7,
24
    'op_types': ['Conv2d', 'Linear'],
Cjkkkk's avatar
Cjkkkk committed
25
26
27
}]

optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
28
pruner = LevelPruner(model, configure_list, optimizer)
Cjkkkk's avatar
Cjkkkk committed
29
model = pruner.compress()
30
31
32

# model is ready for pruning, now start finetune the model,
# the model will be pruned during training automatically
Cjkkkk's avatar
Cjkkkk committed
33
34
```

35
To use a quantizer:
36
```python
37
from nni.algorithms.compression.pytorch.pruning import DoReFaQuantizer
Cjkkkk's avatar
Cjkkkk committed
38

39
40
41
42
43
44
45
46
47
48
configure_list = [{
    'quant_types': ['weight'],
    'quant_bits': {
        'weight': 8,
    },
    'op_types':['Conv2d', 'Linear']
}]
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
quantizer = DoReFaQuantizer(model, configure_list, optimizer)
quantizer.compress()
49

Cjkkkk's avatar
Cjkkkk committed
50
```
51
View [example code](https://github.com/microsoft/nni/tree/v1.9/examples/model_compress) for more information.
Cjkkkk's avatar
Cjkkkk committed
52

53
`Compressor` class provides some utility methods for subclass and users:
54
55
56
57
58
59

### Set wrapper attribute

Sometimes `calc_mask` must save some state data, therefore users can use `set_wrappers_attribute` API to register attribute just like how buffers are registered in PyTorch modules. These buffers will be registered to `module wrapper`. Users can access these buffers through `module wrapper`.
In above example, we use `set_wrappers_attribute` to set a buffer `if_calculated` which is used as flag indicating if the mask of a layer is already calculated.

Cjkkkk's avatar
Cjkkkk committed
60
### Collect data during forward
61
62

Sometimes users want to collect some data during the modules' forward method, for example, the mean value of the activation. This can be done by adding a customized collector to module.
Cjkkkk's avatar
Cjkkkk committed
63
64

```python
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class MyMasker(WeightMasker):
    def __init__(self, model, pruner):
        super().__init__(model, pruner)
        # Set attribute `collected_activation` for all wrappers to store
        # activations for each layer
        self.pruner.set_wrappers_attribute("collected_activation", [])
        self.activation = torch.nn.functional.relu

        def collector(wrapper, input_, output):
            # The collected activation can be accessed via each wrapper's collected_activation
            # attribute
            wrapper.collected_activation.append(self.activation(output.detach().cpu()))

        self.pruner.hook_id = self.pruner.add_activation_collector(collector)
Cjkkkk's avatar
Cjkkkk committed
79
```
80

Cjkkkk's avatar
Cjkkkk committed
81
82
83
The collector function will be called each time the forward method runs.

Users can also remove this collector like this:
84

Cjkkkk's avatar
Cjkkkk committed
85
```python
86
87
88
89
90
91
# Save the collector identifier
collector_id = self.pruner.add_activation_collector(collector)

# When the collector is not used any more, it can be remove using
# the saved collector identifier
self.pruner.remove_activation_collector(collector_id)
Cjkkkk's avatar
Cjkkkk committed
92
93
```

94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
***

## Pruner

A pruner receives `model`, `config_list` and `optimizer` as arguments. It prunes the model per the `config_list` during training loop by adding a hook on `optimizer.step()`.

Pruner class is a subclass of Compressor, so it contains everything in the Compressor class and some additional components only for pruning, it contains:

### Weight masker

A `weight masker` is the implementation of pruning algorithms, it can prune a specified layer wrapped by `module wrapper` with specified sparsity.

### Pruning module wrapper

A `pruning module wrapper` is a module containing:

1. the origin module
2. some buffers used by `calc_mask`
3. a new forward method that applies masks before running the original forward method.

the reasons to use `module wrapper`:

1. some buffers are needed by `calc_mask` to calculate masks and these buffers should be registered in `module wrapper` so that the original modules are not contaminated.
2. a new `forward` method is needed to apply masks to weight before calling the real `forward` method.

### Pruning hook

A pruning hook is installed on a pruner when the pruner is constructed, it is used to call pruner's calc_mask method at `optimizer.step()` is invoked.
122

123
124
125

***

126
## Quantizer
127

128
Quantizer class is also a subclass of `Compressor`, it is used to compress models by reducing the number of bits required to represent weights or activations, which can reduce the computations and the inference time. It contains:
129

130
### Quantization module wrapper
131

132
133
134
135
136
137
138
139
140
Each module/layer of the model to be quantized is wrapped by a quantization module wrapper, it provides a new `forward` method to quantize the original module's weight, input and output.

### Quantization hook

A quantization hook is installed on a quntizer when it is constructed, it is call at `optimizer.step()`.

### Quantization methods

`Quantizer` class provides following methods for subclass to implement quantization algorithms:
141

142
143
144
145
146
147
```python
class Quantizer(Compressor):
    """
    Base quantizer for pytorch quantizer
    """
    def quantize_weight(self, weight, wrapper, **kwargs):
148
        """
149
        quantize should overload this method to quantize weight.
150
151
152
153
154
        This method is effectively hooked to :meth:`forward` of the model.
        Parameters
        ----------
        weight : Tensor
            weight that needs to be quantized
155
156
        wrapper : QuantizerModuleWrapper
            the wrapper for origin module
157
        """
158
        raise NotImplementedError('Quantizer must overload quantize_weight()')
159

160
    def quantize_output(self, output, wrapper, **kwargs):
161
162
        """
        quantize should overload this method to quantize output.
163
        This method is effectively hooked to :meth:`forward` of the model.
164
165
166
167
        Parameters
        ----------
        output : Tensor
            output that needs to be quantized
168
169
        wrapper : QuantizerModuleWrapper
            the wrapper for origin module
170
        """
171
        raise NotImplementedError('Quantizer must overload quantize_output()')
172

173
    def quantize_input(self, *inputs, wrapper, **kwargs):
174
175
176
177
178
179
180
        """
        quantize should overload this method to quantize input.
        This method is effectively hooked to :meth:`forward` of the model.
        Parameters
        ----------
        inputs : Tensor
            inputs that needs to be quantized
181
182
        wrapper : QuantizerModuleWrapper
            the wrapper for origin module
183
        """
184
        raise NotImplementedError('Quantizer must overload quantize_input()')
185
186
187

```

188
***
189

190
## Multi-GPU support
191

192
193
On multi-GPU training, buffers and parameters are copied to multiple GPU every time the `forward` method runs on multiple GPU. If buffers and parameters are updated in the `forward` method, an `in-place` update is needed to ensure the update is effective.
Since `calc_mask` is called in the `optimizer.step` method, which happens after the `forward` method and happens only on one GPU, it supports multi-GPU naturally.
194