Commit ebaa5a15 authored by Carl Case's avatar Carl Case
Browse files

experimental: ability to deactivate amp with handle

parent 437bcf22
......@@ -73,7 +73,7 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
# 0.5) Force-promote for user-annotated functions
for mod, fn in _USER_PROMOTE_REGISTRY:
wrap.promote(mod, fn, verbose)
wrap.promote(mod, fn, handle, verbose)
_USER_PROMOTE_REGISTRY.clear()
# 1) Force-{fp16, fp32} on white- / black-list functions
......@@ -107,7 +107,7 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
for promote_mod, (list_name, promote_fn) in itertools.product(promote_modules,
promote_table):
for fn in getattr(promote_mod, list_name):
promote_fn(promote_mod.MODULE, fn, verbose)
promote_fn(promote_mod.MODULE, fn, handle, verbose)
# 2.5) Pre-0.4, add blacklist methods directly to HalfTensor and FloatTensor types
if compat.tensor_is_float_tensor():
......@@ -115,27 +115,27 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
torch.cuda.HalfTensor],
promote_table):
for fn in getattr(tensor_overrides, list_name):
promote_fn(cls, fn, verbose)
promote_fn(cls, fn, handle, verbose)
# 3) For any in-place version of a blacklist function, error if any input is fp16.
# NB: this is overly conservative.
for fn in utils.as_inplace(torch_overrides.FP32_FUNCS):
wrap.err_if_any_half(torch_overrides.MODULE, fn)
wrap.err_if_any_half(torch_overrides.MODULE, fn, handle)
# 3.5) For any in-place blacklist method, error if called on fp16 tensor
for fn in utils.as_inplace(tensor_overrides.FP32_FUNCS):
wrap.err_if_arg0_half(tensor_overrides.MODULE, fn, verbose)
wrap.err_if_arg0_half(tensor_overrides.MODULE, fn, handle, verbose)
if compat.tensor_is_float_tensor():
wrap.err_if_arg0_half(torch.cuda.HalfTensor, fn, verbose)
wrap.err_if_arg0_half(torch.cuda.HalfTensor, fn, handle, verbose)
# 4) For other in-place methods, match the type of self tensor
for fn in utils.as_inplace(itertools.chain(
tensor_overrides.FP16_FUNCS,
tensor_overrides.CASTS)):
wrap.promote_match_arg0(tensor_overrides.MODULE, fn, verbose)
wrap.promote_match_arg0(tensor_overrides.MODULE, fn, handle, verbose)
if compat.tensor_is_float_tensor():
wrap.promote_match_arg0(torch.cuda.HalfTensor, fn, verbose)
wrap.promote_match_arg0(torch.cuda.FloatTensor, fn, verbose)
wrap.promote_match_arg0(torch.cuda.HalfTensor, fn, handle, verbose)
wrap.promote_match_arg0(torch.cuda.FloatTensor, fn, handle, verbose)
# 5) Special handling to whitelist RNN cell backend impls.
for fn in ['RNNReLUCell', 'RNNTanhCell', 'LSTMCell', 'GRUCell']:
......@@ -143,7 +143,7 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
handle, try_caching=True, verbose=verbose)
# 5.5) Extra-special handling of RNN backend
wrap.rnn_cast(torch.nn.backends.thnn.backend, 'RNN', verbose)
wrap.rnn_cast(torch.nn.backends.thnn.backend, 'RNN', handle, verbose)
# And even more special handling of `backward` for fused gru / lstm
# The `backward` method calls Tensor.sum() (blacklist) internally,
......@@ -156,7 +156,7 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
# 6) Place error+print message on banned functions
if not allow_banned:
for fn, err_msg in functional_overrides.BANNED_FUNCS:
wrap.err_if_any_half(functional_overrides.MODULE, fn, err_msg)
wrap.err_if_any_half(functional_overrides.MODULE, fn, handle, err_msg)
_DECORATOR_HANDLE = handle
return handle
......@@ -2,6 +2,7 @@ import contextlib
import logging
import warnings
from . import utils
from .opt import OptimWrapper
from .scaler import LossScaler
......@@ -12,6 +13,7 @@ class AmpHandle(object):
self._cache = dict()
self._default_scaler = LossScaler()
self._is_active = True
self._all_wrappers = []
def is_active(self):
return self._is_active
......@@ -63,6 +65,15 @@ class AmpHandle(object):
def _clear_cache(self):
self._cache.clear()
# Experimental support for saving / restoring uncasted versions of functions
def _save_func(self, mod, fn, func):
self._all_wrappers.append((mod, fn, func))
def _deactivate(self):
for mod, fn, func in self._all_wrappers:
utils.set_func(mod, fn, func)
self._all_wrappers = []
@property
def has_cache(self):
return self._enable_caching
......
......@@ -126,6 +126,11 @@ def set_func(mod, fn, new_fn):
else:
setattr(mod, fn, new_fn)
def set_func_save(handle, mod, fn, new_fn):
cur_fn = get_func(mod, fn)
handle._save_func(mod, fn, cur_fn)
set_func(mod, fn, new_fn)
# A couple problems get solved here:
# - The flat_weight buffer is disconnected from autograd graph,
# so the fp16 weights need to be derived from the input weights
......
......@@ -34,7 +34,7 @@ def cached_cast(mod, fn, cast_fn, handle,
orig_fn = utils.get_func(mod, fn)
cast_fn = utils.verbosify(cast_fn, fn, verbose)
wrapper = make_cast_wrapper(orig_fn, cast_fn, handle, try_caching)
utils.set_func(mod, fn, wrapper)
utils.set_func_save(handle, mod, fn, wrapper)
# `handle` arg is unused, but simplifies API to make `make_cast_wrapper`
def make_promote_wrapper(orig_fn, cast_fn, handle=None):
......@@ -54,13 +54,13 @@ def make_promote_wrapper(orig_fn, cast_fn, handle=None):
.format(types))
return wrapper
def promote(mod, fn, verbose=False):
def promote(mod, fn, handle, verbose=False):
orig_fn = utils.get_func(mod, fn)
maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)
wrapper = make_promote_wrapper(orig_fn, maybe_float)
utils.set_func(mod, fn, wrapper)
utils.set_func_save(handle, mod, fn, wrapper)
def sequence_promote(mod, fn, verbose=False):
def sequence_promote(mod, fn, handle, verbose=False):
orig_fn = utils.get_func(mod, fn)
maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)
@functools.wraps(orig_fn)
......@@ -76,9 +76,9 @@ def sequence_promote(mod, fn, verbose=False):
# TODO: other mixed-type cases aren't due to amp.
# Just pass through?
return orig_fn(seq, *args, **kwargs)
utils.set_func(mod, fn, wrapper)
utils.set_func_save(handle, mod, fn, wrapper)
def promote_match_arg0(mod, fn, verbose=False):
def promote_match_arg0(mod, fn, handle, verbose=False):
if not utils.has_func(mod, fn):
return
......@@ -95,9 +95,9 @@ def promote_match_arg0(mod, fn, verbose=False):
cast_fn = utils.verbosify(cast_fn, fn, verbose)
new_args = utils.casted_args(cast_fn, args, kwargs)
return orig_fn(arg0, *new_args, **kwargs)
utils.set_func(mod, fn, wrapper)
utils.set_func_save(handle, mod, fn, wrapper)
def err_if_any_half(mod, fn, custom_err_msg=None):
def err_if_any_half(mod, fn, handle, custom_err_msg=None):
if not utils.has_func(mod, fn):
return
......@@ -113,9 +113,9 @@ def err_if_any_half(mod, fn, custom_err_msg=None):
'{} with fp16 arguments.'.format(fn))
else:
return orig_fn(*args, **kwargs)
utils.set_func(mod, fn, wrapper)
utils.set_func_save(handle, mod, fn, wrapper)
def err_if_arg0_half(mod, fn, verbose=False):
def err_if_arg0_half(mod, fn, handle, verbose=False):
if not utils.has_func(mod, fn):
return
......@@ -130,7 +130,7 @@ def err_if_arg0_half(mod, fn, verbose=False):
cast_fn = utils.verbosify(utils.maybe_float, fn, verbose)
new_args = utils.casted_args(cast_fn, args, kwargs)
return orig_fn(arg0, *new_args, **kwargs)
utils.set_func(mod, fn, wrapper)
utils.set_func_save(handle, mod, fn, wrapper)
# Current RNN approach:
# - Wrap top-level `RNN` function in thnn backend
......@@ -140,7 +140,7 @@ def err_if_arg0_half(mod, fn, verbose=False):
# - We interpose on the factory function to:
# 1) Interpose on the actual forward function and put in casts
# 2) Insert an fp16 `flat_weight` if necessary
def rnn_cast(backend, fn, verbose=False):
def rnn_cast(backend, fn, handle, verbose=False):
orig_rnn = utils.get_func(backend, fn)
@functools.wraps(orig_rnn)
def rnn_wrapper(*args, **kwargs):
......@@ -203,7 +203,7 @@ def rnn_cast(backend, fn, verbose=False):
return forward(*new_args, **fkwargs)
return fwd_wrapper
utils.set_func(backend, fn, rnn_wrapper)
utils.set_func_save(handle, backend, fn, rnn_wrapper)
def disable_casts(mod, fn, handle):
if not utils.has_func(mod, fn):
......@@ -214,4 +214,4 @@ def disable_casts(mod, fn, handle):
def wrapper(*args, **kwargs):
with handle._disable_casts():
return orig_fn(*args, **kwargs)
utils.set_func(mod, fn, wrapper)
utils.set_func_save(handle, mod, fn, wrapper)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment