Framework.md 4.49 KB
Newer Older
Chi Song's avatar
Chi Song committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# 设计文档

## 概述
模型压缩框架有两个主要组件: `Pruner``module 的包装`

### Pruner
`Pruner` 用于:
1. 提供 `cal_mask` 方法来计算权重和偏差的掩码(mask)。
2. 根据配置,用 `module 的包装`来替换原始的 module。
3. 修改优化器,来在 `step` 方法被调用时,调用 `cal_mask`

### module 的包装
`module 的包装` 包含:
1. 原始的 module
2. `cal_mask` 使用的一些缓存
3. 新的 forward 方法,用于在运行原始的 forward 方法前应用掩码。

使用 `module 包装`的原因:
1. 计算掩码所需要的 `cal_mask` 方法需要一些缓存,这些缓存需要注册在 `module 包装`里,这样就不需要修改原始的 module。
2. 新的 `forward` 方法用来在原始 `forward` 调用前,将掩码应用到权重上。

## 工作原理
基本的 Pruner 用法:
```python
configure_list = [{
    'sparsity': 0.7,
    'op_types': ['BatchNorm2d'],
}]

optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
pruner = SlimPruner(model, configure_list, optimizer)
model = pruner.compress()
```

Pruner 接收模型,配置和优化器作为参数。 在 `__init__` 方法中,优化器的 `step` 方法会被一个会调用 `cal_mask` 的新的 `step` 方法替换。 同样,所有 module 都会检查它们是否被配置为需要剪枝。如果 module 需要被剪枝,就会用 `module 包装`来替换它。 之后,会返回新的模型和优化器,并进行训练。 `compress` 方法会计算默认的掩码。

## 实现新的剪枝算法
要实现新的剪枝算法,需要继承 `Pruner` 来实现新的类,并重载 `cal_mask` 方法。 `cal_mask` 会被 `optimizer.step` 方法调用。 `Pruner` 基类提供了上述的基本功能,如替换 module 和优化器。

基础的 Pruner 如下所示:
```python
class NewPruner(Pruner):
    def __init__(self, model, config_list, optimizer)
        super().__init__(model, config_list, optimizer)
        # 进行初始化

    def calc_mask(self, wrapper, **kwargs):
        # 计算 weight_mask
        wrapper.weight_mask = weight_mask
```
### 设置包装的属性
有时,`cal_mask` 需要保存一些状态数据,可以像 PyTorch 的 module 一样,使用 `set_wrappers_attribute` API 来注册属性。 这些缓存会注册到 `module 包装`中。 用户可以通过 `module 包装`来直接访问这些缓存。

```python
class NewPruner(Pruner):
    def __init__(self, model, config_list, optimizer):
        super().__init__(model, config_list, optimizer)
        self.set_wrappers_attribute("if_calculated", False)

    def calc_mask(self, wrapper):
        # 计算 weight_mask
        if wrapper.if_calculated:
            pass
        else:
            wrapper.if_calculated = True
            # 更新掩码
```

### 在 forward 时收集数据
有时,需要在 forward 方法中收集数据,例如,需要激活的平均值。 这时,可以为 module 增加定制的收集方法。

```python
class ActivationRankFilterPruner(Pruner):
    def __init__(self, model, config_list, optimizer, activation='relu', statistics_batch_num=1):
        super().__init__(model, config_list, optimizer)
        self.set_wrappers_attribute("if_calculated", False)
        self.set_wrappers_attribute("collected_activation", [])
        self.statistics_batch_num = statistics_batch_num

        def collector(module_, input_, output):
            if len(module_.collected_activation) < self.statistics_batch_num:
                module_.collected_activation.append(self.activation(output.detach().cpu()))
        self.add_activation_collector(collector)
        assert activation in ['relu', 'relu6']
        if activation == 'relu':
            self.activation = torch.nn.functional.relu
        elif activation == 'relu6':
            self.activation = torch.nn.functional.relu6
        else:
            self.activation = None
```
收集函数会在每次 forward 方法运行时调用。

还可这样来移除收集方法:
```python
collector_id = self.add_activation_collector(collector)
# ...
self.remove_activation_collector(collector_id)
```

### 多 GPU 支持
在多 GPU 训练中,缓存和参数会在每次 `forward` 方法被调用时,复制到多个 GPU 上。 如果缓存和参数要在 `forward` 更新,就需要通过`原地`更新来提高效率。 因为 `cal_mask` 会在 `optimizer.step` 方法中的调用,会在 `forward` 方法后才被调用,且只会发生在单 GPU 上,因此它天然的就支持多 GPU 的情况。