Commit de3f3fea authored by rohithkrn's avatar rohithkrn
Browse files

add bfloat16 register functions, enable rnn functions, enable promote functions

parent 6e14df49
from .amp import init, half_function, float_function, promote_function,\
register_half_function, register_float_function, register_promote_function
from .amp import init, half_function, bfloat16_function, float_function, promote_function,\
register_half_function, register_bfloat16_function, register_float_function, register_promote_function
from .handle import scale_loss, disable_casts
from .frontend import initialize, state_dict, load_state_dict
from ._amp_state import master_params, _amp_state
......@@ -189,7 +189,7 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
for model in models:
# Patch the forward method to cast incoming data to the correct type, and
# outgoing data to float32, so "the user never needs to call .half()."
# outgoing data to float32, so "the user never needs to call .half()/.bfloat16()."
# I like writing things explicitly more than decorators.
def patch_forward(old_fwd):
def new_fwd(*args, **kwargs):
......
......@@ -213,8 +213,8 @@ def lazy_init_no_master_weights(self):
elif param.type() == 'torch.cuda.FloatTensor':
stash.all_fp32_params.append(param)
else:
raise TypeError("Optimizer's parameters must be either "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
raise TypeError("Optimizer's parameters must be one of "
"torch.cuda.FloatTensor, torch.cuda.HalfTensor, torch.BFloat16Tensor. "
"Received {}".format(param.type()))
stash.all_fp16_grad_stash = [None for _ in stash.all_fp16_params]
......
......@@ -30,6 +30,9 @@ def half_function(fn):
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=True)
return _decorator_helper(fn, utils.maybe_half, wrap_fn)
def bfloat16_function(fn):
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=True)
return _decorator_helper(fn, utils.maybe_bfloat16, wrap_fn)
def float_function(fn):
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=False)
......@@ -48,6 +51,11 @@ def register_half_function(module, name):
name, module))
_USER_CAST_REGISTRY.add((module, name, utils.maybe_half))
def register_bfloat16_function(module, name):
if not hasattr(module, name):
raise ValueError('No function named {} in module {}.'.format(
name, module))
_USER_CAST_REGISTRY.add((module, name, utils.maybe_bfloat16))
def register_float_function(module, name):
if not hasattr(module, name):
......@@ -116,11 +124,11 @@ def init(enabled=True, loss_scale="dynamic", patch_type=torch.float16, enable_ca
# 1.5) Pre-0.4, put the blacklist methods on HalfTensor and whitelist
# methods on FloatTensor, since they're distinct types.
if compat.tensor_is_float_tensor():
for fn in getattr(tensor_overrides, low_prec_funcs):
wrap.cached_cast(torch.cuda.FloatTensor, fn, utils.maybe_low_prec,
for fn in getattr(tensor_overrides, 'FP16_FUNCS'):
wrap.cached_cast(torch.cuda.FloatTensor, fn, utils.maybe_half,
handle, try_caching=True, verbose=verbose)
for fn in tensor_overrides.FP32_FUNCS:
wrap.cached_cast(low_prec_tensor, fn, utils.maybe_float,
wrap.cached_cast(torch.cuda.HalfTensor, fn, utils.maybe_float,
handle, try_caching=False, verbose=verbose)
# 2) Enable type-promotion on multi-arg functions and methods.
......@@ -136,17 +144,17 @@ def init(enabled=True, loss_scale="dynamic", patch_type=torch.float16, enable_ca
# 2.5) Pre-0.4, add blacklist methods directly to HalfTensor and FloatTensor types
if compat.tensor_is_float_tensor():
for cls, (list_name, promote_fn) in itertools.product([torch.cuda.FloatTensor,
low_prec_tensor],
torch.cuda.HalfTensor],
promote_table):
for fn in getattr(tensor_overrides, list_name):
promote_fn(cls, fn, handle, verbose)
# 3) For any in-place version of a blacklist function, error if any input is fp16.
# 3) For any in-place version of a blacklist function, error if any input is fp16/bfloat16.
# NB: this is overly conservative.
for fn in utils.as_inplace(torch_overrides.FP32_FUNCS):
wrap.err_if_any_half(torch_overrides.MODULE, fn, handle)
# 3.5) For any in-place blacklist method, error if called on fp16 tensor
# 3.5) For any in-place blacklist method, error if called on fp16/bfloat16 tensor
for fn in utils.as_inplace(tensor_overrides.FP32_FUNCS):
wrap.err_if_arg0_half(tensor_overrides.MODULE, fn, handle, verbose)
if compat.tensor_is_float_tensor():
......@@ -158,7 +166,7 @@ def init(enabled=True, loss_scale="dynamic", patch_type=torch.float16, enable_ca
tensor_overrides.CASTS)):
wrap.promote_match_arg0(tensor_overrides.MODULE, fn, handle, verbose)
if compat.tensor_is_float_tensor():
wrap.promote_match_arg0(low_prec_tensor, fn, handle, verbose)
wrap.promote_match_arg0(torch.cuda.HalfTensor, fn, handle, verbose)
wrap.promote_match_arg0(torch.cuda.FloatTensor, fn, handle, verbose)
# 5) RNNs + RNN cells are whitelisted specially
......@@ -169,10 +177,10 @@ def init(enabled=True, loss_scale="dynamic", patch_type=torch.float16, enable_ca
torch.nn.modules.rnn._VF = rnn_compat.VariableFunctionsShim()
# Wrap all the rnns
for x in rnn_compat.RNN_NAMES:
wrap.new_rnn_cast(x.upper(), handle, verbose)
wrap.new_rnn_cast(x.upper(), maybe_low_prec, handle, verbose)
# Wrap all the RNN cells
rnn_compat.whitelist_rnn_cells(handle, verbose)
rnn_compat.whitelist_rnn_cells(maybe_low_prec, handle, verbose)
# 6) Place error+print message on banned functions.
# Or, if allow_banned, then cast to FP32.
......
......@@ -16,6 +16,9 @@ class Properties(object):
"opt_level" : None,
"cast_model_type" : None,
"patch_torch_functions" : False,
# TODO: patch_torch_functions_type could probably be unified with
# patch_torch_functions. Currently introducing a new attribute
# to be on the safer side and not break stuff.
"patch_torch_functions_type" : None,
"keep_batchnorm_fp32" : None,
"master_weights" : None,
......@@ -390,7 +393,7 @@ def initialize(
maybe_print("Selected optimization level {}".format(opt_levels[opt_level].brief), True)
maybe_print("Defaults for this optimization level are:", True)
for k, v in _amp_state.opt_properties.options.items():
maybe_print("{:22} : {}".format(k, v), True)
maybe_print("{:26} : {}".format(k, v), True)
_amp_state.min_loss_scale = min_loss_scale
_amp_state.max_loss_scale = max_loss_scale
......@@ -417,7 +420,7 @@ def initialize(
maybe_print("After processing overrides, optimization options are:", True)
for k, v in _amp_state.opt_properties.options.items():
maybe_print("{:22} : {}".format(k, v), True)
maybe_print("{:26} : {}".format(k, v), True)
return _initialize(models, optimizers, _amp_state.opt_properties, num_losses, cast_model_outputs)
......
......@@ -28,7 +28,7 @@ def has_old_rnns():
except:
return False
def whitelist_rnn_cells(handle, verbose):
def whitelist_rnn_cells(cast_fn, handle, verbose):
# Different module + function names in old/new RNN cases
if has_old_rnns():
fn_names = ['RNNReLUCell', 'RNNTanhCell', 'LSTMCell', 'GRUCell']
......@@ -40,7 +40,7 @@ def whitelist_rnn_cells(handle, verbose):
# Insert casts on cell functions
for fn in fn_names:
wrap.cached_cast(mod, fn, utils.maybe_half, handle,
wrap.cached_cast(mod, fn, cast_fn, handle,
try_caching=True, verbose=verbose)
if has_old_rnns():
......
......@@ -200,22 +200,28 @@ def synthesize_flattened_rnn_weights(fp32_weights,
fp16_weights.append(fp16_layer_weights)
return fp16_weights
def _str_from_dtype(dtype=torch.float16):
type_to_str = {torch.float16 : 'Half',
torch.bfloat16 : 'BFloat16'}
return type_to_str[dtype]
# Roughly same as above, just the `fp32_weights` aren't nested.
# Code kept separate for readability.
def new_synthesize_flattened_rnn_weights(fp32_weights,
fp16_flat_tensor,
rnn_fn='',
dtype=torch.float16,
verbose=False):
fp16_weights = []
fp32_base_ptr = fp32_weights[0].data_ptr()
for w_fp32 in fp32_weights:
w_fp16 = w_fp32.new().half()
w_fp16 = w_fp32.new().to(dtype=dtype)
offset = (w_fp32.data_ptr() - fp32_base_ptr) // w_fp32.element_size()
w_fp16.set_(fp16_flat_tensor.storage(),
offset,
w_fp32.shape)
w_fp16.copy_(w_fp32)
if verbose:
print('Float->Half ({})'.format(rnn_fn))
print('Float->{} ({})'.format(_str_from_dtype(dtype), rnn_fn))
fp16_weights.append(w_fp16)
return fp16_weights
......@@ -51,7 +51,8 @@ def make_promote_wrapper(orig_fn, cast_fn, handle=None):
if len(types) <= 1:
return orig_fn(*args, **kwargs)
elif len(types) == 2 and types == set(['HalfTensor', 'FloatTensor']):
elif len(types) == 2 and (types == set(['HalfTensor', 'FloatTensor'])
or types == set(['BFloat16Tensor', 'FloatTensor'])):
new_args = utils.casted_args(cast_fn,
args,
kwargs)
......@@ -79,7 +80,8 @@ def sequence_promote(mod, fn, handle, verbose=False):
types = set([utils.type_string(x) for x in seq])
if len(types) <= 1:
return orig_fn(seq, *args, **kwargs)
elif types == set(['HalfTensor', 'FloatTensor']):
elif (types == set(['HalfTensor', 'FloatTensor']) or
types == set(['BFloat16Tensor', 'FloatTensor'])):
cast_seq = utils.casted_args(maybe_float,
seq, {})
return orig_fn(cast_seq, *args, **kwargs)
......@@ -121,12 +123,12 @@ def err_if_any_half(mod, fn, handle, custom_err_msg=None):
@functools.wraps(orig_fn)
def wrapper(*args, **kwargs):
types = utils.collect_fp_tensor_types(args, kwargs)
if 'HalfTensor' in types:
if 'HalfTensor' in types or 'BFloat16Tensor' in types:
if custom_err_msg:
raise NotImplementedError(custom_err_msg)
else:
raise NotImplementedError('Cannot call in-place function ' +
'{} with fp16 arguments.'.format(fn))
'{} with fp16 or bfloat16 args.'.format(fn))
else:
return orig_fn(*args, **kwargs)
utils.set_func_save(handle, mod, fn, wrapper)
......@@ -139,9 +141,9 @@ def err_if_arg0_half(mod, fn, handle, verbose=False):
@functools.wraps(orig_fn)
def wrapper(arg0, *args, **kwargs):
assert compat.is_tensor_like(arg0)
if utils.type_string(arg0) == 'HalfTensor':
if utils.type_string(arg0) in {'HalfTensor', 'BFloat16Tensor'}:
raise NotImplementedError('Cannot call in-place method ' +
'{} on fp16 Tensors.'.format(fn))
'{} with fp16 or bfloat16 args.'.format(fn))
else:
cast_fn = utils.verbosify(utils.maybe_float, fn, verbose)
new_args = utils.casted_args(cast_fn, args, kwargs)
......@@ -221,7 +223,7 @@ def rnn_cast(backend, fn, handle, verbose=False):
return fwd_wrapper
utils.set_func_save(handle, backend, fn, rnn_wrapper)
def new_rnn_cast(fn, handle, verbose=False):
def new_rnn_cast(fn, cast_fn, handle, verbose=False):
# Forward+backward compatibility around https://github.com/pytorch/pytorch/pull/15744
# For rnn backend calls that route through _rnn_impls, we must patch the ref
# that _rnn_impls stashed. For rnn backend calls that directly invoke
......@@ -234,7 +236,7 @@ def new_rnn_cast(fn, handle, verbose=False):
assert isinstance(mod, rnn_compat.VariableFunctionsShim)
fn = fn.lower()
orig_fn = utils.get_func(mod, fn)
cast_fn = utils.verbosify(utils.maybe_half, fn, verbose)
cast_fn = utils.verbosify(cast_fn, fn, verbose)
@functools.wraps(orig_fn)
def wrapper(*args, **kwargs):
# Exact call signature from modules/rnn.py
......@@ -249,14 +251,20 @@ def new_rnn_cast(fn, handle, verbose=False):
else:
params_idx = 3 # PackedSequence case
if cast_fn == utils.maybe_half:
dtype = torch.half
elif cast_fn == utils.maybe_bfloat16:
dtype = torch.bfloat16
else:
raise RuntimeError("Unsupported cast_fn passed. Supports only maybe_half and maybe_bfloat16")
new_args = []
for i, arg in enumerate(args):
if i == params_idx:
num_params = sum([x.numel() for x in arg])
fp16_weight_buf = args[0].new_empty((num_params,),
dtype=torch.half)
dtype=dtype)
casted_weights = utils.new_synthesize_flattened_rnn_weights(
arg, fp16_weight_buf, fn, verbose)
arg, fp16_weight_buf, fn, dtype, verbose)
new_args.append(casted_weights)
elif utils.is_fp_tensor(arg):
new_args.append(cast_fn(arg))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment