README.md 7.3 KB
Newer Older
Christian Sarofeen's avatar
Christian Sarofeen committed
1
2
# Introduction

3
4
This repository holds NVIDIA-maintained utilities to streamline
mixed precision and distributed training in Pytorch.
Michael Carilli's avatar
Michael Carilli committed
5
Some of the code here will be included in upstream Pytorch eventually.
6
The intention of Apex is to make up-to-date utilities available to
Michael Carilli's avatar
Michael Carilli committed
7
8
users as quickly as possible.

9
## Full API Documentation: [https://nvidia.github.io/apex](https://nvidia.github.io/apex)
Michael Carilli's avatar
Michael Carilli committed
10

11
12
## [GTC 2019](https://github.com/mcarilli/mixed_precision_references/tree/master/GTC_2019) and [Pytorch DevCon 2019](https://github.com/mcarilli/mixed_precision_references/tree/master/Pytorch_Devcon_2019) Slides

Michael Carilli's avatar
Michael Carilli committed
13
14
# Contents

mcarilli's avatar
mcarilli committed
15
## 1. Amp:  Automatic Mixed Precision
Michael Carilli's avatar
Michael Carilli committed
16

Michael Carilli's avatar
Michael Carilli committed
17
18
19
`apex.amp` is a tool to enable mixed precision training by changing only 3 lines of your script.
Users can easily experiment with different pure and mixed precision training modes by supplying
different flags to `amp.initialize`.
Michael Carilli's avatar
Michael Carilli committed
20

