Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
b82c6bd7
"...hute/git@developer.sourcefind.cn:OpenDAS/hytlass.git" did not exist on "d22dbec28b2dcc026b7c19a57ed71ce1ea9ed1b2"
Commit
b82c6bd7
authored
Jun 03, 2019
by
Michael Carilli
Browse files
Adding min_loss_scale and max_loss_scale arguments to amp.initialize
parent
8be5b6be
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
24 additions
and
184 deletions
+24
-184
apex/amp/README.md
apex/amp/README.md
+1
-179
apex/amp/_initialize.py
apex/amp/_initialize.py
+3
-1
apex/amp/frontend.py
apex/amp/frontend.py
+10
-0
apex/amp/scaler.py
apex/amp/scaler.py
+10
-4
No files found.
apex/amp/README.md
View file @
b82c6bd7
# amp: Automatic Mixed Precision
## This README documents the deprecated (pre-unified) API.
## Documentation for the current unified API can be found [here](https://nvidia.github.io/apex/)
amp is an experimental tool to enable mixed precision training in
PyTorch with extreme simplicity and overall numerical safety. It
does so by employing a whitelist / blacklist model:
-
Any function on the whitelist casts its input arguments to
fp16. These are functions like
`torch.conv2d`
that can take
advantage of TensorCore execution.
-
Any function on the blacklist casts its input arguments to
fp32. These are functions like
`torch.exp`
or loss functions that
have trouble with the numerical properties of fp16.
-
Any other function passes along its input types to its outputs. Care
is taken so that multi-argument functions or methods
(e.g.
`torch.tensor.__add__`
) can handle mixed type inputs. They
simply promote all inputs to have the widest type of any input.
The PyTorch hooks that enable the necessary casts are at the low-level
functional interface to PyTorch, so even custom layers will work with
amp, so long as they are built out of PyTorch functions and methods.
In particular, amp hooks into all of the following:
-
Functions in the top-level
`torch`
namespace
-
Functions in the
`torch.nn.functional`
namespace
-
Methods on
`Tensor`
objects (GPU only, fp16 and fp32)
-
Custom support for RNNs, even though they have no direct functional
interface:
-
Recurrent cells:
`torch.nn.{RNNCell, LSTMCell, GRUCell}`
-
Recurrent layers:
`torch.nn.{RNN, LSTM, GRU}`
In a few limited cases, amp needs help finding custom user-defined
functions that use low-level PyTorch features. In those cases, a
simple annotation is sufficient; this is described below.
## Installation and Requirements
amp is developed on Python 3.6 and PyTorch 0.4. It takes care to be
backwards-compatible with PyTorch 0.3, but users are _highly_
encouraged to upgrade.
amp is installed during normal apex installation, so refer to the
top-level README for more on installation.
## Usage and Getting Started
In the common case, using amp requires adding two lines of code (and
an import). The first enables amp, so that it can hook into all the
relevant PyTorch functions. The second tells it where backpropagation
occurs so that it can properly scale the loss and clear internal
per-iteration state.
#### 1. Enable amp
```
python
from
apex
import
amp
amp_handle
=
amp
.
init
()
```
`amp.init()`
takes three (optional) arguments. The most useful is
`enabled`
(default=True), which simplifies command-line arguments. If
False, then everything amp does will be a zero-overhead pass-through
-- i.e., your code will run as-is.
For the other two options, the defaults are _highly_ recommended. The
first,
`enable_caching`
(default=True), indicates whether amp should
cache fp16 casts of model parameters on a per-iteration basis. This
prevents things like RNN cells used inside a loop from casting their
weight matrices over and over. The second,
`verbose`
(default=False)
toggles whether to print out every cast that occurs. Useful for
debugging, mostly.
#### 2. Wrap backpropagation
Nearly all PyTorch training scripts have a loop that looks like:
```
python
# ... do a bunch of stuff to compute a loss
loss
.
backward
()
optimizer
.
step
()
# ...finish the iteration
```
To use amp, you need only tell it where backprop occurs:
```
python
# ... same as before
with
amp_handle
.
scale_loss
(
loss
,
optimizer
)
as
scaled_loss
:
scaled_loss
.
backward
()
optimizer
.
step
()
# ... same as before
```
This context manager allows amp to:
1.
Use automatic loss scaling to best use fp16 range
2.
Clear its cache of casted parameters before the next optimizer step
Note that it is _possible_ to use amp without step 2. In which case,
you will not get automatic loss scaling, nor is it safe to
`enable_caching`
. (Power user note: you can manually clear the cache
after each optimizer step with
`amp_handle._clear_cache()`
.)
## Multiple Optimizers or Backward Passes
Step (2) from the previous section works when you have one PyTorch
optimizer and a single
`loss.backward()`
for each iteration. Some
models are more complex with:
-
Multiple optimizer objects (over different parameters)
-
Multiple backward passes for each iteration, taking advantage of
PyTorch's gradient accumulation
To work with such models, amp requires you to explicitly wrap each
optimizer and indicate if it will have more than one backward pass
per-iteration.
#### Explicitly wrapping optimizers
If you have more than one optimizer, then you must explicitly wrap
each. (You can also do so with a single optimizer.) First, wrap the
optimizer after initializing amp:
```
python
optimizer
=
# ... some optimizer
amp_handle
=
amp
.
init
()
optimizer
=
amp_handle
.
wrap_optimizer
(
optimizer
)
```
Second, use
`optimizer.scale_loss(...)`
to indicate where backprop
occurs:
```
python
with
optimizer
.
scale_loss
(
loss
)
as
scaled_loss
:
scaled_loss
.
backward
()
optimizer
.
step
()
# ...
```
In essence,
`amp_handle.scale_loss(loss, optimizer)`
is syntactic
sugar for first wrapping the optimizer and then calling
`optimizer.scale_loss(loss)`
in the single-optimizer case. But in the
multi-optimizer case, you must wrap each optimizer individually.
#### Handling multiple backward passes
PyTorch accumulates parameter gradients between calls to
`zero_grad()`
, so it is possible to perform multiple backward passes
before making a parameter update:
```
python
optimizer
.
zero_grad
()
loss1
=
ComputeLoss1
(
model
)
loss1
.
backward
()
# ...
loss2
=
ComputeLoss2
(
model
)
loss2
.
backward
()
# ...
optimizer
.
step
()
# has gradient contributions from both backward passes
```
The amp optimizer wrapper supports an additional argument
`num_loss`
to work with code like this:
```
python
amp_handle
=
amp
.
init
()
optimizer
=
amp_handle
.
wrap_optimizer
(
optimizer
,
num_loss
=
2
)
# ...
optimizer
.
zero_grad
()
loss1
=
ComputeLoss1
(
model
)
with
optimizer
.
scale_loss
(
loss1
)
as
scaled_loss
:
scaled_loss
.
backward
()
# ...
loss2
=
ComputeLoss2
(
model
)
with
optimizer
.
scale_loss
(
loss2
)
as
scaled_loss
:
scaled_loss
.
backward
()
# ...
optimizer
.
step
()
```
## Annotating User Functions
Nearly all PyTorch user code needs nothing more than the two steps
...
...
@@ -238,7 +61,7 @@ registration:
When using this API,
`module`
is the containing class or module for
the function, and
`function_name`
is the _string_ name of the
function. Note that the function must be registered before the call to
`amp.init()`
.
`amp.init
alize
()`
.
For our FRU unit, we can register the backend function directly:
...
...
@@ -246,5 +69,4 @@ For our FRU unit, we can register the backend function directly:
import
backend
amp
.
register_half_function
(
backend
,
'FRUBackend'
)
amp
.
init
()
```
apex/amp/_initialize.py
View file @
b82c6bd7
...
...
@@ -231,7 +231,9 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
_amp_state
.
loss_scalers
=
[]
for
_
in
range
(
num_losses
):
_amp_state
.
loss_scalers
.
append
(
LossScaler
(
properties
.
loss_scale
))
_amp_state
.
loss_scalers
.
append
(
LossScaler
(
properties
.
loss_scale
,
min_loss_scale
=
_amp_state
.
min_loss_scale
,
max_loss_scale
=
_amp_state
.
max_loss_scale
))
if
properties
.
patch_torch_functions
:
# handle is unused here. It's accessible later through a global value anyway.
...
...
apex/amp/frontend.py
View file @
b82c6bd7
...
...
@@ -204,6 +204,8 @@ def initialize(
cast_model_outputs
=
None
,
num_losses
=
1
,
verbosity
=
1
,
min_loss_scale
=
None
,
max_loss_scale
=
2.
**
24
):
"""
Initialize your models, optimizers, and the Torch tensor and functional namespace according to the
...
...
@@ -251,6 +253,11 @@ def initialize(
support multiple losses/backward passes, but use a single global loss scale
for all of them.
verbosity (int, default=1): Set to 0 to suppress Amp-related output.
min_loss_scale (float, default=None): Sets a floor for the loss scale values that can be chosen by dynamic
loss scaling. The default value of None means that no floor is imposed.
If dynamic loss scaling is not used, `min_loss_scale` is ignored.
max_loss_scale (float, default=2.**24): Sets a ceiling for the loss scale values that can be chosen by
dynamic loss scaling. If dynamic loss scaling is not used, `max_loss_scale` is ignored.
Returns:
Model(s) and optimizer(s) modified according to the ``opt_level``.
...
...
@@ -318,6 +325,9 @@ def initialize(
for
k
,
v
in
_amp_state
.
opt_properties
.
options
.
items
():
maybe_print
(
"{:22} : {}"
.
format
(
k
,
v
),
True
)
_amp_state
.
min_loss_scale
=
min_loss_scale
_amp_state
.
max_loss_scale
=
max_loss_scale
maybe_print
(
"Processing user overrides (additional kwargs that are not None)..."
,
True
)
# I chose to have the keyword arguments listed directly in the argument list,
# instead of **kwargs, so I can't use kwargs.items() here.
...
...
apex/amp/scaler.py
View file @
b82c6bd7
...
...
@@ -40,14 +40,17 @@ class LossScaler(object):
loss_scale
,
init_scale
=
2.
**
16
,
scale_factor
=
2.
,
scale_window
=
2000
):
scale_window
=
2000
,
min_loss_scale
=
None
,
max_loss_scale
=
2.
**
24
):
if
loss_scale
==
"dynamic"
:
self
.
dynamic
=
True
self
.
_loss_scale
=
init_scale
else
:
self
.
dynamic
=
False
self
.
_loss_scale
=
loss_scale
self
.
_max_loss_scale
=
2.
**
24
self
.
_max_loss_scale
=
max_loss_scale
self
.
_min_loss_scale
=
min_loss_scale
self
.
_scale_seq_len
=
scale_window
self
.
_unskipped
=
0
self
.
_has_overflow
=
False
...
...
@@ -191,14 +194,17 @@ class LossScaler(object):
if
self
.
_has_overflow
and
self
.
dynamic
:
should_skip
=
True
self
.
_loss_scale
/=
2.
if
(
self
.
_min_loss_scale
):
self
.
_loss_scale
=
max
(
self
.
_min_loss_scale
,
self
.
_loss_scale
/
2.
)
else
:
self
.
_loss_scale
=
self
.
_loss_scale
/
2.
self
.
_unskipped
=
0
else
:
should_skip
=
False
self
.
_unskipped
+=
1
if
self
.
_unskipped
==
self
.
_scale_seq_len
and
self
.
dynamic
:
self
.
_loss_scale
=
min
(
self
.
_max_loss_scale
,
self
.
_loss_scale
*
2.
)
self
.
_loss_scale
=
min
(
self
.
_max_loss_scale
,
self
.
_loss_scale
*
2.
)
self
.
_unskipped
=
0
return
should_skip
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment