Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
ebaa5a15
Commit
ebaa5a15
authored
Aug 27, 2018
by
Carl Case
Browse files
experimental: ability to deactivate amp with handle
parent
437bcf22
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
41 additions
and
25 deletions
+41
-25
apex/amp/amp.py
apex/amp/amp.py
+11
-11
apex/amp/handle.py
apex/amp/handle.py
+11
-0
apex/amp/utils.py
apex/amp/utils.py
+5
-0
apex/amp/wrap.py
apex/amp/wrap.py
+14
-14
No files found.
apex/amp/amp.py
View file @
ebaa5a15
...
@@ -73,7 +73,7 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
...
@@ -73,7 +73,7 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
# 0.5) Force-promote for user-annotated functions
# 0.5) Force-promote for user-annotated functions
for
mod
,
fn
in
_USER_PROMOTE_REGISTRY
:
for
mod
,
fn
in
_USER_PROMOTE_REGISTRY
:
wrap
.
promote
(
mod
,
fn
,
verbose
)
wrap
.
promote
(
mod
,
fn
,
handle
,
verbose
)
_USER_PROMOTE_REGISTRY
.
clear
()
_USER_PROMOTE_REGISTRY
.
clear
()
# 1) Force-{fp16, fp32} on white- / black-list functions
# 1) Force-{fp16, fp32} on white- / black-list functions
...
@@ -107,7 +107,7 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
...
@@ -107,7 +107,7 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
for
promote_mod
,
(
list_name
,
promote_fn
)
in
itertools
.
product
(
promote_modules
,
for
promote_mod
,
(
list_name
,
promote_fn
)
in
itertools
.
product
(
promote_modules
,
promote_table
):
promote_table
):
for
fn
in
getattr
(
promote_mod
,
list_name
):
for
fn
in
getattr
(
promote_mod
,
list_name
):
promote_fn
(
promote_mod
.
MODULE
,
fn
,
verbose
)
promote_fn
(
promote_mod
.
MODULE
,
fn
,
handle
,
verbose
)
# 2.5) Pre-0.4, add blacklist methods directly to HalfTensor and FloatTensor types
# 2.5) Pre-0.4, add blacklist methods directly to HalfTensor and FloatTensor types
if
compat
.
tensor_is_float_tensor
():
if
compat
.
tensor_is_float_tensor
():
...
@@ -115,27 +115,27 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
...
@@ -115,27 +115,27 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
torch
.
cuda
.
HalfTensor
],
torch
.
cuda
.
HalfTensor
],
promote_table
):
promote_table
):
for
fn
in
getattr
(
tensor_overrides
,
list_name
):
for
fn
in
getattr
(
tensor_overrides
,
list_name
):
promote_fn
(
cls
,
fn
,
verbose
)
promote_fn
(
cls
,
fn
,
handle
,
verbose
)
# 3) For any in-place version of a blacklist function, error if any input is fp16.
# 3) For any in-place version of a blacklist function, error if any input is fp16.
# NB: this is overly conservative.
# NB: this is overly conservative.
for
fn
in
utils
.
as_inplace
(
torch_overrides
.
FP32_FUNCS
):
for
fn
in
utils
.
as_inplace
(
torch_overrides
.
FP32_FUNCS
):
wrap
.
err_if_any_half
(
torch_overrides
.
MODULE
,
fn
)
wrap
.
err_if_any_half
(
torch_overrides
.
MODULE
,
fn
,
handle
)
# 3.5) For any in-place blacklist method, error if called on fp16 tensor
# 3.5) For any in-place blacklist method, error if called on fp16 tensor
for
fn
in
utils
.
as_inplace
(
tensor_overrides
.
FP32_FUNCS
):
for
fn
in
utils
.
as_inplace
(
tensor_overrides
.
FP32_FUNCS
):
wrap
.
err_if_arg0_half
(
tensor_overrides
.
MODULE
,
fn
,
verbose
)
wrap
.
err_if_arg0_half
(
tensor_overrides
.
MODULE
,
fn
,
handle
,
verbose
)
if
compat
.
tensor_is_float_tensor
():
if
compat
.
tensor_is_float_tensor
():
wrap
.
err_if_arg0_half
(
torch
.
cuda
.
HalfTensor
,
fn
,
verbose
)
wrap
.
err_if_arg0_half
(
torch
.
cuda
.
HalfTensor
,
fn
,
handle
,
verbose
)
# 4) For other in-place methods, match the type of self tensor
# 4) For other in-place methods, match the type of self tensor
for
fn
in
utils
.
as_inplace
(
itertools
.
chain
(
for
fn
in
utils
.
as_inplace
(
itertools
.
chain
(
tensor_overrides
.
FP16_FUNCS
,
tensor_overrides
.
FP16_FUNCS
,
tensor_overrides
.
CASTS
)):
tensor_overrides
.
CASTS
)):
wrap
.
promote_match_arg0
(
tensor_overrides
.
MODULE
,
fn
,
verbose
)
wrap
.
promote_match_arg0
(
tensor_overrides
.
MODULE
,
fn
,
handle
,
verbose
)
if
compat
.
tensor_is_float_tensor
():
if
compat
.
tensor_is_float_tensor
():
wrap
.
promote_match_arg0
(
torch
.
cuda
.
HalfTensor
,
fn
,
verbose
)
wrap
.
promote_match_arg0
(
torch
.
cuda
.
HalfTensor
,
fn
,
handle
,
verbose
)
wrap
.
promote_match_arg0
(
torch
.
cuda
.
FloatTensor
,
fn
,
verbose
)
wrap
.
promote_match_arg0
(
torch
.
cuda
.
FloatTensor
,
fn
,
handle
,
verbose
)
# 5) Special handling to whitelist RNN cell backend impls.
# 5) Special handling to whitelist RNN cell backend impls.
for
fn
in
[
'RNNReLUCell'
,
'RNNTanhCell'
,
'LSTMCell'
,
'GRUCell'
]:
for
fn
in
[
'RNNReLUCell'
,
'RNNTanhCell'
,
'LSTMCell'
,
'GRUCell'
]:
...
@@ -143,7 +143,7 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
...
@@ -143,7 +143,7 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
handle
,
try_caching
=
True
,
verbose
=
verbose
)
handle
,
try_caching
=
True
,
verbose
=
verbose
)
# 5.5) Extra-special handling of RNN backend
# 5.5) Extra-special handling of RNN backend
wrap
.
rnn_cast
(
torch
.
nn
.
backends
.
thnn
.
backend
,
'RNN'
,
verbose
)
wrap
.
rnn_cast
(
torch
.
nn
.
backends
.
thnn
.
backend
,
'RNN'
,
handle
,
verbose
)
# And even more special handling of `backward` for fused gru / lstm
# And even more special handling of `backward` for fused gru / lstm
# The `backward` method calls Tensor.sum() (blacklist) internally,
# The `backward` method calls Tensor.sum() (blacklist) internally,
...
@@ -156,7 +156,7 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
...
@@ -156,7 +156,7 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
# 6) Place error+print message on banned functions
# 6) Place error+print message on banned functions
if
not
allow_banned
:
if
not
allow_banned
:
for
fn
,
err_msg
in
functional_overrides
.
BANNED_FUNCS
:
for
fn
,
err_msg
in
functional_overrides
.
BANNED_FUNCS
:
wrap
.
err_if_any_half
(
functional_overrides
.
MODULE
,
fn
,
err_msg
)
wrap
.
err_if_any_half
(
functional_overrides
.
MODULE
,
fn
,
handle
,
err_msg
)
_DECORATOR_HANDLE
=
handle
_DECORATOR_HANDLE
=
handle
return
handle
return
handle
apex/amp/handle.py
View file @
ebaa5a15
...
@@ -2,6 +2,7 @@ import contextlib
...
@@ -2,6 +2,7 @@ import contextlib
import
logging
import
logging
import
warnings
import
warnings
from
.
import
utils
from
.opt
import
OptimWrapper
from
.opt
import
OptimWrapper
from
.scaler
import
LossScaler
from
.scaler
import
LossScaler
...
@@ -12,6 +13,7 @@ class AmpHandle(object):
...
@@ -12,6 +13,7 @@ class AmpHandle(object):
self
.
_cache
=
dict
()
self
.
_cache
=
dict
()
self
.
_default_scaler
=
LossScaler
()
self
.
_default_scaler
=
LossScaler
()
self
.
_is_active
=
True
self
.
_is_active
=
True
self
.
_all_wrappers
=
[]
def
is_active
(
self
):
def
is_active
(
self
):
return
self
.
_is_active
return
self
.
_is_active
...
@@ -63,6 +65,15 @@ class AmpHandle(object):
...
@@ -63,6 +65,15 @@ class AmpHandle(object):
def
_clear_cache
(
self
):
def
_clear_cache
(
self
):
self
.
_cache
.
clear
()
self
.
_cache
.
clear
()
# Experimental support for saving / restoring uncasted versions of functions
def
_save_func
(
self
,
mod
,
fn
,
func
):
self
.
_all_wrappers
.
append
((
mod
,
fn
,
func
))
def
_deactivate
(
self
):
for
mod
,
fn
,
func
in
self
.
_all_wrappers
:
utils
.
set_func
(
mod
,
fn
,
func
)
self
.
_all_wrappers
=
[]
@
property
@
property
def
has_cache
(
self
):
def
has_cache
(
self
):
return
self
.
_enable_caching
return
self
.
_enable_caching
...
...
apex/amp/utils.py
View file @
ebaa5a15
...
@@ -126,6 +126,11 @@ def set_func(mod, fn, new_fn):
...
@@ -126,6 +126,11 @@ def set_func(mod, fn, new_fn):
else
:
else
:
setattr
(
mod
,
fn
,
new_fn
)
setattr
(
mod
,
fn
,
new_fn
)
def
set_func_save
(
handle
,
mod
,
fn
,
new_fn
):
cur_fn
=
get_func
(
mod
,
fn
)
handle
.
_save_func
(
mod
,
fn
,
cur_fn
)
set_func
(
mod
,
fn
,
new_fn
)
# A couple problems get solved here:
# A couple problems get solved here:
# - The flat_weight buffer is disconnected from autograd graph,
# - The flat_weight buffer is disconnected from autograd graph,
# so the fp16 weights need to be derived from the input weights
# so the fp16 weights need to be derived from the input weights
...
...
apex/amp/wrap.py
View file @
ebaa5a15
...
@@ -34,7 +34,7 @@ def cached_cast(mod, fn, cast_fn, handle,
...
@@ -34,7 +34,7 @@ def cached_cast(mod, fn, cast_fn, handle,
orig_fn
=
utils
.
get_func
(
mod
,
fn
)
orig_fn
=
utils
.
get_func
(
mod
,
fn
)
cast_fn
=
utils
.
verbosify
(
cast_fn
,
fn
,
verbose
)
cast_fn
=
utils
.
verbosify
(
cast_fn
,
fn
,
verbose
)
wrapper
=
make_cast_wrapper
(
orig_fn
,
cast_fn
,
handle
,
try_caching
)
wrapper
=
make_cast_wrapper
(
orig_fn
,
cast_fn
,
handle
,
try_caching
)
utils
.
set_func
(
mod
,
fn
,
wrapper
)
utils
.
set_func
_save
(
handle
,
mod
,
fn
,
wrapper
)
# `handle` arg is unused, but simplifies API to make `make_cast_wrapper`
# `handle` arg is unused, but simplifies API to make `make_cast_wrapper`
def
make_promote_wrapper
(
orig_fn
,
cast_fn
,
handle
=
None
):
def
make_promote_wrapper
(
orig_fn
,
cast_fn
,
handle
=
None
):
...
@@ -54,13 +54,13 @@ def make_promote_wrapper(orig_fn, cast_fn, handle=None):
...
@@ -54,13 +54,13 @@ def make_promote_wrapper(orig_fn, cast_fn, handle=None):
.
format
(
types
))
.
format
(
types
))
return
wrapper
return
wrapper
def
promote
(
mod
,
fn
,
verbose
=
False
):
def
promote
(
mod
,
fn
,
handle
,
verbose
=
False
):
orig_fn
=
utils
.
get_func
(
mod
,
fn
)
orig_fn
=
utils
.
get_func
(
mod
,
fn
)
maybe_float
=
utils
.
verbosify
(
utils
.
maybe_float
,
fn
,
verbose
)
maybe_float
=
utils
.
verbosify
(
utils
.
maybe_float
,
fn
,
verbose
)
wrapper
=
make_promote_wrapper
(
orig_fn
,
maybe_float
)
wrapper
=
make_promote_wrapper
(
orig_fn
,
maybe_float
)
utils
.
set_func
(
mod
,
fn
,
wrapper
)
utils
.
set_func
_save
(
handle
,
mod
,
fn
,
wrapper
)
def
sequence_promote
(
mod
,
fn
,
verbose
=
False
):
def
sequence_promote
(
mod
,
fn
,
handle
,
verbose
=
False
):
orig_fn
=
utils
.
get_func
(
mod
,
fn
)
orig_fn
=
utils
.
get_func
(
mod
,
fn
)
maybe_float
=
utils
.
verbosify
(
utils
.
maybe_float
,
fn
,
verbose
)
maybe_float
=
utils
.
verbosify
(
utils
.
maybe_float
,
fn
,
verbose
)
@
functools
.
wraps
(
orig_fn
)
@
functools
.
wraps
(
orig_fn
)
...
@@ -76,9 +76,9 @@ def sequence_promote(mod, fn, verbose=False):
...
@@ -76,9 +76,9 @@ def sequence_promote(mod, fn, verbose=False):
# TODO: other mixed-type cases aren't due to amp.
# TODO: other mixed-type cases aren't due to amp.
# Just pass through?
# Just pass through?
return
orig_fn
(
seq
,
*
args
,
**
kwargs
)
return
orig_fn
(
seq
,
*
args
,
**
kwargs
)
utils
.
set_func
(
mod
,
fn
,
wrapper
)
utils
.
set_func
_save
(
handle
,
mod
,
fn
,
wrapper
)
def
promote_match_arg0
(
mod
,
fn
,
verbose
=
False
):
def
promote_match_arg0
(
mod
,
fn
,
handle
,
verbose
=
False
):
if
not
utils
.
has_func
(
mod
,
fn
):
if
not
utils
.
has_func
(
mod
,
fn
):
return
return
...
@@ -95,9 +95,9 @@ def promote_match_arg0(mod, fn, verbose=False):
...
@@ -95,9 +95,9 @@ def promote_match_arg0(mod, fn, verbose=False):
cast_fn
=
utils
.
verbosify
(
cast_fn
,
fn
,
verbose
)
cast_fn
=
utils
.
verbosify
(
cast_fn
,
fn
,
verbose
)
new_args
=
utils
.
casted_args
(
cast_fn
,
args
,
kwargs
)
new_args
=
utils
.
casted_args
(
cast_fn
,
args
,
kwargs
)
return
orig_fn
(
arg0
,
*
new_args
,
**
kwargs
)
return
orig_fn
(
arg0
,
*
new_args
,
**
kwargs
)
utils
.
set_func
(
mod
,
fn
,
wrapper
)
utils
.
set_func
_save
(
handle
,
mod
,
fn
,
wrapper
)
def
err_if_any_half
(
mod
,
fn
,
custom_err_msg
=
None
):
def
err_if_any_half
(
mod
,
fn
,
handle
,
custom_err_msg
=
None
):
if
not
utils
.
has_func
(
mod
,
fn
):
if
not
utils
.
has_func
(
mod
,
fn
):
return
return
...
@@ -113,9 +113,9 @@ def err_if_any_half(mod, fn, custom_err_msg=None):
...
@@ -113,9 +113,9 @@ def err_if_any_half(mod, fn, custom_err_msg=None):
'{} with fp16 arguments.'
.
format
(
fn
))
'{} with fp16 arguments.'
.
format
(
fn
))
else
:
else
:
return
orig_fn
(
*
args
,
**
kwargs
)
return
orig_fn
(
*
args
,
**
kwargs
)
utils
.
set_func
(
mod
,
fn
,
wrapper
)
utils
.
set_func
_save
(
handle
,
mod
,
fn
,
wrapper
)
def
err_if_arg0_half
(
mod
,
fn
,
verbose
=
False
):
def
err_if_arg0_half
(
mod
,
fn
,
handle
,
verbose
=
False
):
if
not
utils
.
has_func
(
mod
,
fn
):
if
not
utils
.
has_func
(
mod
,
fn
):
return
return
...
@@ -130,7 +130,7 @@ def err_if_arg0_half(mod, fn, verbose=False):
...
@@ -130,7 +130,7 @@ def err_if_arg0_half(mod, fn, verbose=False):
cast_fn
=
utils
.
verbosify
(
utils
.
maybe_float
,
fn
,
verbose
)
cast_fn
=
utils
.
verbosify
(
utils
.
maybe_float
,
fn
,
verbose
)
new_args
=
utils
.
casted_args
(
cast_fn
,
args
,
kwargs
)
new_args
=
utils
.
casted_args
(
cast_fn
,
args
,
kwargs
)
return
orig_fn
(
arg0
,
*
new_args
,
**
kwargs
)
return
orig_fn
(
arg0
,
*
new_args
,
**
kwargs
)
utils
.
set_func
(
mod
,
fn
,
wrapper
)
utils
.
set_func
_save
(
handle
,
mod
,
fn
,
wrapper
)
# Current RNN approach:
# Current RNN approach:
# - Wrap top-level `RNN` function in thnn backend
# - Wrap top-level `RNN` function in thnn backend
...
@@ -140,7 +140,7 @@ def err_if_arg0_half(mod, fn, verbose=False):
...
@@ -140,7 +140,7 @@ def err_if_arg0_half(mod, fn, verbose=False):
# - We interpose on the factory function to:
# - We interpose on the factory function to:
# 1) Interpose on the actual forward function and put in casts
# 1) Interpose on the actual forward function and put in casts
# 2) Insert an fp16 `flat_weight` if necessary
# 2) Insert an fp16 `flat_weight` if necessary
def
rnn_cast
(
backend
,
fn
,
verbose
=
False
):
def
rnn_cast
(
backend
,
fn
,
handle
,
verbose
=
False
):
orig_rnn
=
utils
.
get_func
(
backend
,
fn
)
orig_rnn
=
utils
.
get_func
(
backend
,
fn
)
@
functools
.
wraps
(
orig_rnn
)
@
functools
.
wraps
(
orig_rnn
)
def
rnn_wrapper
(
*
args
,
**
kwargs
):
def
rnn_wrapper
(
*
args
,
**
kwargs
):
...
@@ -203,7 +203,7 @@ def rnn_cast(backend, fn, verbose=False):
...
@@ -203,7 +203,7 @@ def rnn_cast(backend, fn, verbose=False):
return
forward
(
*
new_args
,
**
fkwargs
)
return
forward
(
*
new_args
,
**
fkwargs
)
return
fwd_wrapper
return
fwd_wrapper
utils
.
set_func
(
backend
,
fn
,
rnn_wrapper
)
utils
.
set_func
_save
(
handle
,
backend
,
fn
,
rnn_wrapper
)
def
disable_casts
(
mod
,
fn
,
handle
):
def
disable_casts
(
mod
,
fn
,
handle
):
if
not
utils
.
has_func
(
mod
,
fn
):
if
not
utils
.
has_func
(
mod
,
fn
):
...
@@ -214,4 +214,4 @@ def disable_casts(mod, fn, handle):
...
@@ -214,4 +214,4 @@ def disable_casts(mod, fn, handle):
def
wrapper
(
*
args
,
**
kwargs
):
def
wrapper
(
*
args
,
**
kwargs
):
with
handle
.
_disable_casts
():
with
handle
.
_disable_casts
():
return
orig_fn
(
*
args
,
**
kwargs
)
return
orig_fn
(
*
args
,
**
kwargs
)
utils
.
set_func
(
mod
,
fn
,
wrapper
)
utils
.
set_func
_save
(
handle
,
mod
,
fn
,
wrapper
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment