Framework.md 4.86 KB
Newer Older
Cjkkkk's avatar
Cjkkkk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
## Overview
The model compression framework has two main components: `pruner` and `module wrapper`.

### pruner
A `pruner` is responsible for :
1. provide a `cal_mask` method that calculates masks for weight and bias.
2. replace the module with `module wrapper` based on config.
3. modify the optimizer so that the `cal_mask` method is called every time the `step` method is called.

### module wrapper
A `module wrapper` is a module containing :
1. the origin module
2. some buffers used by `cal_mask`
3. a new forward method that applies masks before running the original forward method.

the reasons to use `module wrapper` :
1. some buffers are needed by `cal_mask` to calculate masks and these buffers should be registered in `module wrapper` so that the original modules are not contaminated.
2. a new `forward` method is needed to apply masks to weight before calling the real `forward` method.

## How it works
A basic pruner usage:
```python
configure_list = [{
    'sparsity': 0.7,
    'op_types': ['BatchNorm2d'],
}]

optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
pruner = SlimPruner(model, configure_list, optimizer)
model = pruner.compress()
```

A pruner receive model, config and optimizer as arguments. In the `__init__` method, the `step` method of the optimizer is replaced with a new `step` method that calls `cal_mask`. Also, all modules are checked if they need to be pruned based on config. If a module needs to be pruned, then this module is replaced by a `module wrapper`. Afterward, the new model and new optimizer are returned, which can be trained as before. `compress` method will calculate the default masks.

## Implement a new pruning algorithm
Implementing a new pruning algorithm requires implementing a new `pruner` class, which should subclass `Pruner` and override the `cal_mask` method. The `cal_mask` is called by`optimizer.step` method.
The `Pruner` base class provided basic functionality listed above, for example, replacing modules and patching optimizer.

A basic pruner look likes this:
```python
class NewPruner(Pruner):
    def __init__(self, model, config_list, optimizer)
        super().__init__(model, config_list, optimizer)
        # do some initialization

    def calc_mask(self, wrapper, **kwargs):
        # do something to calculate weight_mask
        wrapper.weight_mask = weight_mask
```
### Set wrapper attribute
Sometimes `cal_mask` must save some state data, therefore users can use `set_wrappers_attribute` API to register attribute just like how buffers are registered in PyTorch modules. These buffers will be registered to `module wrapper`. Users can access these buffers through `module wrapper`.

```python
class NewPruner(Pruner):
    def __init__(self, model, config_list, optimizer):
        super().__init__(model, config_list, optimizer)
        self.set_wrappers_attribute("if_calculated", False)
    
    def calc_mask(self, wrapper):
        # do something to calculate weight_mask
        if wrapper.if_calculated:
            pass
        else:
            wrapper.if_calculated = True
            # update masks
```

### Collect data during forward
Sometimes users want to collect some data during the modules' forward method, for example, the mean value of the activation. Therefore user can add a customized collector to module.

```python
class ActivationRankFilterPruner(Pruner):
    def __init__(self, model, config_list, optimizer, activation='relu', statistics_batch_num=1):
        super().__init__(model, config_list, optimizer)
        self.set_wrappers_attribute("if_calculated", False)
        self.set_wrappers_attribute("collected_activation", [])
        self.statistics_batch_num = statistics_batch_num

        def collector(module_, input_, output):
            if len(module_.collected_activation) < self.statistics_batch_num:
                module_.collected_activation.append(self.activation(output.detach().cpu()))
        self.add_activation_collector(collector)
        assert activation in ['relu', 'relu6']
        if activation == 'relu':
            self.activation = torch.nn.functional.relu
        elif activation == 'relu6':
            self.activation = torch.nn.functional.relu6
        else:
            self.activation = None
```
The collector function will be called each time the forward method runs.

Users can also remove this collector like this:
```python
collector_id = self.add_activation_collector(collector)
# ...
self.remove_activation_collector(collector_id)
```

### Multi-GPU support
On multi-GPU training, buffers and parameters are copied to multiple GPU every time the `forward` method runs on multiple GPU. If buffers and parameters are updated in the `forward` method, an `in-place` update is needed to ensure the update is effective.
Since `cal_mask` is called in the `optimizer.step` method, which happens after the `forward` method and happens only on one GPU, it supports multi-GPU naturally.