Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
43d1ae08
Unverified
Commit
43d1ae08
authored
Jun 12, 2018
by
Carl Case
Committed by
GitHub
Jun 12, 2018
Browse files
Merge pull request #11 from NVIDIA/amp_lstm_backward
Handle the use of .sum() in fused LSTM/GRU backward
parents
227a9a2d
32fbbe48
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
34 additions
and
1 deletion
+34
-1
apex/amp/amp.py
apex/amp/amp.py
+8
-0
apex/amp/handle.py
apex/amp/handle.py
+12
-1
apex/amp/wrap.py
apex/amp/wrap.py
+14
-0
No files found.
apex/amp/amp.py
View file @
43d1ae08
...
...
@@ -145,6 +145,14 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
# 5.5) Extra-special handling of RNN backend
wrap
.
rnn_cast
(
torch
.
nn
.
backends
.
thnn
.
backend
,
'RNN'
,
verbose
)
# And even more special handling of `backward` for fused gru / lstm
# The `backward` method calls Tensor.sum() (blacklist) internally,
# and then the resulting grad_input has the wrong type.
# TODO: where else is this a problem?
for
rnn_type
in
[
'GRUFused'
,
'LSTMFused'
]:
mod
=
getattr
(
torch
.
nn
.
_functions
.
thnn
.
rnnFusedPointwise
,
rnn_type
)
wrap
.
disable_casts
(
mod
,
'backward'
,
handle
)
# 6) Place error+print message on banned functions
if
not
allow_banned
:
for
fn
,
err_msg
in
functional_overrides
.
BANNED_FUNCS
:
...
...
apex/amp/handle.py
View file @
43d1ae08
...
...
@@ -11,9 +11,16 @@ class AmpHandle(object):
self
.
_verbose
=
verbose
self
.
_cache
=
dict
()
self
.
_default_scaler
=
LossScaler
()
self
.
_is_active
=
True
def
is_active
(
self
):
return
True
return
self
.
_is_active
@
contextlib
.
contextmanager
def
_disable_casts
(
self
):
self
.
_is_active
=
False
yield
self
.
_is_active
=
True
def
wrap_optimizer
(
self
,
optimizer
,
num_loss
=
1
):
self
.
_default_scaler
=
None
...
...
@@ -76,6 +83,10 @@ class NoOpHandle(object):
def
is_active
(
self
):
return
False
@
contextlib
.
contextmanager
def
_disable_casts
(
self
):
yield
def
wrap_optimizer
(
self
,
optimizer
,
num_loss
=
1
):
return
OptimWrapper
(
optimizer
,
self
,
num_loss
)
...
...
apex/amp/wrap.py
View file @
43d1ae08
...
...
@@ -9,6 +9,9 @@ def make_cast_wrapper(orig_fn, cast_fn, handle,
try_caching
=
False
):
@
functools
.
wraps
(
orig_fn
)
def
wrapper
(
*
args
,
**
kwargs
):
if
not
handle
.
is_active
():
return
orig_fn
(
*
args
,
**
kwargs
)
if
try_caching
and
handle
.
has_cache
:
args
=
list
(
args
)
for
i
in
range
(
len
(
args
)):
...
...
@@ -201,3 +204,14 @@ def rnn_cast(backend, fn, verbose=False):
return
forward
(
*
new_args
,
**
fkwargs
)
return
fwd_wrapper
utils
.
set_func
(
backend
,
fn
,
rnn_wrapper
)
def
disable_casts
(
mod
,
fn
,
handle
):
if
not
utils
.
has_func
(
mod
,
fn
):
return
orig_fn
=
utils
.
get_func
(
mod
,
fn
)
@
functools
.
wraps
(
orig_fn
)
def
wrapper
(
*
args
,
**
kwargs
):
with
handle
.
_disable_casts
():
return
orig_fn
(
*
args
,
**
kwargs
)
utils
.
set_func
(
mod
,
fn
,
wrapper
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment