Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
fb419005
"docs/vscode:/vscode.git/clone" did not exist on "03df281275ad3fcb732a41ab1638c2e89afddb25"
Commit
fb419005
authored
May 30, 2018
by
Carl Case
Browse files
Hard ban on fp16 BCELoss
parent
6d30e1ff
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
24 additions
and
5 deletions
+24
-5
apex/amp/amp.py
apex/amp/amp.py
+6
-1
apex/amp/lists/functional_overrides.py
apex/amp/lists/functional_overrides.py
+12
-1
apex/amp/wrap.py
apex/amp/wrap.py
+6
-3
No files found.
apex/amp/amp.py
View file @
fb419005
...
...
@@ -54,7 +54,7 @@ def register_promote_function(module, name):
_USER_PROMOTE_REGISTRY
.
add
((
module
,
name
))
# Top-level function to insert _all_ the hooks.
def
init
(
enabled
=
True
,
enable_caching
=
True
,
verbose
=
False
):
def
init
(
enabled
=
True
,
enable_caching
=
True
,
verbose
=
False
,
allow_banned
=
False
):
global
_DECORATOR_HANDLE
if
not
enabled
:
...
...
@@ -145,5 +145,10 @@ def init(enabled=True, enable_caching=True, verbose=False):
# 5.5) Extra-special handling of RNN backend
wrap
.
rnn_cast
(
torch
.
nn
.
backends
.
thnn
.
backend
,
'RNN'
,
verbose
)
# 6) Place error+print message on banned functions
if
not
allow_banned
:
for
fn
,
err_msg
in
functional_overrides
.
BANNED_FUNCS
:
wrap
.
err_if_any_half
(
functional_overrides
.
MODULE
,
fn
,
err_msg
)
_DECORATOR_HANDLE
=
handle
return
handle
apex/amp/lists/functional_overrides.py
View file @
fb419005
...
...
@@ -42,7 +42,6 @@ FP32_FUNCS = [
# Loss functions
# TODO: which of these can be fp16?
'binary_cross_entropy'
,
'poisson_nll_loss'
,
'cosine_embedding_loss'
,
'cross_entropy'
,
...
...
@@ -60,3 +59,15 @@ FP32_FUNCS = [
'soft_margin_loss'
,
'triplet_margin_loss'
]
BANNED_FUNCS
=
[
(
'binary_cross_entropy'
,
(
"
\n
amp does not work out-of-the-box with `F.binary_cross_entropy` or `torch.nn.BCELoss.` "
"It requires that the output of the previous function be already a FloatTensor.
\n\n
"
"Most models have a Sigmoid right before BCELoss. In that case, you can use
\n
"
" torch.nn.BCEWithLogitsLoss
\n
to combine Sigmoid+BCELoss into a single layer "
"that is compatible with amp.
\n
Another option is to add
\n
"
" amp.register_float_function(torch, 'sigmoid')
\n
before calling `amp.init()`.
\n
"
"If you _really_ know what you are doing, you can disable this warning by passing "
"allow_banned=True to `amp.init()`."
))
]
apex/amp/wrap.py
View file @
fb419005
...
...
@@ -94,7 +94,7 @@ def promote_match_arg0(mod, fn, verbose=False):
return
orig_fn
(
arg0
,
*
new_args
,
**
kwargs
)
utils
.
set_func
(
mod
,
fn
,
wrapper
)
def
err_if_any_half
(
mod
,
fn
):
def
err_if_any_half
(
mod
,
fn
,
custom_err_msg
=
None
):
if
not
utils
.
has_func
(
mod
,
fn
):
return
...
...
@@ -103,8 +103,11 @@ def err_if_any_half(mod, fn):
def
wrapper
(
*
args
,
**
kwargs
):
types
=
utils
.
collect_fp_tensor_types
(
args
,
kwargs
)
if
'HalfTensor'
in
types
:
raise
NotImplementedError
(
'Cannot call in-place function '
+
'{} with fp16 arguments.'
.
format
(
fn
))
if
custom_err_msg
:
raise
NotImplementedError
(
custom_err_msg
)
else
:
raise
NotImplementedError
(
'Cannot call in-place function '
+
'{} with fp16 arguments.'
.
format
(
fn
))
else
:
return
orig_fn
(
*
args
,
**
kwargs
)
utils
.
set_func
(
mod
,
fn
,
wrapper
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment