Commit 7c36c412 authored by Michael Carilli's avatar Michael Carilli
Browse files

Forward+backward compatibility fix around https://github.com/pytorch/pytorch/pull/15744

parent 1b903852
...@@ -11,10 +11,15 @@ def _gen_VF_wrapper(name): ...@@ -11,10 +11,15 @@ def _gen_VF_wrapper(name):
# Some python magic to generate an object that has the rnn cell functions # Some python magic to generate an object that has the rnn cell functions
# defined on it, all of which call into corresponding _VF version. # defined on it, all of which call into corresponding _VF version.
# Intended to patch torch.nn.modules.rnn._VF (aka, the ref named "_VF"
# imported at module scope within torch.nn.modules.rnn). This should
# not affect third-party importers of _VF.py.
class VariableFunctionsShim(object): class VariableFunctionsShim(object):
def __init__(self): def __init__(self):
for name in RNN_NAMES: for name in RNN_NAMES:
setattr(self, name + '_cell', _gen_VF_wrapper(name + '_cell')) for suffix in ['', '_cell']:
fn_name = name + suffix
setattr(self, fn_name, _gen_VF_wrapper(fn_name))
def has_old_rnns(): def has_old_rnns():
try: try:
......
from . import compat from . import compat
from . import utils from . import utils
from . import rnn_compat
import functools import functools
...@@ -206,7 +207,17 @@ def rnn_cast(backend, fn, handle, verbose=False): ...@@ -206,7 +207,17 @@ def rnn_cast(backend, fn, handle, verbose=False):
utils.set_func_save(handle, backend, fn, rnn_wrapper) utils.set_func_save(handle, backend, fn, rnn_wrapper)
def new_rnn_cast(fn, handle, verbose=False): def new_rnn_cast(fn, handle, verbose=False):
# Forward+backward compatibility around https://github.com/pytorch/pytorch/pull/15744
# For rnn backend calls that route through _rnn_impls, we must patch the ref
# that _rnn_impls stashed. For rnn backend calls that directly invoke
# _VF.<backend>, e.g. _VF.lstm, we can patch onto VariableFunctionsShim,
# which in turn has patched the ref named "_VF" in torch.nn.modules.rnn.
if utils.has_func(torch.nn.modules.rnn._rnn_impls, fn):
mod = torch.nn.modules.rnn._rnn_impls mod = torch.nn.modules.rnn._rnn_impls
else:
mod = torch.nn.modules.rnn._VF
assert isinstance(mod, rnn_compat.VariableFunctionsShim)
fn = fn.lower()
orig_fn = utils.get_func(mod, fn) orig_fn = utils.get_func(mod, fn)
cast_fn = utils.verbosify(utils.maybe_half, fn, verbose) cast_fn = utils.verbosify(utils.maybe_half, fn, verbose)
@functools.wraps(orig_fn) @functools.wraps(orig_fn)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment