Commit fb419005 authored by Carl Case's avatar Carl Case
Browse files

Hard ban on fp16 BCELoss

parent 6d30e1ff
...@@ -54,7 +54,7 @@ def register_promote_function(module, name): ...@@ -54,7 +54,7 @@ def register_promote_function(module, name):
_USER_PROMOTE_REGISTRY.add((module, name)) _USER_PROMOTE_REGISTRY.add((module, name))
# Top-level function to insert _all_ the hooks. # Top-level function to insert _all_ the hooks.
def init(enabled=True, enable_caching=True, verbose=False): def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
global _DECORATOR_HANDLE global _DECORATOR_HANDLE
if not enabled: if not enabled:
...@@ -145,5 +145,10 @@ def init(enabled=True, enable_caching=True, verbose=False): ...@@ -145,5 +145,10 @@ def init(enabled=True, enable_caching=True, verbose=False):
# 5.5) Extra-special handling of RNN backend # 5.5) Extra-special handling of RNN backend
wrap.rnn_cast(torch.nn.backends.thnn.backend, 'RNN', verbose) wrap.rnn_cast(torch.nn.backends.thnn.backend, 'RNN', verbose)
# 6) Place error+print message on banned functions
if not allow_banned:
for fn, err_msg in functional_overrides.BANNED_FUNCS:
wrap.err_if_any_half(functional_overrides.MODULE, fn, err_msg)
_DECORATOR_HANDLE = handle _DECORATOR_HANDLE = handle
return handle return handle
...@@ -42,7 +42,6 @@ FP32_FUNCS = [ ...@@ -42,7 +42,6 @@ FP32_FUNCS = [
# Loss functions # Loss functions
# TODO: which of these can be fp16? # TODO: which of these can be fp16?
'binary_cross_entropy',
'poisson_nll_loss', 'poisson_nll_loss',
'cosine_embedding_loss', 'cosine_embedding_loss',
'cross_entropy', 'cross_entropy',
...@@ -60,3 +59,15 @@ FP32_FUNCS = [ ...@@ -60,3 +59,15 @@ FP32_FUNCS = [
'soft_margin_loss', 'soft_margin_loss',
'triplet_margin_loss' 'triplet_margin_loss'
] ]
BANNED_FUNCS = [
('binary_cross_entropy',
("\namp does not work out-of-the-box with `F.binary_cross_entropy` or `torch.nn.BCELoss.` "
"It requires that the output of the previous function be already a FloatTensor. \n\n"
"Most models have a Sigmoid right before BCELoss. In that case, you can use\n"
" torch.nn.BCEWithLogitsLoss\nto combine Sigmoid+BCELoss into a single layer "
"that is compatible with amp.\nAnother option is to add\n"
" amp.register_float_function(torch, 'sigmoid')\nbefore calling `amp.init()`.\n"
"If you _really_ know what you are doing, you can disable this warning by passing "
"allow_banned=True to `amp.init()`."))
]
...@@ -94,7 +94,7 @@ def promote_match_arg0(mod, fn, verbose=False): ...@@ -94,7 +94,7 @@ def promote_match_arg0(mod, fn, verbose=False):
return orig_fn(arg0, *new_args, **kwargs) return orig_fn(arg0, *new_args, **kwargs)
utils.set_func(mod, fn, wrapper) utils.set_func(mod, fn, wrapper)
def err_if_any_half(mod, fn): def err_if_any_half(mod, fn, custom_err_msg=None):
if not utils.has_func(mod, fn): if not utils.has_func(mod, fn):
return return
...@@ -103,6 +103,9 @@ def err_if_any_half(mod, fn): ...@@ -103,6 +103,9 @@ def err_if_any_half(mod, fn):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
types = utils.collect_fp_tensor_types(args, kwargs) types = utils.collect_fp_tensor_types(args, kwargs)
if 'HalfTensor' in types: if 'HalfTensor' in types:
if custom_err_msg:
raise NotImplementedError(custom_err_msg)
else:
raise NotImplementedError('Cannot call in-place function ' + raise NotImplementedError('Cannot call in-place function ' +
'{} with fp16 arguments.'.format(fn)) '{} with fp16 arguments.'.format(fn))
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment