Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
71671993
Commit
71671993
authored
Jun 14, 2018
by
Michael Carilli
Browse files
Merge branch 'master' of
https://github.com/NVIDIA/apex
parents
b78b5ea5
43d1ae08
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
37 additions
and
4 deletions
+37
-4
apex/amp/amp.py
apex/amp/amp.py
+8
-0
apex/amp/handle.py
apex/amp/handle.py
+12
-1
apex/amp/utils.py
apex/amp/utils.py
+2
-2
apex/amp/wrap.py
apex/amp/wrap.py
+15
-1
No files found.
apex/amp/amp.py
View file @
71671993
...
...
@@ -145,6 +145,14 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
# 5.5) Extra-special handling of RNN backend
wrap
.
rnn_cast
(
torch
.
nn
.
backends
.
thnn
.
backend
,
'RNN'
,
verbose
)
# And even more special handling of `backward` for fused gru / lstm
# The `backward` method calls Tensor.sum() (blacklist) internally,
# and then the resulting grad_input has the wrong type.
# TODO: where else is this a problem?
for
rnn_type
in
[
'GRUFused'
,
'LSTMFused'
]:
mod
=
getattr
(
torch
.
nn
.
_functions
.
thnn
.
rnnFusedPointwise
,
rnn_type
)
wrap
.
disable_casts
(
mod
,
'backward'
,
handle
)
# 6) Place error+print message on banned functions
if
not
allow_banned
:
for
fn
,
err_msg
in
functional_overrides
.
BANNED_FUNCS
:
...
...
apex/amp/handle.py
View file @
71671993
...
...
@@ -11,9 +11,16 @@ class AmpHandle(object):
self
.
_verbose
=
verbose
self
.
_cache
=
dict
()
self
.
_default_scaler
=
LossScaler
()
self
.
_is_active
=
True
def
is_active
(
self
):
return
True
return
self
.
_is_active
@
contextlib
.
contextmanager
def
_disable_casts
(
self
):
self
.
_is_active
=
False
yield
self
.
_is_active
=
True
def
wrap_optimizer
(
self
,
optimizer
,
num_loss
=
1
):
self
.
_default_scaler
=
None
...
...
@@ -76,6 +83,10 @@ class NoOpHandle(object):
def
is_active
(
self
):
return
False
@
contextlib
.
contextmanager
def
_disable_casts
(
self
):
yield
def
wrap_optimizer
(
self
,
optimizer
,
num_loss
=
1
):
return
OptimWrapper
(
optimizer
,
self
,
num_loss
)
...
...
apex/amp/utils.py
View file @
71671993
...
...
@@ -49,7 +49,7 @@ def maybe_half(x, name='', verbose=False):
if
is_nested
(
x
):
return
type
(
x
)([
maybe_half
(
y
)
for
y
in
x
])
if
type_string
(
x
)
==
'HalfTensor'
:
if
not
x
.
is_cuda
or
type_string
(
x
)
==
'HalfTensor'
:
return
x
else
:
if
verbose
:
...
...
@@ -60,7 +60,7 @@ def maybe_float(x, name='', verbose=False):
if
is_nested
(
x
):
return
type
(
x
)([
maybe_float
(
y
)
for
y
in
x
])
if
type_string
(
x
)
==
'FloatTensor'
:
if
not
x
.
is_cuda
or
type_string
(
x
)
==
'FloatTensor'
:
return
x
else
:
if
verbose
:
...
...
apex/amp/wrap.py
View file @
71671993
...
...
@@ -9,6 +9,9 @@ def make_cast_wrapper(orig_fn, cast_fn, handle,
try_caching
=
False
):
@
functools
.
wraps
(
orig_fn
)
def
wrapper
(
*
args
,
**
kwargs
):
if
not
handle
.
is_active
():
return
orig_fn
(
*
args
,
**
kwargs
)
if
try_caching
and
handle
.
has_cache
:
args
=
list
(
args
)
for
i
in
range
(
len
(
args
)):
...
...
@@ -70,7 +73,7 @@ def sequence_promote(mod, fn, verbose=False):
seq
,
{})
return
orig_fn
(
cast_seq
,
*
args
,
**
kwargs
)
else
:
# TODO: other mixed-type cases aren't due to a
utohalf
.
# TODO: other mixed-type cases aren't due to a
mp
.
# Just pass through?
return
orig_fn
(
seq
,
*
args
,
**
kwargs
)
utils
.
set_func
(
mod
,
fn
,
wrapper
)
...
...
@@ -201,3 +204,14 @@ def rnn_cast(backend, fn, verbose=False):
return
forward
(
*
new_args
,
**
fkwargs
)
return
fwd_wrapper
utils
.
set_func
(
backend
,
fn
,
rnn_wrapper
)
def
disable_casts
(
mod
,
fn
,
handle
):
if
not
utils
.
has_func
(
mod
,
fn
):
return
orig_fn
=
utils
.
get_func
(
mod
,
fn
)
@
functools
.
wraps
(
orig_fn
)
def
wrapper
(
*
args
,
**
kwargs
):
with
handle
.
_disable_casts
():
return
orig_fn
(
*
args
,
**
kwargs
)
utils
.
set_func
(
mod
,
fn
,
wrapper
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment