amp_zh.md 2.61 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# 混合精度训练

ColossalAI可以使用如下三种不同的混合精度训练方式:
1. torch.cuda.amp
2. apex.amp
3. 张量并行AMP

前两种混合精度训练方式依赖于[PyTorch](https://pytorch.org/docs/stable/amp.html)的原生实现(1.6或以上版本)以及
[Nvidia Apex](https://github.com/NVIDIA/apex),但这两种方法与张量并行并不兼容,因为在张量并行中我们需要将张量进行切分并保存在不同的设备上,
因此,实现兼容张量并行的混合精度训练需要在不同进程之间不断通信来交流`inf`以及`nan`是否存在于模型参数中,因此我们才用了
[Megatron-LM](https://github.com/NVIDIA/Megatron-LM)的实现方式。

您可以简单地将配置文件中的`fp16`字段设置为True来使用混合精度训练。目前,PyTorch与Apex的amp不能保证与张量和流水线并行兼容,因此,我们推荐您使用
最后一种混合精度训练方式。

## PyTorch AMP

PyTorch在1.6及以上版本中提供了混合精度训练,其可以在保持一些操作的精度为`fp32`的同时,将数据转换成`fp16`格式,您可以在配置文件中配置使用。

```python
from colossalai.engine import AMP_TYPE

fp16=dict(
    mode=AMP_TYPE.TORCH,
    # below are default values for grad scaler
    init_scale=2.**16,
    growth_factor=2.0,
    backoff_factor=0.5,
    growth_interval=2000,
    enabled=True
)
```

## Apex AMP

我们使用了[Apex](https://nvidia.github.io/apex/)中的混合精度训练,因为该模式提供了细粒度的混合精度控制,例如,`O2`级(第二级优化器)将会保持
批标准化在`fp32`上进行。下面的代码块展示了使用Apex AMP的配置文件。

```python
from colossalai.engine import AMP_TYPE

fp16 = dict(
    mode=AMP_TYPE.APEX,
    # below are the default values
    enabled=True, 
    opt_level='O1', 
    cast_model_type=None, 
    patch_torch_functions=None, 
    keep_batchnorm_fp32=None, 
    master_weights=None, 
    loss_scale=None, 
    cast_model_outputs=None,
    num_losses=1, 
    verbosity=1, 
    min_loss_scale=None, 
    max_loss_scale=16777216.0
)
```

## 张量并行AMP

我们借鉴了Megatron-LM的混合精度训练实现,该实现方式与张量并行与流水线并行相兼容。下面的代码块展示了使用张量并行AMP的配置文件。

```python
from colossalai.engine import AMP_TYPE

fp16 = dict(
    mode=AMP_TYPE.PARALLEL,
    # below are the default values
    clip_grad=0,
    log_num_zeros_in_grad=False,
    initial_scale=2 ** 32,
    min_scale=1,
    growth_factor=2,
    backoff_factor=0.5,
    growth_interval=1000,
    hysteresis=2
)
```