Michael Carilli's avatar
Michael Carilli committed
21
22
[Webinar introducing Amp](https://info.nvidia.com/webinar-mixed-precision-with-pytorch-reg-page.html)
(The flag `cast_batchnorm` has been renamed to `keep_batchnorm_fp32`).
Michael Carilli's avatar
Michael Carilli committed
23

Michael Carilli's avatar
Michael Carilli committed
24
[API Documentation](https://nvidia.github.io/apex/amp.html)
Michael Carilli's avatar
Michael Carilli committed
25

Michael Carilli's avatar
Michael Carilli committed
26
[Comprehensive Imagenet example](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
Michael Carilli's avatar
Michael Carilli committed
27

Michael Carilli's avatar
Michael Carilli committed
28
[DCGAN example coming soon...](https://github.com/NVIDIA/apex/tree/master/examples/dcgan)
29

30
[Moving to the new Amp API](https://nvidia.github.io/apex/amp.html#transition-guide-for-old-api-users) (for users of the deprecated "Amp" and "FP16_Optimizer" APIs)
Michael Carilli's avatar
Michael Carilli committed
31
32
33

## 2. Distributed Training

34
`apex.parallel.DistributedDataParallel` is a module wrapper, similar to
Michael Carilli's avatar
Michael Carilli committed
35
36
37
`torch.nn.parallel.DistributedDataParallel`.  It enables convenient multiprocess distributed training,
optimized for NVIDIA's NCCL communication library.

38
39
[API Documentation](https://nvidia.github.io/apex/parallel.html)

40
[Python Source](https://github.com/NVIDIA/apex/tree/master/apex/parallel)
Michael Carilli's avatar
Michael Carilli committed
41

42
[Example/Walkthrough](https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed)
Christian Sarofeen's avatar
Christian Sarofeen committed
43

Michael Carilli's avatar
Michael Carilli committed
44
45
The [Imagenet example](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
shows use of `apex.parallel.DistributedDataParallel` along with `apex.amp`.
46

jjsjann123's avatar
jjsjann123 committed
47
48
49
50
### Synchronized Batch Normalization

`apex.parallel.SyncBatchNorm` extends `torch.nn.modules.batchnorm._BatchNorm` to
support synchronized BN.
Michael Carilli's avatar
Michael Carilli committed
51
52
53
54
55
56
57
It allreduces stats across processes during multiprocess (DistributedDataParallel) training.
Synchronous BN has been used in cases where only a small
local minibatch can fit on each GPU.
Allreduced stats increase the effective batch size for the BN layer to the
global batch size across all processes (which, technically, is the correct
formulation).
Synchronous BN has been observed to improve converged accuracy in some of our research models.
jjsjann123's avatar
jjsjann123 committed
58

ptrblck's avatar
ptrblck committed
59
60
61
62
63
64
65
66
67
68
69
70
71
### Checkpointing

To properly save and load your `amp` training, we introduce the `amp.state_dict()`, which contains all `loss_scalers` and their corresponding unskipped steps,
as well as `amp.load_state_dict()` to restore these attributes.

In order to get bitwise accuracy, we recommend the following workflow:
```python
# Initialization
opt_level = 'O1'
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)

# Train your model
...
72
73
74
with amp.scale_loss(loss, optimizer) as scaled_loss:
    scaled_loss.backward()
...
ptrblck's avatar
ptrblck committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100

# Save checkpoint
checkpoint = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'amp': amp.state_dict()
}
torch.save(checkpoint, 'amp_checkpoint.pt')
...

# Restore
model = ...
optimizer = ...
checkpoint = torch.load('amp_checkpoint.pt')

model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
amp.load_state_dict(checkpoint['amp'])

# Continue training
...
```

Note that we recommend restoring the model using the same `opt_level`. Also note that we recommend calling the `load_state_dict` methods after `amp.initialize`.

Christian Sarofeen's avatar
Christian Sarofeen committed
101
102
103
# Requirements

Python 3
Michael Carilli's avatar
Michael Carilli committed
104

Michael Carilli's avatar
Michael Carilli committed
105
CUDA 9 or newer
Michael Carilli's avatar
Michael Carilli committed
106

Michael Carilli's avatar
Michael Carilli committed
107
PyTorch 0.4 or newer.  The CUDA and C++ extensions require pytorch 1.0 or newer.
mcarilli's avatar
mcarilli committed
108

Michael Carilli's avatar
Michael Carilli committed
109
110
We recommend the latest stable release, obtainable from
[https://pytorch.org/](https://pytorch.org/).  We also test against the latest master branch, obtainable from [https://github.com/pytorch/pytorch](https://github.com/pytorch/pytorch).
Christian Sarofeen's avatar
Christian Sarofeen committed
111

Michael Carilli's avatar
Michael Carilli committed
112
113
114
It's often convenient to use Apex in Docker containers.  Compatible options include:
* [NVIDIA Pytorch containers from NGC](https://ngc.nvidia.com/catalog/containers/nvidia%2Fpytorch), which come with Apex preinstalled.  To use the latest Amp API, you may need to `pip uninstall apex` then reinstall Apex using the **Quick Start** commands below.
* [official Pytorch -devel Dockerfiles](https://hub.docker.com/r/pytorch/pytorch/tags), e.g. `docker pull pytorch/pytorch:nightly-devel-cuda10.0-cudnn7`, in which you can install Apex using the **Quick Start** commands.
Christian Sarofeen's avatar
Christian Sarofeen committed
115

116
117
See the [Docker example folder](https://github.com/NVIDIA/apex/tree/master/examples/docker) for details.

Chaitanya Sri Krishna Lolla's avatar
Chaitanya Sri Krishna Lolla committed
118
119
120
121
## On ROCm:
* Python 3.6
* Pytorch 1.5 or newer, The HIPExtensions require 1.5 or newer.
* We recommend follow the instructions from [ROCm-Pytorch](https://github.com/ROCmSoftwarePlatform/pytorch) to install pytorch on ROCm.
lcskrishna's avatar
lcskrishna committed
122
* Note: For pytorch versions < 1.8, building from source is no longer supported, please use the release package [ROCm-Apex v0.3](https://github.com/ROCmSoftwarePlatform/apex/releases/tag/v0.3) . 
Chaitanya Sri Krishna Lolla's avatar
Chaitanya Sri Krishna Lolla committed
123

Christian Sarofeen's avatar
Christian Sarofeen committed
124
125
# Quick Start

lcskrishna's avatar
lcskrishna committed
126
127
128
129
### Rocm
Apex on ROCm supports both python only build and extension build.
Note: Pytorch version recommended is >=1.5 for extension build.

Chaitanya Sri Krishna Lolla's avatar
Chaitanya Sri Krishna Lolla committed
130
### To install using python only build use the following command in apex folder:
lcskrishna's avatar
lcskrishna committed
131
132
133
134
135
136
137
138
139
```
python3.6 setup.py install
```

### To install using extensions enabled use the following command in apex folder:
```
python3.6 setup.py install --cpp_ext --cuda_ext
```

140
141
142
143
144
145
### To install Apex on ROCm using ninja and without cloning the source
```
pip3.6 install ninja
pip3.6 install -v --install-option="--cpp_ext" --install-option="--cuda_ext" 'git+https://github.com/ROCmSoftwarePlatform/apex.git'
```

146
### Linux
Christian Sarofeen's avatar
Christian Sarofeen committed
147

Michael Carilli's avatar
Michael Carilli committed
148
149
For performance and full functionality, we recommend installing Apex with
CUDA and C++ extensions via
Christian Sarofeen's avatar
Christian Sarofeen committed
150
```
Glenn Jocher's avatar
Glenn Jocher committed
151
$ git clone https://github.com/NVIDIA/apex
Michael Carilli's avatar
Michael Carilli committed
152
$ cd apex
mcarilli's avatar
mcarilli committed
153
$ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
Christian Sarofeen's avatar
Christian Sarofeen committed
154
155
```

Michael Carilli's avatar
Michael Carilli committed
156
Apex also supports a Python-only build (required with Pytorch 0.4) via
jjsjann123's avatar
jjsjann123 committed
157
```
mcarilli's avatar
mcarilli committed
158
$ pip install -v --no-cache-dir ./
jjsjann123's avatar
jjsjann123 committed
159
```
Michael Carilli's avatar
Michael Carilli committed
160
A Python-only build omits:
161
- Fused kernels required to use `apex.optimizers.FusedAdam`.
vfdev's avatar
typo  
vfdev committed
162
- Fused kernels required to use `apex.normalization.FusedLayerNorm`.
Michael Carilli's avatar
Michael Carilli committed
163
164
165
- Fused kernels that improve the performance and numerical stability of `apex.parallel.SyncBatchNorm`.
- Fused kernels that improve the performance of `apex.parallel.DistributedDataParallel` and `apex.amp`.
`DistributedDataParallel`, `amp`, and `SyncBatchNorm` will still be usable, but they may be slower.
jjsjann123's avatar
jjsjann123 committed
166

167
168
Pyprof support has been moved to its own [dedicated repository](https://github.com/NVIDIA/PyProf).
The codebase is deprecated in Apex and will be removed soon.
Marek Kolodziej's avatar
Marek Kolodziej committed
169

170
### Windows support
mcarilli's avatar
mcarilli committed
171
172
Windows support is experimental, and Linux is recommended.  `pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .` may work if you were able to build Pytorch from source
on your system.  `pip install -v --no-cache-dir .` (without CUDA/C++ extensions) is more likely to work.  If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment.