Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
d6db91a4
Commit
d6db91a4
authored
Jun 06, 2018
by
Michael Carilli
Browse files
Updating latest amp changes to use new C++ backend
parents
89564b69
db6ae13a
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
467 additions
and
90 deletions
+467
-90
.gitignore
.gitignore
+2
-1
apex/amp/README.md
apex/amp/README.md
+144
-27
apex/amp/__init__.py
apex/amp/__init__.py
+2
-1
apex/amp/amp.py
apex/amp/amp.py
+68
-18
apex/amp/handle.py
apex/amp/handle.py
+54
-27
apex/amp/lists/functional_overrides.py
apex/amp/lists/functional_overrides.py
+12
-1
apex/amp/opt.py
apex/amp/opt.py
+108
-0
apex/amp/scaler.py
apex/amp/scaler.py
+41
-0
apex/amp/utils.py
apex/amp/utils.py
+8
-0
apex/amp/wrap.py
apex/amp/wrap.py
+28
-15
No files found.
.gitignore
View file @
d6db91a4
apex.egg-info
apex.egg-info
dist
dist
build
build
docs/build
docs/build
\ No newline at end of file
*~
\ No newline at end of file
apex/amp/README.md
View file @
d6db91a4
...
@@ -41,7 +41,7 @@ top-level README for more on installation.
...
@@ -41,7 +41,7 @@ top-level README for more on installation.
## Usage and Getting Started
## Usage and Getting Started
In the
normal
case, using amp requires adding two lines of code (and
In the
common
case, using amp requires adding two lines of code (and
an import). The first enables amp, so that it can hook into all the
an import). The first enables amp, so that it can hook into all the
relevant PyTorch functions. The second tells it where backpropagation
relevant PyTorch functions. The second tells it where backpropagation
occurs so that it can properly scale the loss and clear internal
occurs so that it can properly scale the loss and clear internal
...
@@ -50,20 +50,25 @@ per-iteration state.
...
@@ -50,20 +50,25 @@ per-iteration state.
#### 1. Enable amp
#### 1. Enable amp
```
python
```
python
from
apex
import
amp
from
apex
import
amp
amp_handle
=
amp
.
enable
()
amp_handle
=
amp
.
init
()
```
```
`amp.enable()`
takes two arguments, and the defaults are _highly_
`amp.init()`
takes three (optional) arguments. The most useful is
recommended. The first,
`enable_caching`
(default=True), indicates
`enabled`
(default=True), which simplifies command-line arguments. If
whether amp should cache fp16 casts of model parameters on a
False, then everything amp does will be a zero-overhead pass-through
per-iteration basis. This prevents things like RNN cells used inside a
-- i.e., your code will run as-is.
loop from casting their weight matrices over and over. The second,
`verbose`
(default=False) toggles whether to print out every cast that
For the other two options, the defaults are _highly_ recommended. The
occurs. Useful for debugging, mostly.
first,
`enable_caching`
(default=True), indicates whether amp should
cache fp16 casts of model parameters on a per-iteration basis. This
prevents things like RNN cells used inside a loop from casting their
weight matrices over and over. The second,
`verbose`
(default=False)
toggles whether to print out every cast that occurs. Useful for
debugging, mostly.
#### 2. Wrap backpropagation
#### 2. Wrap backpropagation
Nearly all PyTorch training scripts have a loop
s
that looks like:
Nearly all PyTorch training scripts have a loop that looks like:
```
python
```
python
# ... do a bunch of stuff to compute a loss
# ... do a bunch of stuff to compute a loss
...
@@ -91,9 +96,86 @@ you will not get automatic loss scaling, nor is it safe to
...
@@ -91,9 +96,86 @@ you will not get automatic loss scaling, nor is it safe to
`enable_caching`
. (Power user note: you can manually clear the cache
`enable_caching`
. (Power user note: you can manually clear the cache
after each optimizer step with
`amp_handle._clear_cache()`
.)
after each optimizer step with
`amp_handle._clear_cache()`
.)
## Multiple Optimizers or Backward Passes
Step (2) from the previous section works when you have one PyTorch
optimizer and a single
`loss.backward()`
for each iteration. Some
models are more complex with:
-
Multiple optimizer objects (over different parameters)
-
Multiple backward passes for each iteration, taking advantage of
PyTorch's gradient accumulation
To work with such models, amp requires you to explicitly wrap each
optimizer and indicate if it will have more than one backward pass
per-iteration.
#### Explicitly wrapping optimizers
If you have more than one optimizer, then you must explicitly wrap
each. (You can also do so with a single optimizer.) First, wrap the
optimizer after initializing amp:
```
python
optimizer
=
# ... some optimizer
amp_handle
=
amp
.
init
()
optimizer
=
amp_handle
.
wrap_optimizer
(
optimizer
)
```
Second, use
`optimizer.scale_loss(...)`
to indicate where backprop
occurs:
```
python
with
optimizer
.
scale_loss
(
loss
)
as
scaled_loss
:
scaled_loss
.
backward
()
optimizer
.
step
()
# ...
```
In essence,
`amp_handle.scale_loss(loss, optimizer)`
is syntactic
sugar for first wrapping the optimizer and then calling
`optimizer.scale_loss(loss)`
in the single-optimizer case. But in the
multi-optimizer case, you must wrap each optimizer individually.
#### Handling multiple backward passes
PyTorch accumulates parameter gradients between calls to
`zero_grad()`
, so it is possible to perform multiple backward passes
before making a parameter update:
```
python
optimizer
.
zero_grad
()
loss1
=
ComputeLoss1
(
model
)
loss1
.
backward
()
# ...
loss2
=
ComputeLoss2
(
model
)
loss2
.
backward
()
# ...
optimizer
.
step
()
# has gradient contributions from both backward passes
```
The amp optimizer wrapper supports an additional argument
`num_loss`
to work with code like this:
```
python
amp_handle
=
amp
.
init
()
optimizer
=
amp_handle
.
wrap_optimizer
(
optimizer
,
num_loss
=
2
)
# ...
optimizer
.
zero_grad
()
loss1
=
ComputeLoss1
(
model
)
with
optimizer
.
scale_loss
(
loss1
)
as
scaled_loss
:
scaled_loss
.
backward
()
# ...
loss2
=
ComputeLoss2
(
model
)
with
optimizer
.
scale_loss
(
loss2
)
as
scaled_loss
:
scaled_loss
.
backward
()
# ...
optimizer
.
step
()
```
## Annotating User Functions
## Annotating User Functions
Nearly all PyTorch user code needs nothing more than
steps one and two
Nearly all PyTorch user code needs nothing more than
the two steps
above to use amp. After all, custom layers are built out of simpler
above to use amp. After all, custom layers are built out of simpler
PyTorch components, and amp already can see those.
PyTorch components, and amp already can see those.
...
@@ -103,27 +185,62 @@ cell called a "forgetful recurrent unit" that calls directly into a
...
@@ -103,27 +185,62 @@ cell called a "forgetful recurrent unit" that calls directly into a
CUDA backend:
CUDA backend:
```
python
```
python
from
backend
import
FRUBackend
def
fru
(
input
,
hidden
,
weight
,
bias
):
def
fru
(
input
,
hidden
,
weight
,
bias
):
# ... call to CUDA code
# call to CUDA code
FRUBackend
(
input
,
hidden
,
weight
,
bias
)
```
```
amp exposes two functions to handle this case:
`register_fp16`
and
In this case, it is possible to get a runtime type mismatch. For
`register_fp32`
. These add the given function to the white or
example, you might have
`input`
in fp16, and
`weight`
in fp32, and amp
blacklist, respectively. You can use them as a decorator:
doesn't have the visibility to insert an appropriate cast.
amp exposes two ways to handle "invisible" backend code: function
annotations and explicit registration.
#### Function annotation
The first way to handle backend code is a set of function annotations:
-
`@amp.half_function`
-
`@amp.float_function`
-
`@amp.promote_function`
These correspond to:
-
Cast all arguments to fp16
-
Cast all argumnets fo fp32
-
If there are any type mismatches, cast everything to the widest type
In our example, we believe that the FRU unit is fp16-safe and will get
performance gains from casting its arguments to fp16, so we write:
```
python
```
python
@
amp
.
register_fp16
@
amp
.
half_function
def
fru
(
input
,
hidden
,
weight
,
bias
):
def
fru
(
input
,
hidden
,
weight
,
bias
):
#
...
#...
```
```
or as a library call:
#### Explicit registration
The other way to handle backend code is with explicit function
registration:
-
`amp.register_half_function(module, function_name)`
-
`amp.register_float_function(module, function_name)`
-
`amp.register_promote_function(module, function_name)`
When using this API,
`module`
is the containing class or module for
the function, and
`function_name`
is the _string_ name of the
function. Note that the function must be registered before the call to
`amp.init()`
.
For our FRU unit, we can register the backend function directly:
```
python
```
python
from
apex
import
amp
import
backend
amp
.
register_fp16
(
custom_module
.
fru
)
amp
.
enable
()
```
Note that the function must be registered before the call to
amp
.
register_half_function
(
backend
,
'FRUBackend'
)
`amp.enable()`
. The library call makes this simple. If the function is
amp
.
init
()
annotated, then you must ensure its module is loaded before the call
```
to
`amp.enable()`
. Furthermore, this does not (yet) work with class
methods, only free functions.
apex/amp/__init__.py
View file @
d6db91a4
from
.amp
import
enable
,
register_half
,
register_float
from
.amp
import
init
,
half_function
,
float_function
,
promote_function
,
\
register_half_function
,
register_float_function
,
register_promote_function
apex/amp/amp.py
View file @
d6db91a4
from
.
import
compat
,
utils
,
wrap
from
.
import
compat
,
utils
,
wrap
from
.handle
import
AmpHandle
from
.handle
import
AmpHandle
,
NoOpHandle
from
.lists
import
functional_overrides
,
torch_overrides
,
tensor_overrides
from
.lists
import
functional_overrides
,
torch_overrides
,
tensor_overrides
import
inspect
import
functools
import
itertools
import
itertools
import
torch
import
torch
_USER_REGISTRY
=
set
()
_DECORATOR_HANDLE
=
None
_USER_CAST_REGISTRY
=
set
()
_USER_PROMOTE_REGISTRY
=
set
()
def
_decorator_helper
(
orig_fn
,
cast_fn
,
wrap_fn
):
def
wrapper
(
*
args
,
**
kwargs
):
handle
=
_DECORATOR_HANDLE
if
handle
is
None
or
not
handle
.
is_active
():
return
orig_fn
(
*
args
,
**
kwargs
)
inner_cast_fn
=
utils
.
verbosify
(
cast_fn
,
orig_fn
.
__name__
,
handle
.
verbose
)
return
wrap_fn
(
orig_fn
,
inner_cast_fn
,
handle
)(
*
args
,
**
kwargs
)
return
wrapper
# Decorator form
def
half_function
(
fn
):
wrap_fn
=
functools
.
partial
(
wrap
.
make_cast_wrapper
,
try_caching
=
True
)
return
_decorator_helper
(
fn
,
utils
.
maybe_half
,
wrap_fn
)
def
float_function
(
fn
):
wrap_fn
=
functools
.
partial
(
wrap
.
make_cast_wrapper
,
try_caching
=
False
)
return
_decorator_helper
(
fn
,
utils
.
maybe_float
,
wrap_fn
)
def
promote_function
(
fn
):
wrap_fn
=
functools
.
partial
(
wrap
.
make_promote_wrapper
)
return
_decorator_helper
(
fn
,
utils
.
maybe_float
,
wrap_fn
)
# Registry form
def
register_half_function
(
module
,
name
):
if
not
hasattr
(
module
,
name
):
raise
ValueError
(
'No function named {} in module {}.'
.
format
(
name
,
module
))
_USER_CAST_REGISTRY
.
add
((
module
,
name
,
utils
.
maybe_half
))
def
register_float_function
(
module
,
name
):
if
not
hasattr
(
module
,
name
):
raise
ValueError
(
'No function named {} in module {}.'
.
format
(
name
,
module
))
_USER_CAST_REGISTRY
.
add
((
module
,
name
,
utils
.
maybe_float
))
def
register_promote_function
(
module
,
name
):
if
not
hasattr
(
module
,
name
):
raise
ValueError
(
'No function named {} in module {}.'
.
format
(
name
,
module
))
_USER_PROMOTE_REGISTRY
.
add
((
module
,
name
))
# Can be used as a @decorator directly on the fn
# Top-level function to insert _all_ the hooks.
# or called w/ arg by user before `enable()`
def
init
(
enabled
=
True
,
enable_caching
=
True
,
verbose
=
False
,
allow_banned
=
False
):
def
register_half
(
fn
):
global
_DECORATOR_HANDLE
mod
=
inspect
.
getmodule
(
fn
)
_USER_REGISTRY
.
add
((
mod
,
fn
.
__name__
,
utils
.
maybe_half
))
return
fn
def
register_float
(
fn
)
:
if
not
enabled
:
mod
=
inspect
.
getmodu
le
(
fn
)
handle
=
NoOpHand
le
()
_
USER_REGISTRY
.
add
((
mod
,
fn
.
__name__
,
utils
.
maybe_float
))
_
DECORATOR_HANDLE
=
handle
return
fn
return
handle
# Top-level function to insert _all_ the hooks.
handle
=
AmpHandle
(
enable_caching
,
verbose
)
def
enable
(
enable_caching
=
True
,
verbose
=
False
):
handle
=
AmpHandle
(
enable_caching
)
# 0) Force-{fp16, fp32} for user-annotated functions
# 0) Force-{fp16, fp32} for user-annotated functions
for
mod
,
fn
,
cast_fn
in
_USER_REGISTRY
:
for
mod
,
fn
,
cast_fn
in
_USER_
CAST_
REGISTRY
:
try_caching
=
(
cast_fn
==
utils
.
maybe_half
)
try_caching
=
(
cast_fn
==
utils
.
maybe_half
)
wrap
.
cached_cast
(
mod
,
fn
,
cast_fn
,
handle
,
wrap
.
cached_cast
(
mod
,
fn
,
cast_fn
,
handle
,
try_caching
,
verbose
)
try_caching
,
verbose
)
_USER_REGISTRY
.
clear
()
_USER_CAST_REGISTRY
.
clear
()
# 0.5) Force-promote for user-annotated functions
for
mod
,
fn
in
_USER_PROMOTE_REGISTRY
:
wrap
.
promote
(
mod
,
fn
,
verbose
)
_USER_PROMOTE_REGISTRY
.
clear
()
# 1) Force-{fp16, fp32} on white- / black-list functions
# 1) Force-{fp16, fp32} on white- / black-list functions
override_modules
=
[
functional_overrides
,
override_modules
=
[
functional_overrides
,
...
@@ -101,4 +145,10 @@ def enable(enable_caching=True, verbose=False):
...
@@ -101,4 +145,10 @@ def enable(enable_caching=True, verbose=False):
# 5.5) Extra-special handling of RNN backend
# 5.5) Extra-special handling of RNN backend
wrap
.
rnn_cast
(
torch
.
nn
.
backends
.
thnn
.
backend
,
'RNN'
,
verbose
)
wrap
.
rnn_cast
(
torch
.
nn
.
backends
.
thnn
.
backend
,
'RNN'
,
verbose
)
# 6) Place error+print message on banned functions
if
not
allow_banned
:
for
fn
,
err_msg
in
functional_overrides
.
BANNED_FUNCS
:
wrap
.
err_if_any_half
(
functional_overrides
.
MODULE
,
fn
,
err_msg
)
_DECORATOR_HANDLE
=
handle
return
handle
return
handle
apex/amp/handle.py
View file @
d6db91a4
...
@@ -2,54 +2,54 @@ import contextlib
...
@@ -2,54 +2,54 @@ import contextlib
import
logging
import
logging
import
warnings
import
warnings
import
torch
from
.opt
import
OptimWrapper
from
.scaler
import
LossScaler
from
apex_C
import
scale_check_overflow
class
AmpHandle
(
object
):
class
AmpHandle
(
object
):
def
__init__
(
self
,
enable_caching
=
True
):
def
__init__
(
self
,
enable_caching
=
True
,
verbose
=
False
):
self
.
_enable_caching
=
enable_caching
self
.
_enable_caching
=
enable_caching
self
.
_verbose
=
verbose
self
.
_cache
=
dict
()
self
.
_cache
=
dict
()
self
.
_loss_scale
=
2.
**
16
self
.
_default_scaler
=
LossScaler
()
self
.
_max_loss_scale
=
2.
**
24
self
.
_scale_seq_len
=
2000
def
is_active
(
self
):
self
.
_unskipped
=
0
return
True
self
.
_overflow_buf
=
torch
.
cuda
.
ByteTensor
(
1024
,)
def
wrap_optimizer
(
self
,
optimizer
,
num_loss
=
1
):
self
.
_default_scaler
=
None
return
OptimWrapper
(
optimizer
,
self
,
num_loss
)
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
scale_loss
(
self
,
loss
,
optimizer
):
def
scale_loss
(
self
,
loss
,
optimizer
):
if
not
self
.
is_active
():
yield
loss
return
if
self
.
_default_scaler
is
None
:
raise
RuntimeError
(
'After calling `handle.wrap_optimizer()`, you must explicitly '
+
'use `optimizer.scale_loss(loss)`.'
)
# TODO: this code block is duplicated here and `opt.py`. Unify.
loss_backward
=
loss
.
backward
loss_backward
=
loss
.
backward
def
warning_wrapper
():
def
warning_wrapper
():
warnings
.
warn
(
"You called .backward() on the unscaled loss "
warnings
.
warn
(
"You called .backward() on the unscaled loss "
"inside a scale_loss block. This is almost "
"inside a scale_loss block. This is almost "
"certainly an error."
,
stacklevel
=
2
)
"certainly an error."
,
stacklevel
=
2
)
loss_backward
()
loss_backward
()
loss
.
backward
=
warning_wrapper
loss
.
backward
=
warning_wrapper
yield
loss
*
self
.
_loss_scale
loss_scale
=
self
.
_default_scaler
.
loss_scale
()
yield
loss
*
loss_scale
loss
.
backward
=
loss_backward
loss
.
backward
=
loss_backward
self
.
_overflow_buf
.
zero_
()
should_skip
=
self
.
_default_scaler
.
unscale_and_update
(
for
group
in
optimizer
.
param_groups
:
optimizer
.
param_groups
,
loss_scale
)
for
p
in
group
[
'params'
]:
if
should_skip
:
if
p
.
grad
is
not
None
:
scale_check_overflow
(
p
.
grad
.
data
,
1.
/
self
.
_loss_scale
,
self
.
_overflow_buf
)
if
self
.
_overflow_buf
.
any
():
self
.
_loss_scale
/=
2.
optimizer_step
=
optimizer
.
step
optimizer_step
=
optimizer
.
step
def
skip_step
():
def
skip_step
():
logging
.
info
(
'Gradient overflow, skipping update'
)
logging
.
info
(
'Gradient overflow, skipping update'
)
optimizer
.
step
=
optimizer_step
optimizer
.
step
=
optimizer_step
optimizer
.
step
=
skip_step
optimizer
.
step
=
skip_step
self
.
_unskipped
=
0
else
:
self
.
_unskipped
+=
1
if
self
.
_unskipped
==
self
.
_scale_seq_len
:
self
.
_loss_scale
=
min
(
self
.
_max_loss_scale
,
self
.
_loss_scale
*
2.
)
self
.
_unskipped
=
0
self
.
_clear_cache
()
self
.
_clear_cache
()
...
@@ -63,3 +63,30 @@ class AmpHandle(object):
...
@@ -63,3 +63,30 @@ class AmpHandle(object):
@
property
@
property
def
cache
(
self
):
def
cache
(
self
):
return
self
.
_cache
return
self
.
_cache
def
remove_cache
(
self
,
param
):
if
self
.
has_cache
and
param
in
self
.
cache
:
del
self
.
cache
[
param
]
@
property
def
verbose
(
self
):
return
self
.
_verbose
class
NoOpHandle
(
object
):
def
is_active
(
self
):
return
False
def
wrap_optimizer
(
self
,
optimizer
,
num_loss
=
1
):
return
OptimWrapper
(
optimizer
,
self
,
num_loss
)
@
contextlib
.
contextmanager
def
scale_loss
(
self
,
loss
,
optimizer
):
yield
loss
@
property
def
has_cache
(
self
):
return
False
@
property
def
verbose
(
self
):
return
False
apex/amp/lists/functional_overrides.py
View file @
d6db91a4
...
@@ -42,7 +42,6 @@ FP32_FUNCS = [
...
@@ -42,7 +42,6 @@ FP32_FUNCS = [
# Loss functions
# Loss functions
# TODO: which of these can be fp16?
# TODO: which of these can be fp16?
'binary_cross_entropy'
,
'poisson_nll_loss'
,
'poisson_nll_loss'
,
'cosine_embedding_loss'
,
'cosine_embedding_loss'
,
'cross_entropy'
,
'cross_entropy'
,
...
@@ -60,3 +59,15 @@ FP32_FUNCS = [
...
@@ -60,3 +59,15 @@ FP32_FUNCS = [
'soft_margin_loss'
,
'soft_margin_loss'
,
'triplet_margin_loss'
'triplet_margin_loss'
]
]
BANNED_FUNCS
=
[
(
'binary_cross_entropy'
,
(
"
\n
amp does not work out-of-the-box with `F.binary_cross_entropy` or `torch.nn.BCELoss.` "
"It requires that the output of the previous function be already a FloatTensor.
\n\n
"
"Most models have a Sigmoid right before BCELoss. In that case, you can use
\n
"
" torch.nn.BCEWithLogitsLoss
\n
to combine Sigmoid+BCELoss into a single layer "
"that is compatible with amp.
\n
Another option is to add
\n
"
" amp.register_float_function(torch, 'sigmoid')
\n
before calling `amp.init()`.
\n
"
"If you _really_ know what you are doing, you can disable this warning by passing "
"allow_banned=True to `amp.init()`."
))
]
apex/amp/opt.py
0 → 100644
View file @
d6db91a4
import
contextlib
import
logging
import
warnings
from
.scaler
import
LossScaler
,
iter_params
import
numpy
as
np
class
OptimWrapper
(
object
):
def
__init__
(
self
,
optimizer
,
amp_handle
,
num_loss
):
self
.
_optimizer
=
optimizer
self
.
_amp_handle
=
amp_handle
self
.
_num_loss
=
num_loss
self
.
_loss_idx
=
0
self
.
_skip_next
=
[
False
]
*
num_loss
self
.
_loss_scaler
=
[
LossScaler
()
for
_
in
range
(
num_loss
)]
@
contextlib
.
contextmanager
def
scale_loss
(
self
,
loss
):
if
not
self
.
_amp_handle
.
is_active
():
yield
loss
return
loss_backward
=
loss
.
backward
def
warning_wrapper
():
warnings
.
warn
(
"You called .backward() on the unscaled loss "
"inside a scale_loss block. This is almost "
"certainly an error."
,
stacklevel
=
2
)
loss_backward
()
loss
.
backward
=
warning_wrapper
# When there are multiple losses per-optimizer, we need
# to save out current grad accumulation, since we won't be
# able to unscale this particulare loss once the grads are
# all mixed together.
cached_grads
=
[]
if
self
.
_loss_idx
>
0
:
for
p
in
iter_params
(
self
.
_optimizer
.
param_groups
):
if
p
.
grad
is
not
None
:
cached_grads
.
append
(
p
.
grad
.
data
.
detach
().
clone
())
else
:
cached_grads
.
append
(
None
)
self
.
_optimizer
.
zero_grad
()
loss_scale
=
self
.
_cur_loss_scaler
().
loss_scale
()
yield
loss
*
loss_scale
loss
.
backward
=
loss_backward
self
.
_skip_next
[
self
.
_loss_idx
]
=
self
.
_cur_loss_scaler
().
unscale_and_update
(
self
.
_optimizer
.
param_groups
,
loss_scale
)
self
.
_loss_idx
+=
1
if
len
(
cached_grads
)
>
0
:
for
p
,
cached_grad
in
zip
(
iter_params
(
self
.
_optimizer
.
param_groups
),
cached_grads
):
if
cached_grad
is
not
None
:
p
.
grad
.
data
.
add_
(
cached_grad
)
cached_grads
=
[]
def
_cur_loss_scaler
(
self
):
assert
0
<=
self
.
_loss_idx
<
self
.
_num_loss
return
self
.
_loss_scaler
[
self
.
_loss_idx
]
def
step
(
self
,
closure
=
None
):
if
not
self
.
_amp_handle
.
is_active
():
return
self
.
_optimizer
.
step
(
closure
=
closure
)
self
.
_loss_idx
=
0
for
group
in
self
.
_optimizer
.
param_groups
:
for
p
in
group
[
'params'
]:
self
.
_amp_handle
.
remove_cache
(
p
)
if
closure
is
not
None
:
raise
NotImplementedError
(
'The `closure` argument is unsupported by the amp '
+
'optimizer wrapper.'
)
if
any
(
self
.
_skip_next
):
logging
.
info
(
'Gradient overflow, skipping update'
)
self
.
_skip_next
=
[
False
]
*
self
.
_num_loss
else
:
return
self
.
_optimizer
.
step
(
closure
=
closure
)
# Forward any attribute lookups
def
__getattr__
(
self
,
attr
):
return
getattr
(
self
.
_optimizer
,
attr
)
# Forward all torch.optim.Optimizer methods
def
__getstate__
(
self
):
return
self
.
_optimizer
.
__getstate__
()
def
__setstate__
(
self
):
return
self
.
_optimizer
.
__setstate__
()
def
__repr__
(
self
):
return
self
.
_optimizer
.
__repr__
()
def
state_dict
(
self
):
return
self
.
_optimizer
.
state_dict
()
def
load_state_dict
(
self
,
state_dict
):
return
self
.
_optimizer
.
load_state_dict
(
state_dict
)
def
zero_grad
(
self
):
return
self
.
_optimizer
.
zero_grad
()
def
add_param_group
(
self
,
param_group
):
return
self
.
_optimizer
.
add_param_group
(
param_group
)
apex/amp/scaler.py
0 → 100644
View file @
d6db91a4
import
torch
from
apex_C
import
scale_check_overflow
class
LossScaler
(
object
):
def
__init__
(
self
):
self
.
_loss_scale
=
2.
**
16
self
.
_max_loss_scale
=
2.
**
24
self
.
_scale_seq_len
=
2000
self
.
_unskipped
=
0
self
.
_overflow_buf
=
torch
.
cuda
.
ByteTensor
(
1024
,)
def
loss_scale
(
self
):
return
self
.
_loss_scale
def
unscale_and_update
(
self
,
param_groups
,
scale
):
self
.
_overflow_buf
.
zero_
()
for
p
in
iter_params
(
param_groups
):
if
p
.
grad
is
not
None
:
scale_check_overflow
(
p
.
grad
.
data
,
1.
/
scale
,
self
.
_overflow_buf
)
if
self
.
_overflow_buf
.
any
():
should_skip
=
True
self
.
_loss_scale
/=
2.
self
.
_unskipped
=
0
else
:
should_skip
=
False
self
.
_unskipped
+=
1
if
self
.
_unskipped
==
self
.
_scale_seq_len
:
self
.
_loss_scale
=
min
(
self
.
_max_loss_scale
,
self
.
_loss_scale
*
2.
)
self
.
_unskipped
=
0
return
should_skip
def
iter_params
(
param_groups
):
for
group
in
param_groups
:
for
p
in
group
[
'params'
]:
yield
p
apex/amp/utils.py
View file @
d6db91a4
...
@@ -85,7 +85,15 @@ def cached_cast(cast_fn, x, cache):
...
@@ -85,7 +85,15 @@ def cached_cast(cast_fn, x, cache):
if
is_nested
(
x
):
if
is_nested
(
x
):
return
type
(
x
)([
cached_cast
(
y
)
for
y
in
x
])
return
type
(
x
)([
cached_cast
(
y
)
for
y
in
x
])
if
x
in
cache
:
if
x
in
cache
:
cached_x
=
cache
[
x
]
# During eval, it's possible to end up caching casted weights
# with requires_grad == False. This is then a problem when they
# get reused on the next train iter. So we ensure that cached
# weights have same requires_grad flag of most recent request.
if
x
.
requires_grad
!=
cached_x
.
requires_grad
:
cached_x
.
requires_grad_
(
x
.
requires_grad
)
return
cache
[
x
]
return
cache
[
x
]
casted_x
=
cast_fn
(
x
)
casted_x
=
cast_fn
(
x
)
cache
[
x
]
=
casted_x
cache
[
x
]
=
casted_x
return
casted_x
return
casted_x
...
...
apex/amp/wrap.py
View file @
d6db91a4
...
@@ -5,13 +5,8 @@ import functools
...
@@ -5,13 +5,8 @@ import functools
import
torch
import
torch
def
cached_cast
(
mod
,
fn
,
cast_fn
,
handle
,
def
make_cast_wrapper
(
orig_fn
,
cast_fn
,
handle
,
try_caching
=
False
,
verbose
=
False
):
try_caching
=
False
):
if
not
utils
.
has_func
(
mod
,
fn
):
return
orig_fn
=
utils
.
get_func
(
mod
,
fn
)
cast_fn
=
utils
.
verbosify
(
cast_fn
,
fn
,
verbose
)
@
functools
.
wraps
(
orig_fn
)
@
functools
.
wraps
(
orig_fn
)
def
wrapper
(
*
args
,
**
kwargs
):
def
wrapper
(
*
args
,
**
kwargs
):
if
try_caching
and
handle
.
has_cache
:
if
try_caching
and
handle
.
has_cache
:
...
@@ -26,18 +21,27 @@ def cached_cast(mod, fn, cast_fn, handle,
...
@@ -26,18 +21,27 @@ def cached_cast(mod, fn, cast_fn, handle,
args
,
args
,
kwargs
)
kwargs
)
return
orig_fn
(
*
new_args
,
**
kwargs
)
return
orig_fn
(
*
new_args
,
**
kwargs
)
utils
.
set_func
(
mod
,
fn
,
wrapper
)
return
wrapper
def
cached_cast
(
mod
,
fn
,
cast_fn
,
handle
,
try_caching
=
False
,
verbose
=
False
):
if
not
utils
.
has_func
(
mod
,
fn
):
return
def
promote
(
mod
,
fn
,
verbose
=
False
):
orig_fn
=
utils
.
get_func
(
mod
,
fn
)
orig_fn
=
utils
.
get_func
(
mod
,
fn
)
maybe_float
=
utils
.
verbosify
(
utils
.
maybe_float
,
fn
,
verbose
)
cast_fn
=
utils
.
verbosify
(
cast_fn
,
fn
,
verbose
)
wrapper
=
make_cast_wrapper
(
orig_fn
,
cast_fn
,
handle
,
try_caching
)
utils
.
set_func
(
mod
,
fn
,
wrapper
)
# `handle` arg is unused, but simplifies API to make `make_cast_wrapper`
def
make_promote_wrapper
(
orig_fn
,
cast_fn
,
handle
=
None
):
@
functools
.
wraps
(
orig_fn
)
@
functools
.
wraps
(
orig_fn
)
def
wrapper
(
*
args
,
**
kwargs
):
def
wrapper
(
*
args
,
**
kwargs
):
types
=
utils
.
collect_fp_tensor_types
(
args
,
kwargs
)
types
=
utils
.
collect_fp_tensor_types
(
args
,
kwargs
)
if
len
(
types
)
<=
1
:
if
len
(
types
)
<=
1
:
return
orig_fn
(
*
args
,
**
kwargs
)
return
orig_fn
(
*
args
,
**
kwargs
)
elif
len
(
types
)
==
2
and
types
==
set
([
'HalfTensor'
,
'FloatTensor'
]):
elif
len
(
types
)
==
2
and
types
==
set
([
'HalfTensor'
,
'FloatTensor'
]):
new_args
=
utils
.
casted_args
(
maybe_float
,
new_args
=
utils
.
casted_args
(
cast_fn
,
args
,
args
,
kwargs
)
kwargs
)
return
orig_fn
(
*
new_args
,
**
kwargs
)
return
orig_fn
(
*
new_args
,
**
kwargs
)
...
@@ -45,8 +49,14 @@ def promote(mod, fn, verbose=False):
...
@@ -45,8 +49,14 @@ def promote(mod, fn, verbose=False):
raise
NotImplementedError
(
'Do not know how to handle '
+
raise
NotImplementedError
(
'Do not know how to handle '
+
'these types to promote: {}'
'these types to promote: {}'
.
format
(
types
))
.
format
(
types
))
return
wrapper
def
promote
(
mod
,
fn
,
verbose
=
False
):
orig_fn
=
utils
.
get_func
(
mod
,
fn
)
maybe_float
=
utils
.
verbosify
(
utils
.
maybe_float
,
fn
,
verbose
)
wrapper
=
make_promote_wrapper
(
orig_fn
,
maybe_float
)
utils
.
set_func
(
mod
,
fn
,
wrapper
)
utils
.
set_func
(
mod
,
fn
,
wrapper
)
def
sequence_promote
(
mod
,
fn
,
verbose
=
False
):
def
sequence_promote
(
mod
,
fn
,
verbose
=
False
):
orig_fn
=
utils
.
get_func
(
mod
,
fn
)
orig_fn
=
utils
.
get_func
(
mod
,
fn
)
maybe_float
=
utils
.
verbosify
(
utils
.
maybe_float
,
fn
,
verbose
)
maybe_float
=
utils
.
verbosify
(
utils
.
maybe_float
,
fn
,
verbose
)
...
@@ -84,7 +94,7 @@ def promote_match_arg0(mod, fn, verbose=False):
...
@@ -84,7 +94,7 @@ def promote_match_arg0(mod, fn, verbose=False):
return
orig_fn
(
arg0
,
*
new_args
,
**
kwargs
)
return
orig_fn
(
arg0
,
*
new_args
,
**
kwargs
)
utils
.
set_func
(
mod
,
fn
,
wrapper
)
utils
.
set_func
(
mod
,
fn
,
wrapper
)
def
err_if_any_half
(
mod
,
fn
):
def
err_if_any_half
(
mod
,
fn
,
custom_err_msg
=
None
):
if
not
utils
.
has_func
(
mod
,
fn
):
if
not
utils
.
has_func
(
mod
,
fn
):
return
return
...
@@ -93,8 +103,11 @@ def err_if_any_half(mod, fn):
...
@@ -93,8 +103,11 @@ def err_if_any_half(mod, fn):
def
wrapper
(
*
args
,
**
kwargs
):
def
wrapper
(
*
args
,
**
kwargs
):
types
=
utils
.
collect_fp_tensor_types
(
args
,
kwargs
)
types
=
utils
.
collect_fp_tensor_types
(
args
,
kwargs
)
if
'HalfTensor'
in
types
:
if
'HalfTensor'
in
types
:
raise
NotImplementedError
(
'Cannot call in-place function '
+
if
custom_err_msg
:
'{} with fp16 arguments.'
.
format
(
fn
))
raise
NotImplementedError
(
custom_err_msg
)
else
:
raise
NotImplementedError
(
'Cannot call in-place function '
+
'{} with fp16 arguments.'
.
format
(
fn
))
else
:
else
:
return
orig_fn
(
*
args
,
**
kwargs
)
return
orig_fn
(
*
args
,
**
kwargs
)
utils
.
set_func
(
mod
,
fn
,
wrapper
)
utils
.
set_func
(
mod
,
fn
,
wrapper
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment