Commit 2445031d authored by Michael Carilli's avatar Michael Carilli
Browse files

Cherry picking RNN fix

parent 612d4193
......@@ -11,10 +11,15 @@ def _gen_VF_wrapper(name):
# Some python magic to generate an object that has the rnn cell functions
# defined on it, all of which call into corresponding _VF version.
# Intended to patch torch.nn.modules.rnn._VF (aka, the ref named "_VF"
# imported at module scope within torch.nn.modules.rnn). This should
# not affect third-party importers of _VF.py.
class VariableFunctionsShim(object):
def __init__(self):
for name in RNN_NAMES:
setattr(self, name + '_cell', _gen_VF_wrapper(name + '_cell'))
for suffix in ['', '_cell']:
fn_name = name + suffix
setattr(self, fn_name, _gen_VF_wrapper(fn_name))
def has_old_rnns():
try:
......
from . import compat
from . import utils
from ._amp_state import _amp_state
from . import rnn_compat
import functools
......@@ -219,7 +220,17 @@ def rnn_cast(backend, fn, handle, verbose=False):
utils.set_func_save(handle, backend, fn, rnn_wrapper)
def new_rnn_cast(fn, handle, verbose=False):
mod = torch.nn.modules.rnn._rnn_impls
# Forward+backward compatibility around https://github.com/pytorch/pytorch/pull/15744
# For rnn backend calls that route through _rnn_impls, we must patch the ref
# that _rnn_impls stashed. For rnn backend calls that directly invoke
# _VF.<backend>, e.g. _VF.lstm, we can patch onto VariableFunctionsShim,
# which in turn has patched the ref named "_VF" in torch.nn.modules.rnn.
if utils.has_func(torch.nn.modules.rnn._rnn_impls, fn):
mod = torch.nn.modules.rnn._rnn_impls
else:
mod = torch.nn.modules.rnn._VF
assert isinstance(mod, rnn_compat.VariableFunctionsShim)
fn = fn.lower()
orig_fn = utils.get_func(mod, fn)
cast_fn = utils.verbosify(utils.maybe_half, fn, verbose)
@functools.wraps(orig_fn)
......
......@@ -6,7 +6,8 @@ print_banner() {
print_banner "Distributed status: $1"
DATADIR="/home/mcarilli/Desktop/pt18data/apex/examples/imagenet/bare_metal_train_val/"
# DATADIR="/home/mcarilli/Desktop/pt18data/apex/examples/imagenet/bare_metal_train_val/"
DATADIR="/opt/home/apex/examples/imagenet/"
if [ "$1" == "single_gpu" ]
then
......
#!/bin/bash
cp ../common/* .
bash run_test.sh distributed
bash run_test.sh single_gpu
#!/bin/bash
cp ../common/* .
bash run_test.sh single_gpu
bash run_test.sh distributed
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment