ModelSpeedup.md 4.34 KB
Newer Older
Chi Song's avatar
Chi Song committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# 加速掩码的模型

*此功能还处于预览版。*

## 介绍

剪枝算法通常都用权重掩码来模拟实际的剪枝。 掩码可以用来检查某个剪枝(或稀疏)算法的模型性能,但还没有真正加速。 模型加速才是模型剪枝的最终目标。因此提供了此工具,来帮助基于用户提供的掩码(掩码来自于剪枝算法),将已有模型转换成小模型。

有两种剪枝算法。 一种是细粒度的剪枝,不改变权重形状,和输入输出的张量。 稀疏内核会被用来加速细粒度剪枝的层。 另一类是粗粒度的剪枝(例如,通道),通常,权重形状,输入输出张量会有所改变。 要加速这类剪枝算法,不需要使用系数内核,只需要用更小的层来替换。 由于开源社区中对稀疏内核的支持还比较有限,当前仅支持粗粒度剪枝,会在将来再支持细粒度的剪枝算法。

## 设计和实现

为了加速模型,被剪枝的层应该被替换掉,要么为粗粒度掩码使用较小的层,要么用稀疏内核来替换细粒度的掩码。 粗粒度掩码通常会改变权重的形状,或输入输出张量,因此,应该通过形状推断,来检查是否其它未被剪枝的层由于形状变化而需要改变形状。 因此,在设计中,主要有两个步骤:第一,做形状推理,找出所有应该替换的模块;第二,替换模块。 第一步需要模型的拓扑(即连接),我们使用了 `jit.trace` 来获取 PyTorch 的模型图。

对于每个模块,要准备四个函数,三个用于形状推理,一个用于模块替换。 三个形状推理函数是:给定权重形状推断输入/输出形状,给定输入形状推断权重/输出形状,给定输出形状推断权重/输入形状。 模块替换功能返回一个较小的新创建的模块。

## 用法

```python
from nni.compression.speedup.torch import ModelSpeedup
# model: 要加速的模型
# dummy_input: 模型的示输入,传给 `jit.trace`
# masks_file: 剪枝算法创建的掩码文件
m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file)
m_speedup.speedup_model()
dummy_input = dummy_input.to(device)
start = time.time()
out = model(dummy_input)
print('elapsed time: ', time.time() - start)
```
完整示例参考[这里](https://github.com/microsoft/nni/tree/master/examples/model_compress/model_speedup.py)

注意:当前实现仅用于 torch 1.3.1 和 torchvision 0.4.2

## 局限性

由于每个模块需要 4 个函数用于形状推理和模块替换,因此工作量较大,当前仅实现了示例所需的函数。 如果要加速自己的模型,但当前不支持,欢迎贡献。

对于 PyTorch,仅提供了替换模块,如果是在 `forward` 中的函数,当前不支持。 一种解决方案是将函数变为 PyTorch 模块。

## 示例的加速结果

实验代码可在[这里](https://github.com/microsoft/nni/tree/master/examples/model_compress/model_speedup.py)找到。

### slim Pruner 示例

在一块 V100 GPU 上, 输入张量:`torch.randn(64, 3, 32, 32)`

| 次数 | 掩码时延    | 加速后的时延   |
| -- | ------- | -------- |
| 1  | 0.01197 | 0.005107 |
| 2  | 0.02019 | 0.008769 |
| 4  | 0.02733 | 0.014809 |
| 8  | 0.04310 | 0.027441 |
| 16 | 0.07731 | 0.05008  |
| 32 | 0.14464 | 0.10027  |

### fpgm Pruner 示例

在 CPU 上, 输入张量:`torch.randn(64, 1, 28, 28)`, 方差较大

| 次数  | 掩码时延    | 加速后的时延   |
| --- | ------- | -------- |
| 1   | 0.01383 | 0.01839  |
| 2   | 0.01167 | 0.003558 |
| 4   | 0.01636 | 0.01088  |
| 40  | 0.14412 | 0.08268  |
| 40  | 1.29385 | 0.14408  |
| 40  | 0.41035 | 0.46162  |
| 400 | 6.29020 | 5.82143  |

### l1filter Pruner 示例

在一块 V100 GPU 上, 输入张量:`torch.randn(64, 3, 32, 32)`

| 次数 | 掩码时延    | 加速后的时延   |
| -- | ------- | -------- |
| 1  | 0.01026 | 0.003677 |
| 2  | 0.01657 | 0.008161 |
| 4  | 0.02458 | 0.020018 |
| 8  | 0.03498 | 0.025504 |
| 16 | 0.06757 | 0.047523 |
| 32 | 0.10487 | 0.086442 |

### APoZ Pruner 示例

在一块 V100 GPU 上, 输入张量:`torch.randn(64, 3, 32, 32)`

| 次数 | 掩码时延    | 加速后的时延   |
| -- | ------- | -------- |
| 1  | 0.01389 | 0.004208 |
| 2  | 0.01628 | 0.008310 |
| 4  | 0.02521 | 0.014008 |
| 8  | 0.03386 | 0.023923 |
| 16 | 0.06042 | 0.046183 |
| 32 | 0.12421 | 0.087113 |