Unverified Commit b2da92fc authored by Peng's avatar Peng Committed by GitHub
Browse files

Merge pull request #5 from rohithkrn/apex_amp_bfp16

Introduce new optimization levels for BFloat16 training
parents 65490af6 e1267a9a
from .amp import init, half_function, float_function, promote_function,\
register_half_function, register_float_function, register_promote_function
from .amp import init, half_function, bfloat16_function, float_function, promote_function,\
register_half_function, register_bfloat16_function, register_float_function, register_promote_function
from .handle import scale_loss, disable_casts
from .frontend import initialize, state_dict, load_state_dict
from ._amp_state import master_params, _amp_state
......@@ -80,10 +80,10 @@ def check_params_fp32(models):
for model in models:
for name, param in model.named_parameters():
if param.is_floating_point():
if 'Half' in param.type():
if 'Half' in param.type() or 'BFloat16' in param.type():
warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.initialize, you do not need to call .half() on your model\n"
"before passing it, no matter what optimization level you choose.".format(
"When using amp.initialize, you do not need to call .half() or .bfloat16()\n"
"on your model before passing it, no matter what optimization level you choose.".format(
name, param.type()))
elif not param.is_cuda:
warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
......@@ -137,7 +137,7 @@ class O2StateDictHook(object):
def __call__(self, module, state_dict, prefix, local_metadata):
for key in state_dict:
param = state_dict[key]
if 'Half' in param.type():
if 'Half' in param.type() or 'BFloat16' in param.type():
param = param.to(torch.float32)
state_dict[key] = param
......@@ -189,7 +189,7 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
for model in models:
# Patch the forward method to cast incoming data to the correct type, and
# outgoing data to float32, so "the user never needs to call .half()."
# outgoing data to float32, so "the user never needs to call .half()/.bfloat16()."
# I like writing things explicitly more than decorators.
def patch_forward(old_fwd):
def new_fwd(*args, **kwargs):
......@@ -232,7 +232,9 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
if properties.patch_torch_functions:
# handle is unused here. It's accessible later through a global value anyway.
handle = amp_init(loss_scale=properties.loss_scale, verbose=(_amp_state.verbosity == 2))
handle = amp_init(loss_scale=properties.loss_scale,
patch_type=properties.patch_torch_functions_type,
verbose=(_amp_state.verbosity == 2))
for optimizer in optimizers:
# Disable Amp casting for the optimizer step, because it should only be
# applied to FP32 master params anyway.
......
import types
from ..fp16_utils import master_params_to_model_params
from ..multi_tensor_apply import multi_tensor_applier
from ._amp_state import maybe_print
from ._amp_state import maybe_print, _amp_state
import torch
from ..optimizers import FusedSGD
......@@ -37,7 +37,7 @@ def lazy_init_with_master_weights(self):
fp32_from_fp16_params_this_group = []
for i, param in enumerate(param_group['params']):
if param.requires_grad:
if param.type() == 'torch.cuda.HalfTensor':
if param.type() in {'torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor'}:
# maybe_print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
# .format(param.size()))
fp16_params_this_group.append(param)
......@@ -55,8 +55,8 @@ def lazy_init_with_master_weights(self):
fp32_params_this_group.append(param)
param_group['params'][i] = param
else:
raise TypeError("Optimizer's parameters must be either "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
raise TypeError("Optimizer's parameters must one of "
"torch.cuda.FloatTensor, torch.cuda.HalfTensor, torch.cuda.BFloat16Tensor. "
"Received {}".format(param.type()))
stash.fp16_groups.append(fp16_params_this_group)
......@@ -208,13 +208,13 @@ def lazy_init_no_master_weights(self):
stash.all_fp32_params = []
for i, param_group in enumerate(self.param_groups):
for i, param in enumerate(param_group['params']):
if param.type() == 'torch.cuda.HalfTensor':
if param.type() in {'torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor'}:
stash.all_fp16_params.append(param)
elif param.type() == 'torch.cuda.FloatTensor':
stash.all_fp32_params.append(param)
else:
raise TypeError("Optimizer's parameters must be either "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
raise TypeError("Optimizer's parameters must be one of "
"torch.cuda.FloatTensor, torch.cuda.HalfTensor, torch.BFloat16Tensor. "
"Received {}".format(param.type()))
stash.all_fp16_grad_stash = [None for _ in stash.all_fp16_params]
......@@ -435,7 +435,7 @@ def _process_optimizer(optimizer, properties):
fp32_from_fp16_params_this_group = []
for i, param in enumerate(new_group['params']):
if param.requires_grad:
if param.type() == 'torch.cuda.HalfTensor':
if param.type() in {'torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor'}:
fp16_params_this_group.append(param)
master_param = param.detach().clone().float()
master_param.requires_grad = True
......@@ -445,8 +445,8 @@ def _process_optimizer(optimizer, properties):
fp32_params_this_group.append(param)
new_group['params'][i] = param
else:
raise TypeError("Optimizer's parameters must be either "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
raise TypeError("Optimizer's parameters must be one of "
"torch.cuda.FloatTensor, torch.cuda.HalfTensor, torch.cuda.BFloat16Tensor. "
"Received {}".format(param.type()))
stash.fp16_groups.append(fp16_params_this_group)
......@@ -471,15 +471,15 @@ def _process_optimizer(optimizer, properties):
# param.grad = None
else:
for param in new_group['params']:
if param.type() == 'torch.cuda.HalfTensor':
if param.type() in {'torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor'}:
stash.all_fp16_params.append(param)
stash.all_fp16_grad_stash.append(None)
elif param.type() == 'torch.cuda.FloatTensor':
stash.all_fp32_params.append(param)
stash.all_fp32_grad_stash.append(None)
else:
raise TypeError("Optimizer's parameters must be either "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
raise TypeError("Optimizer's parameters must one of "
"torch.cuda.FloatTensor, torch.cuda.HalfTensor, torch.cuda.BFloat16Tensor. "
"Received {}".format(param.type()))
old_add_param_group(new_group)
......
......@@ -9,7 +9,6 @@ import itertools
import torch
_DECORATOR_HANDLE = None
_USER_CAST_REGISTRY = set()
_USER_PROMOTE_REGISTRY = set()
......@@ -31,6 +30,9 @@ def half_function(fn):
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=True)
return _decorator_helper(fn, utils.maybe_half, wrap_fn)
def bfloat16_function(fn):
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=True)
return _decorator_helper(fn, utils.maybe_bfloat16, wrap_fn)
def float_function(fn):
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=False)
......@@ -49,6 +51,11 @@ def register_half_function(module, name):
name, module))
_USER_CAST_REGISTRY.add((module, name, utils.maybe_half))
def register_bfloat16_function(module, name):
if not hasattr(module, name):
raise ValueError('No function named {} in module {}.'.format(
name, module))
_USER_CAST_REGISTRY.add((module, name, utils.maybe_bfloat16))
def register_float_function(module, name):
if not hasattr(module, name):
......@@ -65,7 +72,7 @@ def register_promote_function(module, name):
# Top-level function to insert _all_ the hooks.
def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False, allow_banned=False):
def init(enabled=True, loss_scale="dynamic", patch_type=torch.float16, enable_caching=True, verbose=False, allow_banned=False):
global _DECORATOR_HANDLE
if not enabled:
......@@ -87,16 +94,30 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False,
wrap.promote(mod, fn, handle, verbose)
_USER_PROMOTE_REGISTRY.clear()
# conditionally choose between fp16 and bfloat16 functions list to cache
if patch_type == torch.float16:
low_prec_funcs = 'FP16_FUNCS'
maybe_low_prec = utils.maybe_half
low_prec_tensor = torch.cuda.HalfTensor
elif patch_type == torch.bfloat16:
low_prec_funcs = 'BFLOAT16_FUNCS'
maybe_low_prec = utils.maybe_bfloat16
low_prec_tensor = torch.cuda.BFloat16Tensor
else:
raise RuntimeError("Unsupported patch_torch_functions_type passed to initialize." +
"Supported types are: torch.float16 and torch.bfloat16.")
# 1) Force-{fp16, fp32} on white- / black-list functions
override_modules = [functional_overrides,
torch_overrides,
tensor_overrides]
cast_table = [('FP16_FUNCS', utils.maybe_half),
cast_table = [(low_prec_funcs, maybe_low_prec),
('FP32_FUNCS', utils.maybe_float)]
for module, (list_name, cast_fn) in itertools.product(override_modules,
cast_table):
for fn in getattr(module, list_name):
try_caching = (cast_fn == utils.maybe_half)
try_caching = (cast_fn == maybe_low_prec)
wrap.cached_cast(module.MODULE, fn, cast_fn, handle,
try_caching, verbose)
......@@ -128,12 +149,12 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False,
for fn in getattr(tensor_overrides, list_name):
promote_fn(cls, fn, handle, verbose)
# 3) For any in-place version of a blacklist function, error if any input is fp16.
# 3) For any in-place version of a blacklist function, error if any input is fp16/bfloat16.
# NB: this is overly conservative.
for fn in utils.as_inplace(torch_overrides.FP32_FUNCS):
wrap.err_if_any_half(torch_overrides.MODULE, fn, handle)
# 3.5) For any in-place blacklist method, error if called on fp16 tensor
# 3.5) For any in-place blacklist method, error if called on fp16/bfloat16 tensor
for fn in utils.as_inplace(tensor_overrides.FP32_FUNCS):
wrap.err_if_arg0_half(tensor_overrides.MODULE, fn, handle, verbose)
if compat.tensor_is_float_tensor():
......@@ -141,7 +162,7 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False,
# 4) For other in-place methods, match the type of self tensor
for fn in utils.as_inplace(itertools.chain(
tensor_overrides.FP16_FUNCS,
getattr(tensor_overrides, low_prec_funcs),
tensor_overrides.CASTS)):
wrap.promote_match_arg0(tensor_overrides.MODULE, fn, handle, verbose)
if compat.tensor_is_float_tensor():
......@@ -156,10 +177,10 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False,
torch.nn.modules.rnn._VF = rnn_compat.VariableFunctionsShim()
# Wrap all the rnns
for x in rnn_compat.RNN_NAMES:
wrap.new_rnn_cast(x.upper(), handle, verbose)
wrap.new_rnn_cast(x.upper(), maybe_low_prec, handle, verbose)
# Wrap all the RNN cells
rnn_compat.whitelist_rnn_cells(handle, verbose)
rnn_compat.whitelist_rnn_cells(maybe_low_prec, handle, verbose)
# 6) Place error+print message on banned functions.
# Or, if allow_banned, then cast to FP32.
......
......@@ -28,7 +28,8 @@ def is_floating_point(x):
torch_type = x.type()
return torch_type.endswith('FloatTensor') or \
torch_type.endswith('HalfTensor') or \
torch_type.endswith('DoubleTensor')
torch_type.endswith('DoubleTensor') or \
torch_type.endswith('BFloat16Tensor')
except AttributeError:
return False
......
......@@ -16,6 +16,10 @@ class Properties(object):
"opt_level" : None,
"cast_model_type" : None,
"patch_torch_functions" : False,
# TODO: patch_torch_functions_type could probably be unified with
# patch_torch_functions. Currently introducing a new attribute
# to be on the safer side and not break stuff.
"patch_torch_functions_type" : None,
"keep_batchnorm_fp32" : None,
"master_weights" : None,
"loss_scale" : 1.0,
......@@ -53,7 +57,7 @@ class Properties(object):
if name in self.options:
# print("setting {} {}".format(name, value))
if name == "cast_model_type":
if self.opt_level == "O1" and value is not None:
if self.opt_level in {"O1", "O4"} and value is not None:
if value is not False:
if value is not torch.float32:
warn_or_err("O1 inserts casts around Torch functions rather than "
......@@ -63,13 +67,25 @@ class Properties(object):
"cast_model_type was {}".format(value))
self.options[name] = value
elif name == "patch_torch_functions":
if self.opt_level != "O1" and value:
if self.opt_level not in {"O1", "O4"} and value:
warn_or_err("Currently, patch_torch_functions=True should only be set by "
"selecting opt_level='O1'.")
"selecting opt_level='O1' or 'O4'.")
self.options[name] = value
elif name == "patch_torch_functions_type":
if self.opt_level not in {"O1", "O4"} and value is not None:
warn_or_err("Currently, patch_torch_functions_type should only be set by "
"selecting opt_level='O1' or 'O4'.")
elif self.opt_level == "O1" and value != torch.float16:
warn_or_err("patch_torch_functions_type should only be set to torch.float16 "
"for opt_level='O1.")
elif self.opt_level == "O4" and value != torch.bfloat16:
warn_or_err("patch_torch_functions_type should only be set to torch.bfloat16 "
"for opt_level='O4.")
else:
self.options[name] = value
elif name == "keep_batchnorm_fp32":
if self.opt_level == "O1" and value is not None:
warn_or_err("With opt_level O1, batchnorm functions are automatically patched "
if self.opt_level in {"O1", "O4"} and value is not None:
warn_or_err("With opt_level O1 or O4, batchnorm functions are automatically patched "
"to run in FP32, so keep_batchnorm_fp32 should be None." +
" keep_batchnorm_fp32 was {}".format(value))
if value == "False":
......@@ -82,9 +98,9 @@ class Properties(object):
"or None, found keep_batchnorm_fp32={}".format(value)
self.options[name] = value
elif name == "master_weights":
if self.opt_level == "O1" and value is not None:
warn_or_err("It doesn't make sense to use master_weights with O1. "
"With O1, your model weights themselves should be FP32.")
if self.opt_level in {"O1", "O4"} and value is not None:
warn_or_err("It doesn't make sense to use master_weights with O1 and O4 . "
"With O1 and O4, your model weights themselves should be FP32.")
self.options[name] = value
elif name == "loss_scale":
if value == "dynamic":
......@@ -113,6 +129,7 @@ class O3:
properties.opt_level = "O3"
properties.cast_model_type = torch.float16
properties.patch_torch_functions = False
properties.patch_torch_functions_type = None
properties.keep_batchnorm_fp32 = False
properties.master_weights = False
properties.loss_scale = 1.0
......@@ -136,6 +153,7 @@ class O2:
properties.opt_level = "O2"
properties.cast_model_type = torch.float16
properties.patch_torch_functions = False
properties.patch_torch_functions_type = None
properties.keep_batchnorm_fp32 = True
properties.master_weights = True
properties.loss_scale = "dynamic"
......@@ -158,6 +176,7 @@ class O1:
properties.opt_level = "O1"
properties.cast_model_type = None
properties.patch_torch_functions = True
properties.patch_torch_functions_type = torch.float16
properties.keep_batchnorm_fp32 = None
properties.master_weights = None
properties.loss_scale = "dynamic"
......@@ -177,6 +196,7 @@ class O0:
properties.opt_level = "O0"
properties.cast_model_type = torch.float32
properties.patch_torch_functions = False
properties.patch_torch_functions_type = None
properties.keep_batchnorm_fp32 = None
properties.master_weights = False
properties.loss_scale = 1.0
......@@ -184,11 +204,54 @@ class O0:
# properties.enable_ddp_interop = False
return properties # modified in place so this isn't really necessary
class O4:
brief = "O4: Insert automatic casts around Pytorch functions and Tensor methods.\n"
more = "The type of your model's weights is not altered. However, internally,\n"\
"Pytorch functions are patched to cast any Tensor Core-friendly ops to BFLOAT16 for speed,\n"\
"while operations that might benefit from the additional stability of FP32 are patched\n"\
"to cast their inputs to fp32.\n"\
"Loss scaling is not required in O4 mode since bflaot16 has the same dynamic range as fp32."
def __call__(self, properties):
properties.enabled = True
properties.opt_level = "O4"
properties.cast_model_type = None
properties.patch_torch_functions = True
properties.patch_torch_functions_type = torch.bfloat16
properties.keep_batchnorm_fp32 = None
properties.master_weights = None
properties.loss_scale = 1
return properties # modified in place so this isn't really necessary
class O5:
brief = "O5: BFLOAT16 training with FP32 batchnorm and FP32 master weights.\n"
more = "Calls .bfloat16() on your model, converting the entire model (except for batchnorms)\n"\
"to BFLOAT16. Batchnorms are retained in FP32 for additional stability.\n"\
"The forward pass is patched to cast incoming Tensors to BFLOAT16, so you don't need to change\n"\
"your data pipeline.\n"\
"O5 creates FP32 master weights outside the model and patches any optimizers to update\n"\
"these master weights, then copy the master weights into the BFLOAT16 model weights.\n"\
"Master weights can also improve convergence and stability."
def __call__(self, properties):
properties.enabled = True
properties.opt_level = "O5"
properties.cast_model_type = torch.bfloat16
properties.patch_torch_functions = False
properties.patch_torch_functions = None
properties.patch_torch_functions_type = None
properties.keep_batchnorm_fp32 = True
properties.master_weights = True
properties.loss_scale = 1
return properties # modified in place so this isn't really necessary
opt_levels = {"O3": O3(),
"O2": O2(),
"O1": O1(),
"O0": O0()}
"O0": O0(),
"O4": O4(),
"O5": O5()}
# allow user to directly pass Properties struct as well?
......@@ -199,6 +262,7 @@ def initialize(
opt_level="O1",
cast_model_type=None,
patch_torch_functions=None,
patch_torch_functions_type=None,
keep_batchnorm_fp32=None,
master_weights=None,
loss_scale=None,
......@@ -235,10 +299,11 @@ def initialize(
enabled (bool, optional, default=True): If False, renders all Amp calls no-ops, so your script
should run as if Amp were not present.
opt_level (str, optional, default="O1"): Pure or mixed precision optimization level. Accepted values are
"O0", "O1", "O2", and "O3", explained in detail above.
"O0", "O1", "O2", "O3", "O4" and "O5", explained in detail above.
cast_model_type (``torch.dtype``, optional, default=None): Optional property override, see
above.
patch_torch_functions (bool, optional, default=None): Optional property override.
patch_torch_functions_type (``torch.dtype``, optional, default=None): Optional property override
keep_batchnorm_fp32 (bool or str, optional, default=None): Optional property override. If
passed as a string, must be the string "True" or "False".
master_weights (bool, optional, default=None): Optional property override.
......@@ -321,14 +386,14 @@ def initialize(
if opt_level not in opt_levels:
raise RuntimeError(
"Unexpected optimization level {}. ".format(opt_level) +
"Options are 'O0', 'O1', 'O2', 'O3'. Note that in `O0`, `O1`, etc., the prefix O is the letter O, " +
"Options are 'O0', 'O1', 'O2', 'O3', 'O4', 'O5'. Note that in `O0`, `O1`, etc., the prefix O is the letter O, " +
"not the number zero.")
else:
_amp_state.opt_properties = opt_levels[opt_level](_amp_state.opt_properties)
maybe_print("Selected optimization level {}".format(opt_levels[opt_level].brief), True)
maybe_print("Defaults for this optimization level are:", True)
for k, v in _amp_state.opt_properties.options.items():
maybe_print("{:22} : {}".format(k, v), True)
maybe_print("{:26} : {}".format(k, v), True)
_amp_state.min_loss_scale = min_loss_scale
_amp_state.max_loss_scale = max_loss_scale
......@@ -344,6 +409,8 @@ def initialize(
_amp_state.opt_properties.cast_model_type = cast_model_type
if patch_torch_functions is not None:
_amp_state.opt_properties.patch_torch_functions = patch_torch_functions
if patch_torch_functions_type is not None:
_amp_state.opt_properties.patch_torch_functions_type = patch_torch_functions_type
if keep_batchnorm_fp32 is not None:
_amp_state.opt_properties.keep_batchnorm_fp32 = keep_batchnorm_fp32
if master_weights is not None:
......@@ -353,7 +420,7 @@ def initialize(
maybe_print("After processing overrides, optimization options are:", True)
for k, v in _amp_state.opt_properties.options.items():
maybe_print("{:22} : {}".format(k, v), True)
maybe_print("{:26} : {}".format(k, v), True)
return _initialize(models, optimizers, _amp_state.opt_properties, num_losses, cast_model_outputs)
......
......@@ -26,6 +26,17 @@ FP16_FUNCS = [
'linear',
]
BFLOAT16_FUNCS = [
'conv1d',
'conv2d',
'conv3d',
'conv_transpose1d',
'conv_transpose2d',
'conv_transpose3d',
'conv_tbc', # Undocumented / maybe new?
'linear',
]
FP32_FUNCS = [
# Interpolation/Upsampling TODO: Remove for 1.2
......
......@@ -15,6 +15,10 @@ FP16_FUNCS = [
'__matmul__',
]
BFLOAT16_FUNCS = [
'__matmul__',
]
FP32_FUNCS = [
'__ipow__',
'__pow__',
......@@ -56,7 +60,7 @@ SEQUENCE_CASTS = []
# between `torch` and `torch.Tensor` (and check with `hasattr`,
# because a few random ones aren't defined on Tensor)
_self_mod = importlib.import_module(__name__)
for attrname in ['FP16_FUNCS', 'FP32_FUNCS', 'CASTS', 'SEQUENCE_CASTS']:
for attrname in ['FP16_FUNCS', 'BFLOAT16_FUNCS', 'FP32_FUNCS', 'CASTS', 'SEQUENCE_CASTS']:
lst = getattr(_self_mod, attrname)
for fn in getattr(torch_overrides, attrname):
if hasattr(MODULE, fn):
......
......@@ -26,6 +26,27 @@ FP16_FUNCS = [
'mv',
]
BFLOAT16_FUNCS = [
# Low level functions wrapped by torch.nn layers.
# The wrapper layers contain the weights which are then passed in as a parameter
# to these functions.
'conv1d',
'conv2d',
'conv3d',
'conv_transpose1d',
'conv_transpose2d',
'conv_transpose3d',
'conv_tbc',
# BLAS
'addmm',
'addmv',
'addr',
'matmul',
'mm',
'mv',
]
FP32_FUNCS = [
# Pointwise
'acos',
......
......@@ -28,7 +28,7 @@ def has_old_rnns():
except:
return False
def whitelist_rnn_cells(handle, verbose):
def whitelist_rnn_cells(cast_fn, handle, verbose):
# Different module + function names in old/new RNN cases
if has_old_rnns():
fn_names = ['RNNReLUCell', 'RNNTanhCell', 'LSTMCell', 'GRUCell']
......@@ -40,7 +40,7 @@ def whitelist_rnn_cells(handle, verbose):
# Insert casts on cell functions
for fn in fn_names:
wrap.cached_cast(mod, fn, utils.maybe_half, handle,
wrap.cached_cast(mod, fn, cast_fn, handle,
try_caching=True, verbose=verbose)
if has_old_rnns():
......
......@@ -62,6 +62,17 @@ def maybe_half(x, name='', verbose=False):
print('Float->Half ({})'.format(name))
return x.half()
def maybe_bfloat16(x, name='', verbose=False):
if is_nested(x):
return type(x)([maybe_bfloat16(y) for y in x])
if not x.is_cuda or type_string(x) == 'BFloat16Tensor':
return x
else:
if verbose:
print('Float->BFloat16 ({})'.format(name))
return x.bfloat16()
def maybe_float(x, name='', verbose=False):
if is_nested(x):
return type(x)([maybe_float(y) for y in x])
......@@ -189,22 +200,28 @@ def synthesize_flattened_rnn_weights(fp32_weights,
fp16_weights.append(fp16_layer_weights)
return fp16_weights
def _str_from_dtype(dtype=torch.float16):
type_to_str = {torch.float16 : 'Half',
torch.bfloat16 : 'BFloat16'}
return type_to_str[dtype]
# Roughly same as above, just the `fp32_weights` aren't nested.
# Code kept separate for readability.
def new_synthesize_flattened_rnn_weights(fp32_weights,
fp16_flat_tensor,
rnn_fn='',
dtype=torch.float16,
verbose=False):
fp16_weights = []
fp32_base_ptr = fp32_weights[0].data_ptr()
for w_fp32 in fp32_weights:
w_fp16 = w_fp32.new().half()
w_fp16 = w_fp32.new().to(dtype=dtype)
offset = (w_fp32.data_ptr() - fp32_base_ptr) // w_fp32.element_size()
w_fp16.set_(fp16_flat_tensor.storage(),
offset,
w_fp32.shape)
w_fp16.copy_(w_fp32)
if verbose:
print('Float->Half ({})'.format(rnn_fn))
print('Float->{} ({})'.format(_str_from_dtype(dtype), rnn_fn))
fp16_weights.append(w_fp16)
return fp16_weights
......@@ -51,7 +51,8 @@ def make_promote_wrapper(orig_fn, cast_fn, handle=None):
if len(types) <= 1:
return orig_fn(*args, **kwargs)
elif len(types) == 2 and types == set(['HalfTensor', 'FloatTensor']):
elif len(types) == 2 and (types == set(['HalfTensor', 'FloatTensor'])
or types == set(['BFloat16Tensor', 'FloatTensor'])):
new_args = utils.casted_args(cast_fn,
args,
kwargs)
......@@ -79,7 +80,8 @@ def sequence_promote(mod, fn, handle, verbose=False):
types = set([utils.type_string(x) for x in seq])
if len(types) <= 1:
return orig_fn(seq, *args, **kwargs)
elif types == set(['HalfTensor', 'FloatTensor']):
elif (types == set(['HalfTensor', 'FloatTensor']) or
types == set(['BFloat16Tensor', 'FloatTensor'])):
cast_seq = utils.casted_args(maybe_float,
seq, {})
return orig_fn(cast_seq, *args, **kwargs)
......@@ -102,6 +104,8 @@ def promote_match_arg0(mod, fn, handle, verbose=False):
if utils.type_string(arg0) == 'HalfTensor':
cast_fn = utils.maybe_half
if utils.type_string(arg0) == 'BFloat16Tensor':
cast_fn = utils.maybe_bfloat16
elif utils.type_string(arg0) == 'FloatTensor':
cast_fn = utils.maybe_float
else:
......@@ -119,12 +123,12 @@ def err_if_any_half(mod, fn, handle, custom_err_msg=None):
@functools.wraps(orig_fn)
def wrapper(*args, **kwargs):
types = utils.collect_fp_tensor_types(args, kwargs)
if 'HalfTensor' in types:
if 'HalfTensor' in types or 'BFloat16Tensor' in types:
if custom_err_msg:
raise NotImplementedError(custom_err_msg)
else:
raise NotImplementedError('Cannot call in-place function ' +
'{} with fp16 arguments.'.format(fn))
'{} with fp16 or bfloat16 args.'.format(fn))
else:
return orig_fn(*args, **kwargs)
utils.set_func_save(handle, mod, fn, wrapper)
......@@ -137,9 +141,9 @@ def err_if_arg0_half(mod, fn, handle, verbose=False):
@functools.wraps(orig_fn)
def wrapper(arg0, *args, **kwargs):
assert compat.is_tensor_like(arg0)
if utils.type_string(arg0) == 'HalfTensor':
if utils.type_string(arg0) in {'HalfTensor', 'BFloat16Tensor'}:
raise NotImplementedError('Cannot call in-place method ' +
'{} on fp16 Tensors.'.format(fn))
'{} with fp16 or bfloat16 args.'.format(fn))
else:
cast_fn = utils.verbosify(utils.maybe_float, fn, verbose)
new_args = utils.casted_args(cast_fn, args, kwargs)
......@@ -219,7 +223,7 @@ def rnn_cast(backend, fn, handle, verbose=False):
return fwd_wrapper
utils.set_func_save(handle, backend, fn, rnn_wrapper)
def new_rnn_cast(fn, handle, verbose=False):
def new_rnn_cast(fn, cast_fn, handle, verbose=False):
# Forward+backward compatibility around https://github.com/pytorch/pytorch/pull/15744
# For rnn backend calls that route through _rnn_impls, we must patch the ref
# that _rnn_impls stashed. For rnn backend calls that directly invoke
......@@ -232,7 +236,7 @@ def new_rnn_cast(fn, handle, verbose=False):
assert isinstance(mod, rnn_compat.VariableFunctionsShim)
fn = fn.lower()
orig_fn = utils.get_func(mod, fn)
cast_fn = utils.verbosify(utils.maybe_half, fn, verbose)
cast_fn = utils.verbosify(cast_fn, fn, verbose)
@functools.wraps(orig_fn)
def wrapper(*args, **kwargs):
# Exact call signature from modules/rnn.py
......@@ -247,14 +251,20 @@ def new_rnn_cast(fn, handle, verbose=False):
else:
params_idx = 3 # PackedSequence case
if cast_fn == utils.maybe_half:
dtype = torch.half
elif cast_fn == utils.maybe_bfloat16:
dtype = torch.bfloat16
else:
raise RuntimeError("Unsupported cast_fn passed. Supports only maybe_half and maybe_bfloat16")
new_args = []
for i, arg in enumerate(args):
if i == params_idx:
num_params = sum([x.numel() for x in arg])
fp16_weight_buf = args[0].new_empty((num_params,),
dtype=torch.half)
dtype=dtype)
casted_weights = utils.new_synthesize_flattened_rnn_weights(
arg, fp16_weight_buf, fn, verbose)
arg, fp16_weight_buf, fn, dtype, verbose)
new_args.append(casted_weights)
elif utils.is_fp_tensor(arg):
new_args.append(cast_fn(arg))
......
......@@ -149,7 +149,7 @@ void multi_tensor_adam_cuda(
}
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_AND_HALF(
DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BFLOAT16(
tensor_lists[0][0].scalar_type(), 0, "adam",
multi_tensor_apply<4>(
BLOCK_SIZE,
......
......@@ -138,9 +138,9 @@ void multi_tensor_axpby_cuda(
// If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply.
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_axpby_cuda",
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_axpby_cuda",
DISPATCH_FLOAT_AND_HALF(tensor_lists[2][0].scalar_type(), 2, "multi_tensor_axpby_cuda",
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_axpby_cuda",
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_axpby_cuda",
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[2][0].scalar_type(), 2, "multi_tensor_axpby_cuda",
multi_tensor_apply<3>(
BLOCK_SIZE,
chunk_size,
......
......@@ -322,7 +322,7 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
ret_per_tensor = at::empty({0}, float_options);
}
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
multi_tensor_apply<1>(
BLOCK_SIZE,
chunk_size,
......@@ -391,7 +391,7 @@ void multi_tensor_norm_out_cuda(
output_per_tensor = at::zeros({ntensors*max_chunks_per_tensor}, float_options);
if (norm_type == 0) {
DISPATCH_FLOAT_AND_HALF(
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_maxnorm_cuda",
multi_tensor_apply<1>(
BLOCK_SIZE,
......@@ -405,7 +405,7 @@ void multi_tensor_norm_out_cuda(
max_chunks_per_tensor);)
}
else {
DISPATCH_FLOAT_AND_HALF(
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
multi_tensor_apply<1>(
BLOCK_SIZE,
......
......@@ -363,7 +363,7 @@ void multi_tensor_lamb_cuda(
// We now in-place modify grad to store update before compute its norm
// Generally this is not a issue since people modify grad in step() method all the time
// We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
......@@ -386,7 +386,7 @@ void multi_tensor_lamb_cuda(
std::vector<std::vector<at::Tensor>> grad_param_list(tensor_lists.begin(), tensor_lists.begin()+2);
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
multi_tensor_apply<2>(
BLOCK_SIZE,
chunk_size,
......
......@@ -127,9 +127,9 @@ void multi_tensor_lamb_stage1_cuda(
float next_step = float(step+1);
float beta1_correction = 1.0f - std::pow(beta1, next_step);
float beta2_correction = 1.0f - std::pow(beta2, next_step);
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_1",
DISPATCH_FLOAT_AND_HALF(tensor_lists[4][0].scalar_type(), 2, "lamb_stage_1",
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_1",
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[4][0].scalar_type(), 2, "lamb_stage_1",
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
......
......@@ -91,8 +91,8 @@ void multi_tensor_lamb_stage2_cuda(
{
using namespace at;
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_2",
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_2",
multi_tensor_apply<2>(
BLOCK_SIZE,
chunk_size,
......
......@@ -164,7 +164,7 @@ void multi_tensor_novograd_cuda(
multi_tensor_norm_out_cuda(chunk_size, noop_flag, grad_list, grad_norms, beta2, (1.0f - beta2), norm_type);
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_AND_HALF(
DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BFLOAT16(
tensor_lists[0][0].scalar_type(), 0, "novograd",
multi_tensor_apply<3>(
BLOCK_SIZE,
......
......@@ -121,8 +121,8 @@ void multi_tensor_scale_cuda(
// If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply.
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda",
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda",
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda",
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda",
multi_tensor_apply<2>(
BLOCK_SIZE,
chunk_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment