Commit de3f3fea authored by rohithkrn's avatar rohithkrn
Browse files

add bfloat16 register functions, enable rnn functions, enable promote functions

parent 6e14df49
from .amp import init, half_function, float_function, promote_function,\ from .amp import init, half_function, bfloat16_function, float_function, promote_function,\
register_half_function, register_float_function, register_promote_function register_half_function, register_bfloat16_function, register_float_function, register_promote_function
from .handle import scale_loss, disable_casts from .handle import scale_loss, disable_casts
from .frontend import initialize, state_dict, load_state_dict from .frontend import initialize, state_dict, load_state_dict
from ._amp_state import master_params, _amp_state from ._amp_state import master_params, _amp_state
...@@ -189,7 +189,7 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs ...@@ -189,7 +189,7 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
for model in models: for model in models:
# Patch the forward method to cast incoming data to the correct type, and # Patch the forward method to cast incoming data to the correct type, and
# outgoing data to float32, so "the user never needs to call .half()." # outgoing data to float32, so "the user never needs to call .half()/.bfloat16()."
# I like writing things explicitly more than decorators. # I like writing things explicitly more than decorators.
def patch_forward(old_fwd): def patch_forward(old_fwd):
def new_fwd(*args, **kwargs): def new_fwd(*args, **kwargs):
......
...@@ -213,8 +213,8 @@ def lazy_init_no_master_weights(self): ...@@ -213,8 +213,8 @@ def lazy_init_no_master_weights(self):
elif param.type() == 'torch.cuda.FloatTensor': elif param.type() == 'torch.cuda.FloatTensor':
stash.all_fp32_params.append(param) stash.all_fp32_params.append(param)
else: else:
raise TypeError("Optimizer's parameters must be either " raise TypeError("Optimizer's parameters must be one of "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. " "torch.cuda.FloatTensor, torch.cuda.HalfTensor, torch.BFloat16Tensor. "
"Received {}".format(param.type())) "Received {}".format(param.type()))
stash.all_fp16_grad_stash = [None for _ in stash.all_fp16_params] stash.all_fp16_grad_stash = [None for _ in stash.all_fp16_params]
......
...@@ -30,6 +30,9 @@ def half_function(fn): ...@@ -30,6 +30,9 @@ def half_function(fn):
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=True) wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=True)
return _decorator_helper(fn, utils.maybe_half, wrap_fn) return _decorator_helper(fn, utils.maybe_half, wrap_fn)
def bfloat16_function(fn):
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=True)
return _decorator_helper(fn, utils.maybe_bfloat16, wrap_fn)
def float_function(fn): def float_function(fn):
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=False) wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=False)
...@@ -48,6 +51,11 @@ def register_half_function(module, name): ...@@ -48,6 +51,11 @@ def register_half_function(module, name):
name, module)) name, module))
_USER_CAST_REGISTRY.add((module, name, utils.maybe_half)) _USER_CAST_REGISTRY.add((module, name, utils.maybe_half))
def register_bfloat16_function(module, name):
if not hasattr(module, name):
raise ValueError('No function named {} in module {}.'.format(
name, module))
_USER_CAST_REGISTRY.add((module, name, utils.maybe_bfloat16))
def register_float_function(module, name): def register_float_function(module, name):
if not hasattr(module, name): if not hasattr(module, name):
...@@ -116,11 +124,11 @@ def init(enabled=True, loss_scale="dynamic", patch_type=torch.float16, enable_ca ...@@ -116,11 +124,11 @@ def init(enabled=True, loss_scale="dynamic", patch_type=torch.float16, enable_ca
# 1.5) Pre-0.4, put the blacklist methods on HalfTensor and whitelist # 1.5) Pre-0.4, put the blacklist methods on HalfTensor and whitelist
# methods on FloatTensor, since they're distinct types. # methods on FloatTensor, since they're distinct types.
if compat.tensor_is_float_tensor(): if compat.tensor_is_float_tensor():
for fn in getattr(tensor_overrides, low_prec_funcs): for fn in getattr(tensor_overrides, 'FP16_FUNCS'):
wrap.cached_cast(torch.cuda.FloatTensor, fn, utils.maybe_low_prec, wrap.cached_cast(torch.cuda.FloatTensor, fn, utils.maybe_half,
handle, try_caching=True, verbose=verbose) handle, try_caching=True, verbose=verbose)
for fn in tensor_overrides.FP32_FUNCS: for fn in tensor_overrides.FP32_FUNCS:
wrap.cached_cast(low_prec_tensor, fn, utils.maybe_float, wrap.cached_cast(torch.cuda.HalfTensor, fn, utils.maybe_float,
handle, try_caching=False, verbose=verbose) handle, try_caching=False, verbose=verbose)
# 2) Enable type-promotion on multi-arg functions and methods. # 2) Enable type-promotion on multi-arg functions and methods.
...@@ -136,17 +144,17 @@ def init(enabled=True, loss_scale="dynamic", patch_type=torch.float16, enable_ca ...@@ -136,17 +144,17 @@ def init(enabled=True, loss_scale="dynamic", patch_type=torch.float16, enable_ca
# 2.5) Pre-0.4, add blacklist methods directly to HalfTensor and FloatTensor types # 2.5) Pre-0.4, add blacklist methods directly to HalfTensor and FloatTensor types
if compat.tensor_is_float_tensor(): if compat.tensor_is_float_tensor():
for cls, (list_name, promote_fn) in itertools.product([torch.cuda.FloatTensor, for cls, (list_name, promote_fn) in itertools.product([torch.cuda.FloatTensor,
low_prec_tensor], torch.cuda.HalfTensor],
promote_table): promote_table):
for fn in getattr(tensor_overrides, list_name): for fn in getattr(tensor_overrides, list_name):
promote_fn(cls, fn, handle, verbose) promote_fn(cls, fn, handle, verbose)
# 3) For any in-place version of a blacklist function, error if any input is fp16. # 3) For any in-place version of a blacklist function, error if any input is fp16/bfloat16.
# NB: this is overly conservative. # NB: this is overly conservative.
for fn in utils.as_inplace(torch_overrides.FP32_FUNCS): for fn in utils.as_inplace(torch_overrides.FP32_FUNCS):
wrap.err_if_any_half(torch_overrides.MODULE, fn, handle) wrap.err_if_any_half(torch_overrides.MODULE, fn, handle)
# 3.5) For any in-place blacklist method, error if called on fp16 tensor # 3.5) For any in-place blacklist method, error if called on fp16/bfloat16 tensor
for fn in utils.as_inplace(tensor_overrides.FP32_FUNCS): for fn in utils.as_inplace(tensor_overrides.FP32_FUNCS):
wrap.err_if_arg0_half(tensor_overrides.MODULE, fn, handle, verbose) wrap.err_if_arg0_half(tensor_overrides.MODULE, fn, handle, verbose)
if compat.tensor_is_float_tensor(): if compat.tensor_is_float_tensor():
...@@ -158,7 +166,7 @@ def init(enabled=True, loss_scale="dynamic", patch_type=torch.float16, enable_ca ...@@ -158,7 +166,7 @@ def init(enabled=True, loss_scale="dynamic", patch_type=torch.float16, enable_ca
tensor_overrides.CASTS)): tensor_overrides.CASTS)):
wrap.promote_match_arg0(tensor_overrides.MODULE, fn, handle, verbose) wrap.promote_match_arg0(tensor_overrides.MODULE, fn, handle, verbose)
if compat.tensor_is_float_tensor(): if compat.tensor_is_float_tensor():
wrap.promote_match_arg0(low_prec_tensor, fn, handle, verbose) wrap.promote_match_arg0(torch.cuda.HalfTensor, fn, handle, verbose)
wrap.promote_match_arg0(torch.cuda.FloatTensor, fn, handle, verbose) wrap.promote_match_arg0(torch.cuda.FloatTensor, fn, handle, verbose)
# 5) RNNs + RNN cells are whitelisted specially # 5) RNNs + RNN cells are whitelisted specially
...@@ -169,10 +177,10 @@ def init(enabled=True, loss_scale="dynamic", patch_type=torch.float16, enable_ca ...@@ -169,10 +177,10 @@ def init(enabled=True, loss_scale="dynamic", patch_type=torch.float16, enable_ca
torch.nn.modules.rnn._VF = rnn_compat.VariableFunctionsShim() torch.nn.modules.rnn._VF = rnn_compat.VariableFunctionsShim()
# Wrap all the rnns # Wrap all the rnns
for x in rnn_compat.RNN_NAMES: for x in rnn_compat.RNN_NAMES:
wrap.new_rnn_cast(x.upper(), handle, verbose) wrap.new_rnn_cast(x.upper(), maybe_low_prec, handle, verbose)
# Wrap all the RNN cells # Wrap all the RNN cells
rnn_compat.whitelist_rnn_cells(handle, verbose) rnn_compat.whitelist_rnn_cells(maybe_low_prec, handle, verbose)
# 6) Place error+print message on banned functions. # 6) Place error+print message on banned functions.
# Or, if allow_banned, then cast to FP32. # Or, if allow_banned, then cast to FP32.
......
...@@ -16,6 +16,9 @@ class Properties(object): ...@@ -16,6 +16,9 @@ class Properties(object):
"opt_level" : None, "opt_level" : None,
"cast_model_type" : None, "cast_model_type" : None,
"patch_torch_functions" : False, "patch_torch_functions" : False,
# TODO: patch_torch_functions_type could probably be unified with
# patch_torch_functions. Currently introducing a new attribute
# to be on the safer side and not break stuff.
"patch_torch_functions_type" : None, "patch_torch_functions_type" : None,
"keep_batchnorm_fp32" : None, "keep_batchnorm_fp32" : None,
"master_weights" : None, "master_weights" : None,
...@@ -390,7 +393,7 @@ def initialize( ...@@ -390,7 +393,7 @@ def initialize(
maybe_print("Selected optimization level {}".format(opt_levels[opt_level].brief), True) maybe_print("Selected optimization level {}".format(opt_levels[opt_level].brief), True)
maybe_print("Defaults for this optimization level are:", True) maybe_print("Defaults for this optimization level are:", True)
for k, v in _amp_state.opt_properties.options.items(): for k, v in _amp_state.opt_properties.options.items():
maybe_print("{:22} : {}".format(k, v), True) maybe_print("{:26} : {}".format(k, v), True)
_amp_state.min_loss_scale = min_loss_scale _amp_state.min_loss_scale = min_loss_scale
_amp_state.max_loss_scale = max_loss_scale _amp_state.max_loss_scale = max_loss_scale
...@@ -417,7 +420,7 @@ def initialize( ...@@ -417,7 +420,7 @@ def initialize(
maybe_print("After processing overrides, optimization options are:", True) maybe_print("After processing overrides, optimization options are:", True)
for k, v in _amp_state.opt_properties.options.items(): for k, v in _amp_state.opt_properties.options.items():
maybe_print("{:22} : {}".format(k, v), True) maybe_print("{:26} : {}".format(k, v), True)
return _initialize(models, optimizers, _amp_state.opt_properties, num_losses, cast_model_outputs) return _initialize(models, optimizers, _amp_state.opt_properties, num_losses, cast_model_outputs)
......
...@@ -28,7 +28,7 @@ def has_old_rnns(): ...@@ -28,7 +28,7 @@ def has_old_rnns():
except: except:
return False return False
def whitelist_rnn_cells(handle, verbose): def whitelist_rnn_cells(cast_fn, handle, verbose):
# Different module + function names in old/new RNN cases # Different module + function names in old/new RNN cases
if has_old_rnns(): if has_old_rnns():
fn_names = ['RNNReLUCell', 'RNNTanhCell', 'LSTMCell', 'GRUCell'] fn_names = ['RNNReLUCell', 'RNNTanhCell', 'LSTMCell', 'GRUCell']
...@@ -40,7 +40,7 @@ def whitelist_rnn_cells(handle, verbose): ...@@ -40,7 +40,7 @@ def whitelist_rnn_cells(handle, verbose):
# Insert casts on cell functions # Insert casts on cell functions
for fn in fn_names: for fn in fn_names:
wrap.cached_cast(mod, fn, utils.maybe_half, handle, wrap.cached_cast(mod, fn, cast_fn, handle,
try_caching=True, verbose=verbose) try_caching=True, verbose=verbose)
if has_old_rnns(): if has_old_rnns():
......
...@@ -200,22 +200,28 @@ def synthesize_flattened_rnn_weights(fp32_weights, ...@@ -200,22 +200,28 @@ def synthesize_flattened_rnn_weights(fp32_weights,
fp16_weights.append(fp16_layer_weights) fp16_weights.append(fp16_layer_weights)
return fp16_weights return fp16_weights
def _str_from_dtype(dtype=torch.float16):
type_to_str = {torch.float16 : 'Half',
torch.bfloat16 : 'BFloat16'}
return type_to_str[dtype]
# Roughly same as above, just the `fp32_weights` aren't nested. # Roughly same as above, just the `fp32_weights` aren't nested.
# Code kept separate for readability. # Code kept separate for readability.
def new_synthesize_flattened_rnn_weights(fp32_weights, def new_synthesize_flattened_rnn_weights(fp32_weights,
fp16_flat_tensor, fp16_flat_tensor,
rnn_fn='', rnn_fn='',
dtype=torch.float16,
verbose=False): verbose=False):
fp16_weights = [] fp16_weights = []
fp32_base_ptr = fp32_weights[0].data_ptr() fp32_base_ptr = fp32_weights[0].data_ptr()
for w_fp32 in fp32_weights: for w_fp32 in fp32_weights:
w_fp16 = w_fp32.new().half() w_fp16 = w_fp32.new().to(dtype=dtype)
offset = (w_fp32.data_ptr() - fp32_base_ptr) // w_fp32.element_size() offset = (w_fp32.data_ptr() - fp32_base_ptr) // w_fp32.element_size()
w_fp16.set_(fp16_flat_tensor.storage(), w_fp16.set_(fp16_flat_tensor.storage(),
offset, offset,
w_fp32.shape) w_fp32.shape)
w_fp16.copy_(w_fp32) w_fp16.copy_(w_fp32)
if verbose: if verbose:
print('Float->Half ({})'.format(rnn_fn)) print('Float->{} ({})'.format(_str_from_dtype(dtype), rnn_fn))
fp16_weights.append(w_fp16) fp16_weights.append(w_fp16)
return fp16_weights return fp16_weights
...@@ -51,7 +51,8 @@ def make_promote_wrapper(orig_fn, cast_fn, handle=None): ...@@ -51,7 +51,8 @@ def make_promote_wrapper(orig_fn, cast_fn, handle=None):
if len(types) <= 1: if len(types) <= 1:
return orig_fn(*args, **kwargs) return orig_fn(*args, **kwargs)
elif len(types) == 2 and types == set(['HalfTensor', 'FloatTensor']): elif len(types) == 2 and (types == set(['HalfTensor', 'FloatTensor'])
or types == set(['BFloat16Tensor', 'FloatTensor'])):
new_args = utils.casted_args(cast_fn, new_args = utils.casted_args(cast_fn,
args, args,
kwargs) kwargs)
...@@ -79,7 +80,8 @@ def sequence_promote(mod, fn, handle, verbose=False): ...@@ -79,7 +80,8 @@ def sequence_promote(mod, fn, handle, verbose=False):
types = set([utils.type_string(x) for x in seq]) types = set([utils.type_string(x) for x in seq])
if len(types) <= 1: if len(types) <= 1:
return orig_fn(seq, *args, **kwargs) return orig_fn(seq, *args, **kwargs)
elif types == set(['HalfTensor', 'FloatTensor']): elif (types == set(['HalfTensor', 'FloatTensor']) or
types == set(['BFloat16Tensor', 'FloatTensor'])):
cast_seq = utils.casted_args(maybe_float, cast_seq = utils.casted_args(maybe_float,
seq, {}) seq, {})
return orig_fn(cast_seq, *args, **kwargs) return orig_fn(cast_seq, *args, **kwargs)
...@@ -121,12 +123,12 @@ def err_if_any_half(mod, fn, handle, custom_err_msg=None): ...@@ -121,12 +123,12 @@ def err_if_any_half(mod, fn, handle, custom_err_msg=None):
@functools.wraps(orig_fn) @functools.wraps(orig_fn)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
types = utils.collect_fp_tensor_types(args, kwargs) types = utils.collect_fp_tensor_types(args, kwargs)
if 'HalfTensor' in types: if 'HalfTensor' in types or 'BFloat16Tensor' in types:
if custom_err_msg: if custom_err_msg:
raise NotImplementedError(custom_err_msg) raise NotImplementedError(custom_err_msg)
else: else:
raise NotImplementedError('Cannot call in-place function ' + raise NotImplementedError('Cannot call in-place function ' +
'{} with fp16 arguments.'.format(fn)) '{} with fp16 or bfloat16 args.'.format(fn))
else: else:
return orig_fn(*args, **kwargs) return orig_fn(*args, **kwargs)
utils.set_func_save(handle, mod, fn, wrapper) utils.set_func_save(handle, mod, fn, wrapper)
...@@ -139,9 +141,9 @@ def err_if_arg0_half(mod, fn, handle, verbose=False): ...@@ -139,9 +141,9 @@ def err_if_arg0_half(mod, fn, handle, verbose=False):
@functools.wraps(orig_fn) @functools.wraps(orig_fn)
def wrapper(arg0, *args, **kwargs): def wrapper(arg0, *args, **kwargs):
assert compat.is_tensor_like(arg0) assert compat.is_tensor_like(arg0)
if utils.type_string(arg0) == 'HalfTensor': if utils.type_string(arg0) in {'HalfTensor', 'BFloat16Tensor'}:
raise NotImplementedError('Cannot call in-place method ' + raise NotImplementedError('Cannot call in-place method ' +
'{} on fp16 Tensors.'.format(fn)) '{} with fp16 or bfloat16 args.'.format(fn))
else: else:
cast_fn = utils.verbosify(utils.maybe_float, fn, verbose) cast_fn = utils.verbosify(utils.maybe_float, fn, verbose)
new_args = utils.casted_args(cast_fn, args, kwargs) new_args = utils.casted_args(cast_fn, args, kwargs)
...@@ -221,7 +223,7 @@ def rnn_cast(backend, fn, handle, verbose=False): ...@@ -221,7 +223,7 @@ def rnn_cast(backend, fn, handle, verbose=False):
return fwd_wrapper return fwd_wrapper
utils.set_func_save(handle, backend, fn, rnn_wrapper) utils.set_func_save(handle, backend, fn, rnn_wrapper)
def new_rnn_cast(fn, handle, verbose=False): def new_rnn_cast(fn, cast_fn, handle, verbose=False):
# Forward+backward compatibility around https://github.com/pytorch/pytorch/pull/15744 # Forward+backward compatibility around https://github.com/pytorch/pytorch/pull/15744
# For rnn backend calls that route through _rnn_impls, we must patch the ref # For rnn backend calls that route through _rnn_impls, we must patch the ref
# that _rnn_impls stashed. For rnn backend calls that directly invoke # that _rnn_impls stashed. For rnn backend calls that directly invoke
...@@ -234,7 +236,7 @@ def new_rnn_cast(fn, handle, verbose=False): ...@@ -234,7 +236,7 @@ def new_rnn_cast(fn, handle, verbose=False):
assert isinstance(mod, rnn_compat.VariableFunctionsShim) assert isinstance(mod, rnn_compat.VariableFunctionsShim)
fn = fn.lower() fn = fn.lower()
orig_fn = utils.get_func(mod, fn) orig_fn = utils.get_func(mod, fn)
cast_fn = utils.verbosify(utils.maybe_half, fn, verbose) cast_fn = utils.verbosify(cast_fn, fn, verbose)
@functools.wraps(orig_fn) @functools.wraps(orig_fn)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
# Exact call signature from modules/rnn.py # Exact call signature from modules/rnn.py
...@@ -249,14 +251,20 @@ def new_rnn_cast(fn, handle, verbose=False): ...@@ -249,14 +251,20 @@ def new_rnn_cast(fn, handle, verbose=False):
else: else:
params_idx = 3 # PackedSequence case params_idx = 3 # PackedSequence case
if cast_fn == utils.maybe_half:
dtype = torch.half
elif cast_fn == utils.maybe_bfloat16:
dtype = torch.bfloat16
else:
raise RuntimeError("Unsupported cast_fn passed. Supports only maybe_half and maybe_bfloat16")
new_args = [] new_args = []
for i, arg in enumerate(args): for i, arg in enumerate(args):
if i == params_idx: if i == params_idx:
num_params = sum([x.numel() for x in arg]) num_params = sum([x.numel() for x in arg])
fp16_weight_buf = args[0].new_empty((num_params,), fp16_weight_buf = args[0].new_empty((num_params,),
dtype=torch.half) dtype=dtype)
casted_weights = utils.new_synthesize_flattened_rnn_weights( casted_weights = utils.new_synthesize_flattened_rnn_weights(
arg, fp16_weight_buf, fn, verbose) arg, fp16_weight_buf, fn, dtype, verbose)
new_args.append(casted_weights) new_args.append(casted_weights)
elif utils.is_fp_tensor(arg): elif utils.is_fp_tensor(arg):
new_args.append(cast_fn(arg)) new_args.append(cast_fn(arg))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment