Commit 613997ea authored by Michael Carilli's avatar Michael Carilli
Browse files

No need for casts during optimizer step

parent ed8236fa
from .amp import init, half_function, float_function, promote_function,\
register_half_function, register_float_function, register_promote_function
from .handle import scale_loss
from .handle import scale_loss, disable_casts
from .frontend import initialize
......@@ -2,6 +2,7 @@ import torch
from torch._six import container_abcs, string_classes
import functools
from ._amp_state import _amp_state
from .handle import disable_casts
from .scaler import LossScaler
from apex.fp16_utils import convert_network
from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general
......@@ -111,8 +112,8 @@ def _initialize(models, optimizers, properties):
check_optimizers(optimizers)
# Stash master weights before casting the model.
# if properties.master_weights:
# In the future, when FP16_Optimizer can be deprecated and master weights can
# become an attribute, remember to stash master weights before casting the model.
if properties.cast_model_type:
if properties.keep_batchnorm_fp32:
......@@ -125,6 +126,7 @@ def _initialize(models, optimizers, properties):
caster = functools.partial(to_type, properties.cast_model_type)
# Patch the forward method to cast incoming data to the correct type.
# I like writing things explicitly more than decorators.
def patch_forward(old_fwd):
def new_fwd(*args, **kwargs):
return old_fwd(*applier(args, caster),
......@@ -142,10 +144,10 @@ def _initialize(models, optimizers, properties):
if isinstance(optimizer, FusedAdam):
optimizers[i] = wrap_fused_adam(optimizer, properties)
if properties.loss_scale == "dynamic":
optimizers[i] = FP16_Optimizer_general(optimizers[i],
optimizers[i] = FP16_Optimizer_general(optimizer,
dynamic_loss_scale=True)
else:
optimizers[i] = FP16_Optimizer_general(optimizers[i],
optimizers[i] = FP16_Optimizer_general(optimizer,
static_loss_scale=properties.loss_scale)
else:
for optimizer in optimizers:
......@@ -154,6 +156,17 @@ def _initialize(models, optimizers, properties):
if properties.patch_torch_functions:
# handle is unused here. It's accessible later through a global value anyway.
handle = amp_init(loss_scale=properties.loss_scale)
for optimizer in optimizers:
# Disable Amp casting for the optimizer step, because it should only be
# applied to FP32 master params anyway.
def patch_step(old_step):
def new_step(*args, **kwargs):
with disable_casts():
output = old_step(*args, **kwargs)
return output
return new_step
optimizer.step = patch_step(optimizer.step)
if optimizers_was_list:
if models_was_list:
......
......@@ -50,7 +50,7 @@ def scale_loss(loss,
iter_params(optimizer.param_groups),
iter_params(optimizer.param_groups),
loss_scale)
# In the future, once I have fused optimizers that enable sync-free dynamic loss scaling,
# For future fused optimizers that enable sync-free dynamic loss scaling,
# should_skip will always be False.
should_skip = optimizer.loss_scaler.update_scale()
if should_skip:
......@@ -66,6 +66,15 @@ def scale_loss(loss,
_amp_state.handle._clear_cache()
# Free function version of AmpHandle.disable_casts, another step on the
# path to removing the concept of "AmpHandle"
@contextlib.contextmanager
def disable_casts():
_amp_state.handle._is_active = False
yield
_amp_state.handle._is_active = True
class AmpHandle(object):
def __init__(self, loss_scale="dynamic", enable_caching=True, verbose=False):
self._enable_caching = enable_caching
......
from . import compat
from . import utils
from ._amp_state import _amp_state
import functools
......@@ -37,10 +38,16 @@ def cached_cast(mod, fn, cast_fn, handle,
utils.set_func_save(handle, mod, fn, wrapper)
# `handle` arg is unused, but simplifies API to make `make_cast_wrapper`
# Annoyingly, make_promote_wrapper still uses the global handle. Once everyone
# is on the new API and I am free to get rid of handle, I can clean this up.
def make_promote_wrapper(orig_fn, cast_fn, handle=None):
@functools.wraps(orig_fn)
def wrapper(*args, **kwargs):
if not _amp_state.handle.is_active():
return orig_fn(*args, **kwargs)
types = utils.collect_fp_tensor_types(args, kwargs)
if len(types) <= 1:
return orig_fn(*args, **kwargs)
elif len(types) == 2 and types == set(['HalfTensor', 'FloatTensor']):
......@@ -65,6 +72,9 @@ def sequence_promote(mod, fn, handle, verbose=False):
maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)
@functools.wraps(orig_fn)
def wrapper(seq, *args, **kwargs):
if not _amp_state.handle.is_active():
return orig_fn(seq, *args, **kwargs)
types = set([utils.type_string(x) for x in seq])
if len(types) <= 1:
return orig_fn(seq, *args, **kwargs)
......@@ -86,6 +96,9 @@ def promote_match_arg0(mod, fn, handle, verbose=False):
@functools.wraps(orig_fn)
def wrapper(arg0, *args, **kwargs):
assert compat.is_tensor_like(arg0)
if not _amp_state.handle.is_active():
return orig_fn(arg0, *args, **kwargs)
if utils.type_string(arg0) == 'HalfTensor':
cast_fn = utils.maybe_half
elif utils.type_string(arg0) == 'FloatTensor':
......@@ -215,6 +228,9 @@ def new_rnn_cast(fn, handle, verbose=False):
assert len(args) == 9
assert len(kwargs) == 0
if not _amp_state.handle.is_active():
return orig_fn(*args, **kwargs)
if isinstance(args[6], bool):
params_idx = 2 # Not PackedSequence case
else:
......
import torch
import torch.nn as nn
from torch.autograd import Variable
import apex
from apex.RNN.models import bidirectionalRNN, stackedRNN, RNNCell
from torch.nn._functions.rnn import LSTMCell
import itertools
torch.backends.cudnn.enabled=False
batch_first = False #not implemented yet
dropout = 0.0 #How to validate?
bidirectional = False #True works, but differs in definition to PyTorch
rnn_types = ['LSTM', 'GRU', 'ReLU', 'Tanh']
sizes = [8,4,2]
seq_sizes = sizes
hidden_sizes = sizes
inp_sizes = sizes
batch_sizes = sizes
num_layerss = sizes
biases = [True]
def copy_param_set(pyt_rnn, my_rnn, layer=0, reverse=False):
my_params = None
rnn = None
if isinstance(my_rnn, bidirectionalRNN):
rnn = my_rnn.fwd.rnns[layer] if not reverse else my_rnn.bckwrd.rnns[layer]
elif isinstance(my_rnn, stackedRNN):
rnn = my_rnn.rnns[layer]
else:
raise RuntimeError()
param_names = ['w_ih', 'w_hh', 'b_ih', 'b_hh']
if not hasattr(rnn, 'b_hh'):
param_names = param_names[:2]
my_params = [getattr(rnn, param_name) for param_name in param_names]
pyt_params = None
param_names = ['weight_ih_', 'weight_hh_', 'bias_ih_', 'bias_hh_']
reverse_str = '_reverse' if reverse else ''
if not hasattr(pyt_rnn, 'bias_hh_l0'):
param_names=param_names[:2]
pyt_params =[getattr(pyt_rnn, param_name + 'l' + str(layer) + reverse_str )
for param_name in param_names ]
for pyt_param, my_param in zip(pyt_params, my_params):
pyt_param.data.copy_(my_param.data)
def copy_all_params(pyt_rnn, my_rnn):
for layer in range(num_layers):
copy_param_set(pyt_rnn, my_rnn, layer)
if bidirectional:
copy_param_set(pyt_rnn, my_rnn, layer, bidirectional)
def compare_variables(v1, v2, msg, params):
diff = float((v1.data-v2.data).abs().max())
if diff > 1e-5:
print("Error of ", diff, " found for ", msg, " for case: ", str(params))
def compare_tuple_variables(t1, t2, msg, params):
for var1, var2 in zip(t1, t2):
compare_variables(var1, var2, msg, params)
def maybe_compare(v1, v2, msg, params):
if isinstance(v1, Variable) and isinstance(v2, Variable):
compare_variables(v1, v2, msg, params)
else:
compare_tuple_variables(v1, v2, msg, params)
product = list(itertools.product(rnn_types, seq_sizes, hidden_sizes, inp_sizes, batch_sizes, num_layerss, biases))
for test_case in product:
rnn_type, seq_size, hidden_size, inp_size, batch_size, num_layers, bias = test_case
inp = torch.cuda.FloatTensor(seq_size, batch_size, inp_size).uniform_()
if rnn_type == 'ReLU' or rnn_type == 'Tanh':
pytorch_rnn = nn.RNN(inp_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional, nonlinearity=rnn_type.lower()).cuda()
else:
pytorch_rnn = getattr(nn, rnn_type)(inp_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional).cuda()
my_rnn = getattr(apex.RNN.models, rnn_type)(inp_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional).cuda()
copy_all_params(pytorch_rnn, my_rnn)
pyt_inp = Variable(inp, requires_grad=True)
my_inp = Variable(inp, requires_grad=True)
my_out, my_hiddens = my_rnn(my_inp)
pyt_out, pyt_hiddens = pytorch_rnn(pyt_inp)
pyt_out.sum().backward()
my_out.sum().backward()
maybe_compare(pyt_out, my_out, "out", test_case)
#If there's only one hidden state PyTorch doesn't return it in a tuple,
#apex does, so we wrap PyTorch's returned hidden state in a tuple.
if not isinstance(pyt_hiddens, tuple):
pyt_hiddens = (pyt_hiddens,)
try:
for i, (pyt_hid, my_hid) in enumerate(zip(pyt_hiddens, my_hiddens)):
maybe_compare(pyt_hid, my_hid , "hx_"+str(i), test_case)
except ValueError:
maybe_compare(pyt_hiddens, my_hiddens , "hx_0", test_case)
maybe_compare(pyt_inp.grad, my_inp.grad, "inp.grad", test_case)
print("Test passed.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment