Commit 3e1a1c09 authored by Carl Case's avatar Carl Case Committed by mcarilli
Browse files

amp support for Aten RNNs (#41)

* WIP: update to support new RNN backend code

* small refactor

* add test for rnn w/ packed sequences
parent 1579b9e3
from . import compat, utils, wrap from . import compat, rnn_compat, utils, wrap
from .handle import AmpHandle, NoOpHandle from .handle import AmpHandle, NoOpHandle
from .lists import functional_overrides, torch_overrides, tensor_overrides from .lists import functional_overrides, torch_overrides, tensor_overrides
...@@ -137,21 +137,18 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False): ...@@ -137,21 +137,18 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
wrap.promote_match_arg0(torch.cuda.HalfTensor, fn, handle, verbose) wrap.promote_match_arg0(torch.cuda.HalfTensor, fn, handle, verbose)
wrap.promote_match_arg0(torch.cuda.FloatTensor, fn, handle, verbose) wrap.promote_match_arg0(torch.cuda.FloatTensor, fn, handle, verbose)
# 5) Special handling to whitelist RNN cell backend impls. # 5) RNNs + RNN cells are whitelisted specially
for fn in ['RNNReLUCell', 'RNNTanhCell', 'LSTMCell', 'GRUCell']: if rnn_compat.has_old_rnns():
wrap.cached_cast(torch.nn.backends.thnn.backend, fn, utils.maybe_half, wrap.rnn_cast(torch.nn.backends.thnn.backend, 'RNN', handle, verbose)
handle, try_caching=True, verbose=verbose) if not rnn_compat.has_old_rnns():
# Patch in our own indirection of `_VF` in modules/rnn s.t. it is mutable.
# 5.5) Extra-special handling of RNN backend torch.nn.modules.rnn._VF = rnn_compat.VariableFunctionsShim()
wrap.rnn_cast(torch.nn.backends.thnn.backend, 'RNN', handle, verbose) # Wrap all the rnns
for x in rnn_compat.RNN_NAMES:
# And even more special handling of `backward` for fused gru / lstm wrap.new_rnn_cast(x.upper(), handle, verbose)
# The `backward` method calls Tensor.sum() (blacklist) internally,
# and then the resulting grad_input has the wrong type. # Wrap all the RNN cells
# TODO: where else is this a problem? rnn_compat.whitelist_rnn_cells(handle, verbose)
for rnn_type in ['GRUFused', 'LSTMFused']:
mod = getattr(torch.nn._functions.thnn.rnnFusedPointwise, rnn_type)
wrap.disable_casts(mod, 'backward', handle)
# 6) Place error+print message on banned functions. # 6) Place error+print message on banned functions.
# Or, if allow_banned, then cast to FP32. # Or, if allow_banned, then cast to FP32.
......
from . import utils, wrap
import torch
_VF = torch._C._VariableFunctions
RNN_NAMES = ['rnn_relu', 'rnn_tanh', 'gru', 'lstm']
def _gen_VF_wrapper(name):
def wrapper(*args, **kwargs):
return getattr(_VF, name)(*args, **kwargs)
return wrapper
# Some python magic to generate an object that has the rnn cell functions
# defined on it, all of which call into corresponding _VF version.
class VariableFunctionsShim(object):
def __init__(self):
for name in RNN_NAMES:
setattr(self, name + '_cell', _gen_VF_wrapper(name + '_cell'))
def has_old_rnns():
try:
torch.nn.backends.thnn.backend.LSTMCell
return True
except:
return False
def whitelist_rnn_cells(handle, verbose):
# Different module + function names in old/new RNN cases
if has_old_rnns():
fn_names = ['RNNReLUCell', 'RNNTanhCell', 'LSTMCell', 'GRUCell']
mod = torch.nn.backends.thnn.backend
else:
fn_names = [x + '_cell' for x in RNN_NAMES]
mod = torch.nn.modules.rnn._VF
assert isinstance(mod, VariableFunctionsShim)
# Insert casts on cell functions
for fn in fn_names:
wrap.cached_cast(mod, fn, utils.maybe_half, handle,
try_caching=True, verbose=verbose)
if has_old_rnns():
# Special handling of `backward` for fused gru / lstm:
# The `backward` method calls Tensor.sum() (blacklist) internally,
# and then the resulting grad_input has the wrong type.
# TODO: where else is this a problem?
for rnn_type in ['GRUFused', 'LSTMFused']:
mod = getattr(torch.nn._functions.thnn.rnnFusedPointwise, rnn_type)
wrap.disable_casts(mod, 'backward', handle)
import unittest import unittest
from apex import amp from apex import amp
import random
import torch import torch
from torch import nn from torch import nn
...@@ -93,5 +94,23 @@ class TestRnns(unittest.TestCase): ...@@ -93,5 +94,23 @@ class TestRnns(unittest.TestCase):
bidirectional=bidir) bidirectional=bidir)
self.run_rnn_test(rnn, layers, bidir, state_tuple=True) self.run_rnn_test(rnn, layers, bidir, state_tuple=True)
def test_rnn_packed_sequence(self):
num_layers = 2
rnn = nn.RNN(input_size=self.h, hidden_size=self.h, num_layers=num_layers)
for typ in [torch.float, torch.half]:
x = torch.randn((self.t, self.b, self.h), dtype=typ).requires_grad_()
lens = sorted([random.randint(self.t // 2, self.t) for _ in range(self.b)],
reverse=True)
# `pack_padded_sequence` breaks if default tensor type is non-CPU
torch.set_default_tensor_type(torch.FloatTensor)
lens = torch.tensor(lens, dtype=torch.int64, device=torch.device('cpu'))
packed_seq = nn.utils.rnn.pack_padded_sequence(x, lens)
torch.set_default_tensor_type(torch.cuda.FloatTensor)
hidden = torch.zeros((num_layers, self.b, self.h), dtype=typ)
output, _ = rnn(packed_seq, hidden)
self.assertEqual(output.data.type(), HALF)
output.data.float().sum().backward()
self.assertEqual(x.grad.dtype, x.dtype)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -111,18 +111,24 @@ def as_inplace(fns): ...@@ -111,18 +111,24 @@ def as_inplace(fns):
def has_func(mod, fn): def has_func(mod, fn):
if isinstance(mod, torch.nn.backends.backend.FunctionBackend): if isinstance(mod, torch.nn.backends.backend.FunctionBackend):
return fn in mod.function_classes return fn in mod.function_classes
elif isinstance(mod, dict):
return fn in mod
else: else:
return hasattr(mod, fn) return hasattr(mod, fn)
def get_func(mod, fn): def get_func(mod, fn):
if isinstance(mod, torch.nn.backends.backend.FunctionBackend): if isinstance(mod, torch.nn.backends.backend.FunctionBackend):
return mod.function_classes[fn] return mod.function_classes[fn]
elif isinstance(mod, dict):
return mod[fn]
else: else:
return getattr(mod, fn) return getattr(mod, fn)
def set_func(mod, fn, new_fn): def set_func(mod, fn, new_fn):
if isinstance(mod, torch.nn.backends.backend.FunctionBackend): if isinstance(mod, torch.nn.backends.backend.FunctionBackend):
mod.function_classes[fn] = new_fn mod.function_classes[fn] = new_fn
elif isinstance(mod, dict):
mod[fn] = new_fn
else: else:
setattr(mod, fn, new_fn) setattr(mod, fn, new_fn)
...@@ -165,3 +171,23 @@ def synthesize_flattened_rnn_weights(fp32_weights, ...@@ -165,3 +171,23 @@ def synthesize_flattened_rnn_weights(fp32_weights,
fp16_layer_weights.append(w_fp16) fp16_layer_weights.append(w_fp16)
fp16_weights.append(fp16_layer_weights) fp16_weights.append(fp16_layer_weights)
return fp16_weights return fp16_weights
# Roughly same as above, just the `fp32_weights` aren't nested.
# Code kept separate for readability.
def new_synthesize_flattened_rnn_weights(fp32_weights,
fp16_flat_tensor,
rnn_fn='',
verbose=False):
fp16_weights = []
fp32_base_ptr = fp32_weights[0].data_ptr()
for w_fp32 in fp32_weights:
w_fp16 = w_fp32.new().half()
offset = (w_fp32.data_ptr() - fp32_base_ptr) // w_fp32.element_size()
w_fp16.set_(fp16_flat_tensor.storage(),
offset,
w_fp32.shape)
w_fp16.copy_(w_fp32)
if verbose:
print('Float->Half ({})'.format(rnn_fn))
fp16_weights.append(w_fp16)
return fp16_weights
...@@ -205,6 +205,38 @@ def rnn_cast(backend, fn, handle, verbose=False): ...@@ -205,6 +205,38 @@ def rnn_cast(backend, fn, handle, verbose=False):
return fwd_wrapper return fwd_wrapper
utils.set_func_save(handle, backend, fn, rnn_wrapper) utils.set_func_save(handle, backend, fn, rnn_wrapper)
def new_rnn_cast(fn, handle, verbose=False):
mod = torch.nn.modules.rnn._rnn_impls
orig_fn = utils.get_func(mod, fn)
cast_fn = utils.verbosify(utils.maybe_half, fn, verbose)
@functools.wraps(orig_fn)
def wrapper(*args, **kwargs):
# Exact call signature from modules/rnn.py
assert len(args) == 9
assert len(kwargs) == 0
if isinstance(args[6], bool):
params_idx = 2 # Not PackedSequence case
else:
params_idx = 3 # PackedSequence case
new_args = []
for i, arg in enumerate(args):
if i == params_idx:
num_params = sum([x.numel() for x in arg])
fp16_weight_buf = args[0].new_empty((num_params,),
dtype=torch.half)
casted_weights = utils.new_synthesize_flattened_rnn_weights(
arg, fp16_weight_buf, fn, verbose)
new_args.append(casted_weights)
elif utils.is_fp_tensor(arg):
new_args.append(cast_fn(arg))
else:
new_args.append(arg)
return orig_fn(*new_args)
utils.set_func_save(handle, mod, fn, wrapper)
def disable_casts(mod, fn, handle): def disable_casts(mod, fn, handle):
if not utils.has_func(mod, fn): if not utils.has_func(mod, fn):
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment