README.md 6.58 KB
Newer Older
Christian Sarofeen's avatar
Christian Sarofeen committed
1
2
# Introduction

3
This repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch.
Michael Carilli's avatar
Michael Carilli committed
4
Some of the code here will be included in upstream Pytorch eventually.
5
The intent of Apex is to make up-to-date utilities available to users as quickly as possible.
Michael Carilli's avatar
Michael Carilli committed
6

7
## Full API Documentation: [https://nvidia.github.io/apex](https://nvidia.github.io/apex)
Michael Carilli's avatar
Michael Carilli committed
8

9
10
## [GTC 2019](https://github.com/mcarilli/mixed_precision_references/tree/master/GTC_2019) and [Pytorch DevCon 2019](https://github.com/mcarilli/mixed_precision_references/tree/master/Pytorch_Devcon_2019) Slides

Michael Carilli's avatar
Michael Carilli committed
11
12
# Contents

mcarilli's avatar
mcarilli committed
13
## 1. Amp:  Automatic Mixed Precision
Michael Carilli's avatar
Michael Carilli committed
14

Michael Carilli's avatar
Michael Carilli committed
15
16
17
`apex.amp` is a tool to enable mixed precision training by changing only 3 lines of your script.
Users can easily experiment with different pure and mixed precision training modes by supplying
different flags to `amp.initialize`.
Michael Carilli's avatar
Michael Carilli committed
18

