"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "d38600fee193e8b9f97f3ec5ea02d3fd355cf6a0"
Unverified Commit 43d1ae08 authored by Carl Case's avatar Carl Case Committed by GitHub
Browse files

Merge pull request #11 from NVIDIA/amp_lstm_backward

Handle the use of .sum() in fused LSTM/GRU backward
parents 227a9a2d 32fbbe48
...@@ -145,6 +145,14 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False): ...@@ -145,6 +145,14 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
# 5.5) Extra-special handling of RNN backend # 5.5) Extra-special handling of RNN backend
wrap.rnn_cast(torch.nn.backends.thnn.backend, 'RNN', verbose) wrap.rnn_cast(torch.nn.backends.thnn.backend, 'RNN', verbose)
# And even more special handling of `backward` for fused gru / lstm
# The `backward` method calls Tensor.sum() (blacklist) internally,
# and then the resulting grad_input has the wrong type.
# TODO: where else is this a problem?
for rnn_type in ['GRUFused', 'LSTMFused']:
mod = getattr(torch.nn._functions.thnn.rnnFusedPointwise, rnn_type)
wrap.disable_casts(mod, 'backward', handle)
# 6) Place error+print message on banned functions # 6) Place error+print message on banned functions
if not allow_banned: if not allow_banned:
for fn, err_msg in functional_overrides.BANNED_FUNCS: for fn, err_msg in functional_overrides.BANNED_FUNCS:
......
...@@ -11,9 +11,16 @@ class AmpHandle(object): ...@@ -11,9 +11,16 @@ class AmpHandle(object):
self._verbose = verbose self._verbose = verbose
self._cache = dict() self._cache = dict()
self._default_scaler = LossScaler() self._default_scaler = LossScaler()
self._is_active = True
def is_active(self): def is_active(self):
return True return self._is_active
@contextlib.contextmanager
def _disable_casts(self):
self._is_active = False
yield
self._is_active = True
def wrap_optimizer(self, optimizer, num_loss=1): def wrap_optimizer(self, optimizer, num_loss=1):
self._default_scaler = None self._default_scaler = None
...@@ -76,6 +83,10 @@ class NoOpHandle(object): ...@@ -76,6 +83,10 @@ class NoOpHandle(object):
def is_active(self): def is_active(self):
return False return False
@contextlib.contextmanager
def _disable_casts(self):
yield
def wrap_optimizer(self, optimizer, num_loss=1): def wrap_optimizer(self, optimizer, num_loss=1):
return OptimWrapper(optimizer, self, num_loss) return OptimWrapper(optimizer, self, num_loss)
......
...@@ -9,6 +9,9 @@ def make_cast_wrapper(orig_fn, cast_fn, handle, ...@@ -9,6 +9,9 @@ def make_cast_wrapper(orig_fn, cast_fn, handle,
try_caching=False): try_caching=False):
@functools.wraps(orig_fn) @functools.wraps(orig_fn)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if not handle.is_active():
return orig_fn(*args, **kwargs)
if try_caching and handle.has_cache: if try_caching and handle.has_cache:
args = list(args) args = list(args)
for i in range(len(args)): for i in range(len(args)):
...@@ -201,3 +204,14 @@ def rnn_cast(backend, fn, verbose=False): ...@@ -201,3 +204,14 @@ def rnn_cast(backend, fn, verbose=False):
return forward(*new_args, **fkwargs) return forward(*new_args, **fkwargs)
return fwd_wrapper return fwd_wrapper
utils.set_func(backend, fn, rnn_wrapper) utils.set_func(backend, fn, rnn_wrapper)
def disable_casts(mod, fn, handle):
if not utils.has_func(mod, fn):
return
orig_fn = utils.get_func(mod, fn)
@functools.wraps(orig_fn)
def wrapper(*args, **kwargs):
with handle._disable_casts():
return orig_fn(*args, **kwargs)
utils.set_func(mod, fn, wrapper)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment