Framework.md 5.92 KB
Newer Older
QuanluZhang's avatar
QuanluZhang committed
1
2
# Design Doc

Cjkkkk's avatar
Cjkkkk committed
3
4
## Overview

5
Following example shows how to use a pruner:
Cjkkkk's avatar
Cjkkkk committed
6

7
8
```python
from nni.compression.torch import LevelPruner
Cjkkkk's avatar
Cjkkkk committed
9

10
# load a pretrained model or train a model before using a pruner
Cjkkkk's avatar
Cjkkkk committed
11
12
13

configure_list = [{
    'sparsity': 0.7,
14
    'op_types': ['Conv2d', 'Linear'],
Cjkkkk's avatar
Cjkkkk committed
15
16
17
}]

optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
18
pruner = LevelPruner(model, configure_list, optimizer)
Cjkkkk's avatar
Cjkkkk committed
19
model = pruner.compress()
20
21
22

# model is ready for pruning, now start finetune the model,
# the model will be pruned during training automatically
Cjkkkk's avatar
Cjkkkk committed
23
24
```

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
A pruner receives `model`, `config_list` and `optimizer` as arguments. It prunes the model per the `config_list` during training loop by adding a hook on `optimizer.step()`.

From implementation perspective, a pruner consists of a `weight masker` instance and multiple `module wrapper` instances.

### Weight masker

A `weight masker` is the implementation of pruning algorithms, it can prune a specified layer wrapped by `module wrapper` with specified sparsity.

### Module wrapper

A `module wrapper` is a module containing:

1. the origin module
2. some buffers used by `calc_mask`
3. a new forward method that applies masks before running the original forward method.

the reasons to use `module wrapper`:

1. some buffers are needed by `calc_mask` to calculate masks and these buffers should be registered in `module wrapper` so that the original modules are not contaminated.
2. a new `forward` method is needed to apply masks to weight before calling the real `forward` method.

### Pruner

A `pruner` is responsible for:

1. Manage / verify config_list.
2. Use `module wrapper` to wrap the model layers and add hook on `optimizer.step`
3. Use `weight masker` to calculate masks of layers while pruning.
4. Export pruned model weights and masks.
Cjkkkk's avatar
Cjkkkk committed
54
55
56

## Implement a new pruning algorithm

57
58
59
Implementing a new pruning algorithm requires implementing a `weight masker` class which shoud be a subclass of `WeightMasker`, and a `pruner` class, which should a subclass `Pruner`.

An implementation of `weight masker` may look like this:
Cjkkkk's avatar
Cjkkkk committed
60

61
62
63
64
65
66
67
68
69
70
71
72
```python
class MyMasker(WeightMasker):
    def __init__(self, model, pruner):
        super().__init__(model, pruner)
        # You can do some initialization here, such as collecting some statistics data
        # if it is necessary for your algorithms to calculate the masks.

    def calc_mask(self, sparsity, wrapper, wrapper_idx=None):
        # calculate the masks based on the wrapper.weight, and sparsity, 
        # and anything else
        # mask = ...
        return {'weight_mask': mask}
Cjkkkk's avatar
Cjkkkk committed
73
```
74
75
76
77

You can reference nni provided [weight masker](https://github.com/microsoft/nni/blob/master/src/sdk/pynni/nni/compression/torch/pruning/structured_pruning.py) implementations to implement your own weight masker.

A basic pruner looks likes this:
Cjkkkk's avatar
Cjkkkk committed
78
79

```python
80
class MyPruner(Pruner):
Cjkkkk's avatar
Cjkkkk committed
81
82
83
    def __init__(self, model, config_list, optimizer):
        super().__init__(model, config_list, optimizer)
        self.set_wrappers_attribute("if_calculated", False)
84
85
86
87
88
        # construct a weight masker instance
        self.masker = MyMasker(model, self)

    def calc_mask(self, wrapper, wrapper_idx=None):
        sparsity = wrapper.config['sparsity']
Cjkkkk's avatar
Cjkkkk committed
89
        if wrapper.if_calculated:
90
91
            # Already pruned, do not prune again as a one-shot pruner
            return None
Cjkkkk's avatar
Cjkkkk committed
92
        else:
93
94
            # call your masker to actually calcuate the mask for this layer
            masks = self.masker.calc_mask(sparsity=sparsity, wrapper=wrapper, wrapper_idx=wrapper_idx)
Cjkkkk's avatar
Cjkkkk committed
95
            wrapper.if_calculated = True
96
97
            return masks

Cjkkkk's avatar
Cjkkkk committed
98
99
```

100
101
102
103
104
105
106
Reference nni provided [pruner](https://github.com/microsoft/nni/blob/master/src/sdk/pynni/nni/compression/torch/pruning/one_shot.py) implementations to implement your own pruner class.

### Set wrapper attribute

Sometimes `calc_mask` must save some state data, therefore users can use `set_wrappers_attribute` API to register attribute just like how buffers are registered in PyTorch modules. These buffers will be registered to `module wrapper`. Users can access these buffers through `module wrapper`.
In above example, we use `set_wrappers_attribute` to set a buffer `if_calculated` which is used as flag indicating if the mask of a layer is already calculated.

Cjkkkk's avatar
Cjkkkk committed
107
### Collect data during forward
108
109

Sometimes users want to collect some data during the modules' forward method, for example, the mean value of the activation. This can be done by adding a customized collector to module.
Cjkkkk's avatar
Cjkkkk committed
110
111

```python
112
113
114
115
116
117
118
119
120
121
122
123
124
125
class MyMasker(WeightMasker):
    def __init__(self, model, pruner):
        super().__init__(model, pruner)
        # Set attribute `collected_activation` for all wrappers to store
        # activations for each layer
        self.pruner.set_wrappers_attribute("collected_activation", [])
        self.activation = torch.nn.functional.relu

        def collector(wrapper, input_, output):
            # The collected activation can be accessed via each wrapper's collected_activation
            # attribute
            wrapper.collected_activation.append(self.activation(output.detach().cpu()))

        self.pruner.hook_id = self.pruner.add_activation_collector(collector)
Cjkkkk's avatar
Cjkkkk committed
126
```
127

Cjkkkk's avatar
Cjkkkk committed
128
129
130
The collector function will be called each time the forward method runs.

Users can also remove this collector like this:
131

Cjkkkk's avatar
Cjkkkk committed
132
```python
133
134
135
136
137
138
# Save the collector identifier
collector_id = self.pruner.add_activation_collector(collector)

# When the collector is not used any more, it can be remove using
# the saved collector identifier
self.pruner.remove_activation_collector(collector_id)
Cjkkkk's avatar
Cjkkkk committed
139
140
141
```

### Multi-GPU support
142

Cjkkkk's avatar
Cjkkkk committed
143
On multi-GPU training, buffers and parameters are copied to multiple GPU every time the `forward` method runs on multiple GPU. If buffers and parameters are updated in the `forward` method, an `in-place` update is needed to ensure the update is effective.
144
Since `calc_mask` is called in the `optimizer.step` method, which happens after the `forward` method and happens only on one GPU, it supports multi-GPU naturally.