Michael Carilli's avatar
Michael Carilli committed
19
20
[Webinar introducing Amp](https://info.nvidia.com/webinar-mixed-precision-with-pytorch-reg-page.html)
(The flag `cast_batchnorm` has been renamed to `keep_batchnorm_fp32`).
Michael Carilli's avatar
Michael Carilli committed
21

Michael Carilli's avatar
Michael Carilli committed
22
[API Documentation](https://nvidia.github.io/apex/amp.html)
Michael Carilli's avatar
Michael Carilli committed
23

Michael Carilli's avatar
Michael Carilli committed
24
[Comprehensive Imagenet example](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
Michael Carilli's avatar
Michael Carilli committed
25

Michael Carilli's avatar
Michael Carilli committed
26
[DCGAN example coming soon...](https://github.com/NVIDIA/apex/tree/master/examples/dcgan)
27

28
[Moving to the new Amp API](https://nvidia.github.io/apex/amp.html#transition-guide-for-old-api-users) (for users of the deprecated "Amp" and "FP16_Optimizer" APIs)
Michael Carilli's avatar
Michael Carilli committed
29
30
31

## 2. Distributed Training

32
`apex.parallel.DistributedDataParallel` is a module wrapper, similar to
Michael Carilli's avatar
Michael Carilli committed
33
34
35
`torch.nn.parallel.DistributedDataParallel`.  It enables convenient multiprocess distributed training,
optimized for NVIDIA's NCCL communication library.

36
37
[API Documentation](https://nvidia.github.io/apex/parallel.html)

38
[Python Source](https://github.com/NVIDIA/apex/tree/master/apex/parallel)
Michael Carilli's avatar
Michael Carilli committed
39

40
[Example/Walkthrough](https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed)
Christian Sarofeen's avatar
Christian Sarofeen committed
41

Michael Carilli's avatar
Michael Carilli committed
42
43
The [Imagenet example](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
shows use of `apex.parallel.DistributedDataParallel` along with `apex.amp`.
44

jjsjann123's avatar
jjsjann123 committed
45
46
47
48
### Synchronized Batch Normalization

`apex.parallel.SyncBatchNorm` extends `torch.nn.modules.batchnorm._BatchNorm` to
support synchronized BN.
Michael Carilli's avatar
Michael Carilli committed
49
50
51
52
53
54
55
It allreduces stats across processes during multiprocess (DistributedDataParallel) training.
Synchronous BN has been used in cases where only a small
local minibatch can fit on each GPU.
Allreduced stats increase the effective batch size for the BN layer to the
global batch size across all processes (which, technically, is the correct
formulation).
Synchronous BN has been observed to improve converged accuracy in some of our research models.
jjsjann123's avatar
jjsjann123 committed
56

ptrblck's avatar
ptrblck committed
57
58
59
60
61
62
63
64
65
66
67
68
69
### Checkpointing

To properly save and load your `amp` training, we introduce the `amp.state_dict()`, which contains all `loss_scalers` and their corresponding unskipped steps,
as well as `amp.load_state_dict()` to restore these attributes.

In order to get bitwise accuracy, we recommend the following workflow:
```python
# Initialization
opt_level = 'O1'
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)

# Train your model
...
70
71
72
with amp.scale_loss(loss, optimizer) as scaled_loss:
    scaled_loss.backward()
...
ptrblck's avatar
ptrblck committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

# Save checkpoint
checkpoint = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'amp': amp.state_dict()
}
torch.save(checkpoint, 'amp_checkpoint.pt')
...

# Restore
model = ...
optimizer = ...
checkpoint = torch.load('amp_checkpoint.pt')

model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
amp.load_state_dict(checkpoint['amp'])

# Continue training
...
```

Note that we recommend restoring the model using the same `opt_level`. Also note that we recommend calling the `load_state_dict` methods after `amp.initialize`.

99
# Installation
Christian Sarofeen's avatar
Christian Sarofeen committed
100

101
102
103
## Containers
NVIDIA PyTorch Containers are available on NGC: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch.
The containers come with all the custom extensions available at the moment. 
Michael Carilli's avatar
Michael Carilli committed
104

105
106
107
108
See [the NGC documentation](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html) for details such as:
- how to pull a container
- how to run a pulled container
- release notes
Michael Carilli's avatar
Michael Carilli committed
109

110
## From Source
mcarilli's avatar
mcarilli committed
111

112
To install Apex from source, we recommend using the nightly Pytorch obtainable from https://github.com/pytorch/pytorch.
Christian Sarofeen's avatar
Christian Sarofeen committed
113

114
The latest stable release obtainable from https://pytorch.org should also work.
Christian Sarofeen's avatar
Christian Sarofeen committed
115

lcskrishna's avatar
lcskrishna committed
116
117
118
119
### Rocm
Apex on ROCm supports both python only build and extension build.
Note: Pytorch version recommended is >=1.5 for extension build.

Chaitanya Sri Krishna Lolla's avatar
Chaitanya Sri Krishna Lolla committed
120
### To install using python only build use the following command in apex folder:
lcskrishna's avatar
lcskrishna committed
121
```
Peng's avatar
Peng committed
122
python setup.py install
lcskrishna's avatar
lcskrishna committed
123
124
125
126
```

### To install using extensions enabled use the following command in apex folder:
```
Peng's avatar
Peng committed
127
python setup.py install --cpp_ext --cuda_ext
lcskrishna's avatar
lcskrishna committed
128
```
Hubert Lu's avatar
Hubert Lu committed
129
Note that using --cuda_ext flag to install Apex will also enable all the extensions supported on ROCm including "--distributed_adam", "--distributed_lamb", "--bnp", "--xentropy", "--deprecated_fused_adam", "--deprecated_fused_lamb", and "--fast_multihead_attn".
lcskrishna's avatar
lcskrishna committed
130

131
132
### To install Apex on ROCm using ninja and without cloning the source
```
Peng's avatar
Peng committed
133
134
pip install ninja
pip install -v --install-option="--cpp_ext" --install-option="--cuda_ext" 'git+https://github.com/ROCmSoftwarePlatform/apex.git'
135
136
```

137
### Linux
Michael Carilli's avatar
Michael Carilli committed
138
139
For performance and full functionality, we recommend installing Apex with
CUDA and C++ extensions via
140
```bash
141
142
143
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
Christian Sarofeen's avatar
Christian Sarofeen committed
144
145
```

146
147
Apex also supports a Python-only build via
```bash
148
pip install -v --disable-pip-version-check --no-cache-dir ./
jjsjann123's avatar
jjsjann123 committed
149
```
Michael Carilli's avatar
Michael Carilli committed
150
A Python-only build omits:
151
- Fused kernels required to use `apex.optimizers.FusedAdam`.
152
- Fused kernels required to use `apex.normalization.FusedLayerNorm` and `apex.normalization.FusedRMSNorm`.
Michael Carilli's avatar
Michael Carilli committed
153
154
155
- Fused kernels that improve the performance and numerical stability of `apex.parallel.SyncBatchNorm`.
- Fused kernels that improve the performance of `apex.parallel.DistributedDataParallel` and `apex.amp`.
`DistributedDataParallel`, `amp`, and `SyncBatchNorm` will still be usable, but they may be slower.
jjsjann123's avatar
jjsjann123 committed
156

Marek Kolodziej's avatar
Marek Kolodziej committed
157

158
159
160
161
### [Experimental] Windows
`pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .` may work if you were able to build Pytorch from source
on your system. A Python-only build via `pip install -v --no-cache-dir .` is more likely to work.  
If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment.