amp.rst 10.3 KB
Newer Older
Michael Carilli's avatar
Michael Carilli committed
1
2
3
4
5
6
.. role:: hidden
    :class: hidden-section

apex.amp
===================================

Michael Carilli's avatar
Michael Carilli committed
7
This page documents the updated API for Amp (Automatic Mixed Precision),
Michael Carilli's avatar
Michael Carilli committed
8
a tool to enable Tensor Core-accelerated training in only 3 lines of Python.
Michael Carilli's avatar
Michael Carilli committed
9

10
11
Users **should not** manually cast their model or data to ``.half()``, regardless of what ``opt_level``
or properties are chosen.  Amp intends that users start with an existing default (FP32) script,
Michael Carilli's avatar
Michael Carilli committed
12
add the three lines corresponding to the Amp API, and begin training with mixed precision.
13
Amp can also be disabled, in which case the original script will behave exactly as it used to.
Michael Carilli's avatar
Michael Carilli committed
14
In this way, there's no risk adhering to the Amp API, and a lot of potential performance benefit.
15
16

Example::
Michael Carilli's avatar
Michael Carilli committed
17
18

        # Declare model and optimizer as usual
19
20
        model = torch.nn.Linear(D_in, D_out).cuda().half()
        optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
Michael Carilli's avatar
Michael Carilli committed
21
22

        # Allow Amp to perform casts as required by the opt_level
23
24
25
26
27
28
29
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
        ...
        # loss.backward() becomes:
        with amp.scale_loss(loss, optimizer) as scaled_loss:
            scaled_loss.backward()
        ...

Michael Carilli's avatar
Michael Carilli committed
30
31
32
A `runnable, comprehensive Imagenet example`_ demonstrating good practices can be found
on the Github page.

Michael Carilli's avatar
Michael Carilli committed
33
34
35
GANs are a tricky case that many people have requested.  A `comprehensive DCGAN example`_
is under construction.

Michael Carilli's avatar
Michael Carilli committed
36
37
38
.. _`runnable, comprehensive Imagenet example`:
    https://github.com/NVIDIA/apex/tree/master/examples/imagenet

Michael Carilli's avatar
Michael Carilli committed
39
40
41
.. _`comprehensive DCGAN example`:
    https://github.com/NVIDIA/apex/tree/master/examples/dcgan

Michael Carilli's avatar
Michael Carilli committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
``opt_level``\ s and Properties
-------------------------------

Amp allows users to easily experiment with different pure and mixed precision modes, including
pure FP16 training and pure FP32 training.  Commonly-used default modes are chosen by
selecting an "optimization level" or ``opt_level``; each ``opt_level`` establishes a set of
properties that govern Amp's implementation of pure or mixed precision training.
Finer-grained control of how a given ``opt_level`` behaves can be achieved by passing values for
particular properties directly to ``amp.initialize``.  These manually specified values will
override the defaults established by the ``opt_level``.

Properties
**********

Currently, the under-the-hood properties that govern pure or mixed precision training are the following:

- ``cast_model_type``:  Casts your model's parameters and buffers to the desired type.
- ``patch_torch_functions``: Patch all Torch functions and Tensor methods to perform Tensor Core-friendly ops like GEMMs and convolutions in FP16, and any ops that benefit from FP32 precision in FP32.
- ``keep_batchnorm_fp32``:  To enhance precision and enable cudnn batchnorm (which improves performance), it's often beneficial to keep batchnorms in particular in FP32 even if the rest of the model is FP16.
- ``master_weights``:  Maintain FP32 master weights to accompany any FP16 model weights.  FP32 master weights are stepped by the optimizer to enhance precision and capture small gradients.
- ``loss_scale``:  If ``loss_scale`` is a float value, use this value as the static (fixed) loss scale.  If ``loss_scale`` is the string ``"dynamic"``, adapatively adjust the loss scale over time.  Dynamic loss scale adjustments are performed by Amp automatically.

Again, you often don't need to specify these properties by hand.  Instead, select an ``opt_level``,
which will set them up for you.  After selecting an ``opt_level``, you can optionally pass property
kwargs as manual overrides.

If you attempt to override a property that does not make sense for the selected ``opt_level``,
Amp will raise an error with an explanation.  For example, selecting ``opt_level="O1"`` combined with
the override ``master_weights=True`` does not make sense.  ``O1`` inserts casts
around Torch functions rather than model weights.  Data, activations, and weights are recast
out-of-place on the fly as they flow through patched functions.  Therefore, the model weights themselves
can (and should) remain FP32, and there is no need to maintain separate FP32 master weights.

``opt_level``\ s
****************

``O0`` and ``O3`` are not true mixed precision, but they are useful for establishing baselines.

``O0``:  FP32 training
^^^^^^^^^^^^^^^^^^^^^^
Your incoming model should be FP32 already, so this is likely a no-op.
``O0`` can be useful to establish an accuracy baseline.

| Default properties set by ``O0``:
| ``cast_model_type=torch.float32``
| ``patch_torch_functions=False``
| ``keep_batchnorm_fp32=None`` (effectively, "not applicable")
| ``master_weights=False``
| ``loss_scale=1.0``
|
|

``O3``:  FP16 training
^^^^^^^^^^^^^^^^^^^^^^
``O3`` may not achieve the stability of the true mixed precision options ``O1`` and ``O2``.
However, it can be useful to establish a speed baseline for your model, against which
the performance of ``O1`` and ``O2`` can be compared.  If your model uses batch normalization,
to establish "speed of light" you can try ``O3`` with the additional property override
``keep_batchnorm_fp32=True`` (which enables cudnn batchnorm, as stated earlier).

| Default properties set by ``O3``:
| ``cast_model_type=torch.float16``
| ``patch_torch_functions=False``
| ``keep_batchnorm_fp32=False``
| ``master_weights=False``
| ``loss_scale=1.0``
|
|

``O1`` and ``O2`` are different implementations of mixed precision.  Try both, and see
what gives the best speedup and accuracy for your model.

``O1``:  Conservative Mixed Precision
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Patch all Torch functions and Tensor methods to cast their inputs according to a whitelist-blacklist
model.  Whitelist ops (for example, Tensor Core-friendly ops like GEMMs and convolutions) are performed
in FP16.  Blacklist ops that benefit from FP32 precision (for example, batchnorm and softmax)
are performed in FP32.  ``O1`` also uses dynamic loss scaling, unless overridden.

| Default properties set by ``O1``:
| ``cast_model_type=None`` (not applicable)
| ``patch_torch_functions=True``
| ``keep_batchnorm_fp32=None`` (not necessary to specify True, batchnorm inputs are cast to FP32)
| ``master_weights=None`` (not applicable, model weights remain FP32)
| ``loss_scale="dynamic"``
|
|

``O2``:  Fast Mixed Precision
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
``O2`` casts the model to FP16, keeps batchnorms in FP32, maintains master weights in FP32,
and implements dynamic loss scaling (unless overridden).
Unlike ``O1``, ``O2`` does not patch Torch functions or Tensor methods.

| Default properties set by ``O2``:
| ``cast_model_type=torch.float16``
| ``patch_torch_functions=False``
| ``keep_batchnorm_fp32=True``
| ``master_weights=True``
| ``loss_scale="dynamic"``
|
|


Unified API
-----------

150
151
152
153
154
155
156
.. automodule:: apex.amp
.. currentmodule:: apex.amp

.. autofunction:: initialize

.. autofunction:: scale_loss

Michael Carilli's avatar
Michael Carilli committed
157
158
.. autofunction:: master_params

Michael Carilli's avatar
Michael Carilli committed
159
160
161
Advanced use cases
------------------

Michael Carilli's avatar
Michael Carilli committed
162
The unified Amp API supports gradient accumulation across iterations,
Michael Carilli's avatar
Michael Carilli committed
163
multiple backward passes per iteration, multiple models/optimizers,
Michael Carilli's avatar
Michael Carilli committed
164
165
166
and custom/user-defined autograd functions.  Gradient clipping and GANs also
require special treatment, but this treatment does not need to change
for different ``opt_level``\ s.  Further details can be found here:
Michael Carilli's avatar
Michael Carilli committed
167
168
169
170
171
172

.. toctree::
   :maxdepth: 1

   advanced

Michael Carilli's avatar
Michael Carilli committed
173
Transition guide for old API users
Michael Carilli's avatar
Michael Carilli committed
174
175
176
177
----------------------------------

We strongly encourage moving to the new Amp API, because it's more versatile, easier to use, and future proof.  The original :class:`FP16_Optimizer` and the old "Amp" API are deprecated, and subject to removal at at any time.

Michael Carilli's avatar
Michael Carilli committed
178
179
For users of the old "Amp" API
******************************
Michael Carilli's avatar
Michael Carilli committed
180

Michael Carilli's avatar
Michael Carilli committed
181
182
183
In the new API, ``opt-level O1`` performs the same patching of the Torch namespace as the old thing
called "Amp."
However, the new API allows static or dynamic loss scaling, while the old API only allowed dynamic loss scaling.
Michael Carilli's avatar
Michael Carilli committed
184
185
186
187
188
189
190
191
192

In the new API, the old call to ``amp_handle = amp.init()``, and the returned ``amp_handle``, are no
longer exposed or necessary.  The new ``amp.initialize()`` does the duty of ``amp.init()`` (and more).
Therefore, any existing calls to ``amp_handle = amp.init()`` should be deleted.

The functions formerly exposed through ``amp_handle`` are now free
functions accessible through the ``amp`` module.

The backward context manager must be changed accordingly::
Michael Carilli's avatar
Michael Carilli committed
193

Michael Carilli's avatar
Michael Carilli committed
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    # old API
    with amp_handle.scale_loss(loss, optimizer) as scaled_loss:
        scaled_loss.backward()
    ->
    # new API
    with amp.scale_loss(loss, optimizer) as scaled_loss:
        scaled_loss.backward()

For now, the deprecated "Amp" API documentation can still be found on the Github README:  https://github.com/NVIDIA/apex/tree/master/apex/amp.  The old API calls that `annotate user functions`_ to run
with a particular precision are still honored by the new API.

.. _`annotate user functions`:
    https://github.com/NVIDIA/apex/tree/master/apex/amp#annotating-user-functions


Michael Carilli's avatar
Michael Carilli committed
209
210
For users of the old FP16_Optimizer
***********************************
Michael Carilli's avatar
Michael Carilli committed
211
212
213

``opt-level O2`` is equivalent to :class:`FP16_Optimizer` with ``dynamic_loss_scale=True``.
Once again, the backward pass must be changed to the unified version::
Michael Carilli's avatar
Michael Carilli committed
214

Michael Carilli's avatar
Michael Carilli committed
215
216
217
218
    optimizer.backward(loss)
    ->
    with amp.scale_loss(loss, optimizer) as scaled_loss:
        scaled_loss.backward()
219

Michael Carilli's avatar
Michael Carilli committed
220
221
222
223
224
225
226
One annoying aspect of FP16_Optimizer was that the user had to manually convert their model to half
(either by calling ``.half()`` on it, or using a function or module wrapper from
``apex.fp16_utils``), and also manually call ``.half()`` on input data.  **Neither of these are
necessary in the new API.  No matter what --opt-level
you choose, you can and should simply build your model in the default FP32 format.**  The new Amp
API will perform the right conversions during
``model, optimizer = amp.initialize(model, optimizer, opt_level=....)`` based on the ``--opt-level``
Michael Carilli's avatar
Michael Carilli committed
227
228
and any overridden flags.  Floating point input data may be FP32 or FP16, but you may as well just
let it be FP16, because the ``model`` returned by ``amp.initialize`` will have its ``forward``
Michael Carilli's avatar
Michael Carilli committed
229
method patched to cast the input data appropriately.
230

Michael Carilli's avatar
Michael Carilli committed
231
232
233
.. note::
    Aside from the call to ``amp.initialize`` itself, it's never necessary to manually cast
    your model or data with the new API.  Therefore, a script that adheres to the new API
Michael Carilli's avatar
Michael Carilli committed
234
    can switch between different ``opt-level``\ s without having to make any other changes.