README.md 8.71 KB
Newer Older
1
# Mixed Precision ImageNet Training in PyTorch
2

Michael Carilli's avatar
Michael Carilli committed
3
`main_amp.py` is based on [https://github.com/pytorch/examples/tree/master/imagenet](https://github.com/pytorch/examples/tree/master/imagenet).
mcarilli's avatar
mcarilli committed
4
It implements Automatic Mixed Precision (Amp) training of popular model architectures, such as ResNet, AlexNet, and VGG, on the ImageNet dataset.  Command-line flags forwarded to `amp.initialize` are used to easily manipulate and switch between various pure and mixed precision "optimization levels" or `opt_level`s.  For a detailed explanation of `opt_level`s, see the [updated API guide](https://nvidia.github.io/apex/amp.html).
5

Michael Carilli's avatar
Michael Carilli committed
6
7
8
9
10
11
12
13
14
Three lines enable Amp:
```
# Added after model and optimizer construction
model, optimizer = amp.initialize(model, optimizer, flags...)
...
# loss.backward() changed to:
with amp.scale_loss(loss, optimizer) as scaled_loss:
    scaled_loss.backward()
```
15

Michael Carilli's avatar
Michael Carilli committed
16
With the new Amp API **you never need to explicitly convert your model, or the input data, to half().**
17

18
19
20
## Requirements

- Download the ImageNet dataset and move validation images to labeled subfolders
Michael Carilli's avatar
Michael Carilli committed
21
    - The following script may be helpful: https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh
22
23
24

## Training

25
26
To train a model, create softlinks to the Imagenet dataset, then run `main.py` with the desired model architecture, as shown in `Example commands` below.

Michael Carilli's avatar
Michael Carilli committed
27
The default learning rate schedule is set for ResNet50.  `main_amp.py` script rescales the learning rate according to the global batch size (number of distributed processes \* per-process minibatch size).
28
29
30
31

## Example commands

**Note:**  batch size `--b 224` assumes your GPUs have >=16GB of onboard memory.  You may be able to increase this to 256, but that's cutting it close, so it may out-of-memory for different Pytorch versions.
32

33
34
35
36
**Note:**  All of the following use 4 dataloader subprocesses (`--workers 4`) to reduce potential
CPU data loading bottlenecks.

**Note:**  `--opt-level` `O1` and `O2` both use dynamic loss scaling by default unless manually overridden.
Michael Carilli's avatar
Michael Carilli committed
37
38
39
`--opt-level` `O0` and `O3` (the "pure" training modes) do not use loss scaling by default.
`O0` and `O3` can be told to use loss scaling via manual overrides, but using loss scaling with `O0`
(pure FP32 training) does not really make sense, and will trigger a warning.
40

Michael Carilli's avatar
Michael Carilli committed
41
Softlink training and validation datasets into the current directory:
Michael Carilli's avatar
Michael Carilli committed
42
```
43
44
$ ln -sf /data/imagenet/train-jpeg/ train
$ ln -sf /data/imagenet/val-jpeg/ val
45
46
```

Michael Carilli's avatar
Michael Carilli committed
47
48
### Summary

Michael Carilli's avatar
Michael Carilli committed
49
Amp allows easy experimentation with various pure and mixed precision options.
Michael Carilli's avatar
Michael Carilli committed
50
51
52
```
$ python main_amp.py -a resnet50 --b 128 --workers 4 --opt-level O0 ./
$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O3 ./
Michael Carilli's avatar
Michael Carilli committed
53
$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O3 --keep-batchnorm-fp32 True ./
Michael Carilli's avatar
Michael Carilli committed
54
55
56
57
58
59
60
$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O1 ./
$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O1 --loss-scale 128.0 ./
$ python -m torch.distributed.launch --nproc_per_node=2 main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O1 ./
$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O2 ./
$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O2 --loss-scale 128.0 ./
$ python -m torch.distributed.launch --nproc_per_node=2 main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O2 ./
```
mcarilli's avatar
mcarilli committed
61
Options are explained below.  Again, the [updated API guide](https://nvidia.github.io/apex/amp.html) provides more detail.
Michael Carilli's avatar
Michael Carilli committed
62
63

#### `--opt-level O0` (FP32 training) and `O3` (FP16 training)
Michael Carilli's avatar
Michael Carilli committed
64
65

"Pure FP32" training:
66
67
68
```
$ python main_amp.py -a resnet50 --b 128 --workers 4 --opt-level O0 ./
```
Michael Carilli's avatar
Michael Carilli committed
69
"Pure FP16" training:
70
71
72
```
$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O3 ./
```
Michael Carilli's avatar
Michael Carilli committed
73
FP16 training with FP32 batchnorm:
74
```
Michael Carilli's avatar
Michael Carilli committed
75
$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O3 --keep-batchnorm-fp32 True ./
76
```
Michael Carilli's avatar
Michael Carilli committed
77
Keeping the batchnorms in FP32 improves stability and allows Pytorch
78
79
to use cudnn batchnorms, which significantly increases speed in Resnet50.

Michael Carilli's avatar
Michael Carilli committed
80
The `O3` options might not converge, because they are not true mixed precision.
81
However, they can be useful to establish "speed of light" performance for
Michael Carilli's avatar
Michael Carilli committed
82
your model, which provides a baseline for comparison with `O1` and `O2`.
Michael Carilli's avatar
Michael Carilli committed
83
84
For Resnet50 in particular, `--opt-level O3 --keep-batchnorm-fp32 True` establishes
the "speed of light."  (Without `--keep-batchnorm-fp32`, it's slower, because it does
85
86
not use cudnn batchnorm.)

Michael Carilli's avatar
Michael Carilli committed
87
#### `--opt-level O1` ("conservative mixed precision")
Michael Carilli's avatar
Michael Carilli committed
88
89
90
91
92

`O1` patches Torch functions to cast inputs according to a whitelist-blacklist model.
FP16-friendly (Tensor Core) ops like gemms and convolutions run in FP16, while ops
that benefit from FP32, like batchnorm and softmax, run in FP32.
Also, dynamic loss scaling is used by default.
93
94
95
```
$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O1 ./
```
Michael Carilli's avatar
Michael Carilli committed
96
`O1` overridden to use static loss scaling:
97
98
99
```
$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O1 --loss-scale 128.0
```
Michael Carilli's avatar
Michael Carilli committed
100
101
Distributed training with 2 processes (1 GPU per process, see **Distributed training** below
for more detail)
102
103
104
105
106
107
```
$ python -m torch.distributed.launch --nproc_per_node=2 main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O1 ./
```
For best performance, set `--nproc_per_node` equal to the total number of GPUs on the node
to use all available resources.

Michael Carilli's avatar
Michael Carilli committed
108
#### `--opt-level O2` ("fast mixed precision")
Michael Carilli's avatar
Michael Carilli committed
109
110
111

`O2` casts the model to FP16, keeps batchnorms in FP32,
maintains master weights in FP32, and implements
112
113
114
115
116
117
118
119
120
121
122
123
124
dynamic loss scaling by default. (Unlike --opt-level O1, --opt-level O2
does not patch Torch functions.)
```
$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O2 ./
```
"Fast mixed precision" overridden to use static loss scaling:
```
$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O2 --loss-scale 128.0 ./
```
Distributed training with 2 processes (1 GPU per process)
```
$ python -m torch.distributed.launch --nproc_per_node=2 main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O2 ./
```
125

126
127
## Distributed training

128
129
130
`main_amp.py` optionally uses `apex.parallel.DistributedDataParallel` (DDP) for multiprocess training with one GPU per process.
```
model = apex.parallel.DistributedDataParallel(model)
131
```
132
is a drop-in replacement for
133
```
134
135
136
137
138
139
140
model = torch.nn.parallel.DistributedDataParallel(model,
                                                  device_ids=[arg.local_rank],
                                                  output_device=arg.local_rank)
```
(because Torch DDP permits multiple GPUs per process, with Torch DDP you are required to
manually specify the device to run on and the output device.
With Apex DDP, it uses only the current device by default).
141

142
The choice of DDP wrapper (Torch or Apex) is orthogonal to the use of Amp and other Apex tools.  It is safe to use `apex.amp` with either `torch.nn.parallel.DistributedDataParallel` or `apex.parallel.DistributedDataParallel`.  In the future, I may add some features that permit optional tighter integration between `Amp` and `apex.parallel.DistributedDataParallel` for marginal performance benefits, but currently, there's no compelling reason to use Apex DDP versus Torch DDP for most models.
jjsjann123's avatar
jjsjann123 committed
143

Michael Carilli's avatar
Michael Carilli committed
144
145
146
147
148
149
150
151
152
153
To use DDP with `apex.amp`, the only gotcha is that
```
model, optimizer = amp.initialize(model, optimizer, flags...)
```
must precede
```
model = DDP(model)
```
If DDP wrapping occurs before `amp.initialize`, `amp.initialize` will raise an error.

154
155
156
157
With both Apex DDP and Torch DDP, you must also call `torch.cuda.set_device(args.local_rank)` within
each process prior to initializing your model or any other tensors.
More information can be found in the docs for the
Pytorch multiprocess launcher module [torch.distributed.launch](https://pytorch.org/docs/stable/distributed.html#launch-utility).
158

159
160
161
162
163
`main_amp.py` is written to interact with 
[torch.distributed.launch](https://pytorch.org/docs/master/distributed.html#launch-utility),
which spawns multiprocess jobs using the following syntax:
```
python -m torch.distributed.launch --nproc_per_node=NUM_GPUS main_amp.py args...
164
```
165
166
167
168
169
170
`NUM_GPUS` should be less than or equal to the number of visible GPU devices on the node.  The use of `torch.distributed.launch` is unrelated to the choice of DDP wrapper.  It is safe to use either apex DDP or torch DDP with `torch.distributed.launch`.

Optionally, one can run imagenet with synchronized batch normalization across processes by adding
`--sync_bn` to the `args...`

## Deterministic training (for debugging purposes)
171

172
173
174
175
Running with the `--deterministic` flag should produce bitwise identical outputs run-to-run,
regardless of what other options are used (see [Pytorch docs on reproducibility](https://pytorch.org/docs/stable/notes/randomness.html)).
Since `--deterministic` disables `torch.backends.cudnn.benchmark`, `--deterministic` may
cause a modest performance decrease.