Commit 6f7a8b39 authored by lcskrishna's avatar lcskrishna
Browse files

Merge remote-tracking branch 'rocm_upstream/master' into ifu_07272020

parents 459de22d 9c80f6d3
sudo docker build . --rm -t apex
sudo docker run -it -v $HOME:/data --rm --privileged --device=/dev/dri --device=/dev/kfd --network host --group-add video apex
ARG FROM_IMAGE=lcskrishna/rocm-pytorch:rocm3.3_ubuntu16.04_py3.6_pytorch_bfloat16_mgpu
FROM ${FROM_IMAGE}
RUN \
git clone --recursive https://github.com/ROCmSoftwarePlatform/apex.git && \
cd apex && \
python3.6 setup.py install --cpp_ext --cuda_ext
......@@ -115,8 +115,27 @@ It's often convenient to use Apex in Docker containers. Compatible options incl
See the [Docker example folder](https://github.com/NVIDIA/apex/tree/master/examples/docker) for details.
## On ROCm:
* Python 3.6
* Pytorch 1.5 or newer, The HIPExtensions require 1.5 or newer.
* We recommend follow the instructions from [ROCm-Pytorch](https://github.com/ROCmSoftwarePlatform/pytorch) to install pytorch on ROCm.
# Quick Start
### Rocm
Apex on ROCm supports both python only build and extension build.
Note: Pytorch version recommended is >=1.5 for extension build.
### To install using python only build use the following command in apex folder:
```
python3.6 setup.py install
```
### To install using extensions enabled use the following command in apex folder:
```
python3.6 setup.py install --cpp_ext --cuda_ext
```
### Linux
For performance and full functionality, we recommend installing Apex with
......
......@@ -18,3 +18,6 @@ from . import fp16_utils
from . import optimizers
from . import normalization
from . import pyprof
#common utilties to run tests on ROCm.
from . import testing
from .amp import init, half_function, float_function, promote_function,\
register_half_function, register_float_function, register_promote_function
from .amp import init, half_function, bfloat16_function, float_function, promote_function,\
register_half_function, register_bfloat16_function, register_float_function, register_promote_function
from .handle import scale_loss, disable_casts
from .frontend import initialize, state_dict, load_state_dict
from ._amp_state import master_params, _amp_state
......@@ -80,10 +80,10 @@ def check_params_fp32(models):
for model in models:
for name, param in model.named_parameters():
if param.is_floating_point():
if 'Half' in param.type():
if 'Half' in param.type() or 'BFloat16' in param.type():
warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.initialize, you do not need to call .half() on your model\n"
"before passing it, no matter what optimization level you choose.".format(
"When using amp.initialize, you do not need to call .half() or .bfloat16()\n"
"on your model before passing it, no matter what optimization level you choose.".format(
name, param.type()))
elif not param.is_cuda:
warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
......@@ -137,7 +137,7 @@ class O2StateDictHook(object):
def __call__(self, module, state_dict, prefix, local_metadata):
for key in state_dict:
param = state_dict[key]
if 'Half' in param.type():
if 'Half' in param.type() or 'BFloat16' in param.type():
param = param.to(torch.float32)
state_dict[key] = param
......@@ -189,7 +189,7 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
for model in models:
# Patch the forward method to cast incoming data to the correct type, and
# outgoing data to float32, so "the user never needs to call .half()."
# outgoing data to float32, so "the user never needs to call .half()/.bfloat16()."
# I like writing things explicitly more than decorators.
def patch_forward(old_fwd):
def new_fwd(*args, **kwargs):
......@@ -232,7 +232,9 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
if properties.patch_torch_functions:
# handle is unused here. It's accessible later through a global value anyway.
handle = amp_init(loss_scale=properties.loss_scale, verbose=(_amp_state.verbosity == 2))
handle = amp_init(loss_scale=properties.loss_scale,
patch_type=properties.patch_torch_functions_type,
verbose=(_amp_state.verbosity == 2))
for optimizer in optimizers:
# Disable Amp casting for the optimizer step, because it should only be
# applied to FP32 master params anyway.
......
import types
from ..fp16_utils import master_params_to_model_params
from ..multi_tensor_apply import multi_tensor_applier
from ._amp_state import maybe_print
from ._amp_state import maybe_print, _amp_state
import torch
from ..optimizers import FusedSGD
......@@ -37,7 +37,7 @@ def lazy_init_with_master_weights(self):
fp32_from_fp16_params_this_group = []
for i, param in enumerate(param_group['params']):
if param.requires_grad:
if param.type() == 'torch.cuda.HalfTensor':
if param.type() in {'torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor'}:
# maybe_print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
# .format(param.size()))
fp16_params_this_group.append(param)
......@@ -55,8 +55,8 @@ def lazy_init_with_master_weights(self):
fp32_params_this_group.append(param)
param_group['params'][i] = param
else:
raise TypeError("Optimizer's parameters must be either "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
raise TypeError("Optimizer's parameters must one of "
"torch.cuda.FloatTensor, torch.cuda.HalfTensor, torch.cuda.BFloat16Tensor. "
"Received {}".format(param.type()))
stash.fp16_groups.append(fp16_params_this_group)
......@@ -208,13 +208,13 @@ def lazy_init_no_master_weights(self):
stash.all_fp32_params = []
for i, param_group in enumerate(self.param_groups):
for i, param in enumerate(param_group['params']):
if param.type() == 'torch.cuda.HalfTensor':
if param.type() in {'torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor'}:
stash.all_fp16_params.append(param)
elif param.type() == 'torch.cuda.FloatTensor':
stash.all_fp32_params.append(param)
else:
raise TypeError("Optimizer's parameters must be either "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
raise TypeError("Optimizer's parameters must be one of "
"torch.cuda.FloatTensor, torch.cuda.HalfTensor, torch.BFloat16Tensor. "
"Received {}".format(param.type()))
stash.all_fp16_grad_stash = [None for _ in stash.all_fp16_params]
......@@ -435,7 +435,7 @@ def _process_optimizer(optimizer, properties):
fp32_from_fp16_params_this_group = []
for i, param in enumerate(new_group['params']):
if param.requires_grad:
if param.type() == 'torch.cuda.HalfTensor':
if param.type() in {'torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor'}:
fp16_params_this_group.append(param)
master_param = param.detach().clone().float()
master_param.requires_grad = True
......@@ -445,8 +445,8 @@ def _process_optimizer(optimizer, properties):
fp32_params_this_group.append(param)
new_group['params'][i] = param
else:
raise TypeError("Optimizer's parameters must be either "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
raise TypeError("Optimizer's parameters must be one of "
"torch.cuda.FloatTensor, torch.cuda.HalfTensor, torch.cuda.BFloat16Tensor. "
"Received {}".format(param.type()))
stash.fp16_groups.append(fp16_params_this_group)
......@@ -471,15 +471,15 @@ def _process_optimizer(optimizer, properties):
# param.grad = None
else:
for param in new_group['params']:
if param.type() == 'torch.cuda.HalfTensor':
if param.type() in {'torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor'}:
stash.all_fp16_params.append(param)
stash.all_fp16_grad_stash.append(None)
elif param.type() == 'torch.cuda.FloatTensor':
stash.all_fp32_params.append(param)
stash.all_fp32_grad_stash.append(None)
else:
raise TypeError("Optimizer's parameters must be either "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
raise TypeError("Optimizer's parameters must one of "
"torch.cuda.FloatTensor, torch.cuda.HalfTensor, torch.cuda.BFloat16Tensor. "
"Received {}".format(param.type()))
old_add_param_group(new_group)
......
......@@ -9,7 +9,6 @@ import itertools
import torch
_DECORATOR_HANDLE = None
_USER_CAST_REGISTRY = set()
_USER_PROMOTE_REGISTRY = set()
......@@ -31,6 +30,9 @@ def half_function(fn):
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=True)
return _decorator_helper(fn, utils.maybe_half, wrap_fn)
def bfloat16_function(fn):
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=True)
return _decorator_helper(fn, utils.maybe_bfloat16, wrap_fn)
def float_function(fn):
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=False)
......@@ -49,6 +51,11 @@ def register_half_function(module, name):
name, module))
_USER_CAST_REGISTRY.add((module, name, utils.maybe_half))
def register_bfloat16_function(module, name):
if not hasattr(module, name):
raise ValueError('No function named {} in module {}.'.format(
name, module))
_USER_CAST_REGISTRY.add((module, name, utils.maybe_bfloat16))
def register_float_function(module, name):
if not hasattr(module, name):
......@@ -65,7 +72,7 @@ def register_promote_function(module, name):
# Top-level function to insert _all_ the hooks.
def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False, allow_banned=False):
def init(enabled=True, loss_scale="dynamic", patch_type=torch.float16, enable_caching=True, verbose=False, allow_banned=False):
global _DECORATOR_HANDLE
if not enabled:
......@@ -87,16 +94,30 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False,
wrap.promote(mod, fn, handle, verbose)
_USER_PROMOTE_REGISTRY.clear()
# conditionally choose between fp16 and bfloat16 functions list to cache
if patch_type == torch.float16:
low_prec_funcs = 'FP16_FUNCS'
maybe_low_prec = utils.maybe_half
low_prec_tensor = torch.cuda.HalfTensor
elif patch_type == torch.bfloat16:
low_prec_funcs = 'BFLOAT16_FUNCS'
maybe_low_prec = utils.maybe_bfloat16
low_prec_tensor = torch.cuda.BFloat16Tensor
else:
raise RuntimeError("Unsupported patch_torch_functions_type passed to initialize." +
"Supported types are: torch.float16 and torch.bfloat16.")
# 1) Force-{fp16, fp32} on white- / black-list functions
override_modules = [functional_overrides,
torch_overrides,
tensor_overrides]
cast_table = [('FP16_FUNCS', utils.maybe_half),
cast_table = [(low_prec_funcs, maybe_low_prec),
('FP32_FUNCS', utils.maybe_float)]
for module, (list_name, cast_fn) in itertools.product(override_modules,
cast_table):
for fn in getattr(module, list_name):
try_caching = (cast_fn == utils.maybe_half)
try_caching = (cast_fn == maybe_low_prec)
wrap.cached_cast(module.MODULE, fn, cast_fn, handle,
try_caching, verbose)
......@@ -128,12 +149,12 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False,
for fn in getattr(tensor_overrides, list_name):
promote_fn(cls, fn, handle, verbose)
# 3) For any in-place version of a blacklist function, error if any input is fp16.
# 3) For any in-place version of a blacklist function, error if any input is fp16/bfloat16.
# NB: this is overly conservative.
for fn in utils.as_inplace(torch_overrides.FP32_FUNCS):
wrap.err_if_any_half(torch_overrides.MODULE, fn, handle)
# 3.5) For any in-place blacklist method, error if called on fp16 tensor
# 3.5) For any in-place blacklist method, error if called on fp16/bfloat16 tensor
for fn in utils.as_inplace(tensor_overrides.FP32_FUNCS):
wrap.err_if_arg0_half(tensor_overrides.MODULE, fn, handle, verbose)
if compat.tensor_is_float_tensor():
......@@ -141,7 +162,7 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False,
# 4) For other in-place methods, match the type of self tensor
for fn in utils.as_inplace(itertools.chain(
tensor_overrides.FP16_FUNCS,
getattr(tensor_overrides, low_prec_funcs),
tensor_overrides.CASTS)):
wrap.promote_match_arg0(tensor_overrides.MODULE, fn, handle, verbose)
if compat.tensor_is_float_tensor():
......@@ -156,10 +177,10 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False,
torch.nn.modules.rnn._VF = rnn_compat.VariableFunctionsShim()
# Wrap all the rnns
for x in rnn_compat.RNN_NAMES:
wrap.new_rnn_cast(x.upper(), handle, verbose)
wrap.new_rnn_cast(x.upper(), maybe_low_prec, handle, verbose)
# Wrap all the RNN cells
rnn_compat.whitelist_rnn_cells(handle, verbose)
rnn_compat.whitelist_rnn_cells(maybe_low_prec, handle, verbose)
# 6) Place error+print message on banned functions.
# Or, if allow_banned, then cast to FP32.
......
......@@ -28,7 +28,8 @@ def is_floating_point(x):
torch_type = x.type()
return torch_type.endswith('FloatTensor') or \
torch_type.endswith('HalfTensor') or \
torch_type.endswith('DoubleTensor')
torch_type.endswith('DoubleTensor') or \
torch_type.endswith('BFloat16Tensor')
except AttributeError:
return False
......
......@@ -16,6 +16,10 @@ class Properties(object):
"opt_level" : None,
"cast_model_type" : None,
"patch_torch_functions" : False,
# TODO: patch_torch_functions_type could probably be unified with
# patch_torch_functions. Currently introducing a new attribute
# to be on the safer side and not break stuff.
"patch_torch_functions_type" : None,
"keep_batchnorm_fp32" : None,
"master_weights" : None,
"loss_scale" : 1.0,
......@@ -53,7 +57,7 @@ class Properties(object):
if name in self.options:
# print("setting {} {}".format(name, value))
if name == "cast_model_type":
if self.opt_level == "O1" and value is not None:
if self.opt_level in {"O1", "O4"} and value is not None:
if value is not False:
if value is not torch.float32:
warn_or_err("O1 inserts casts around Torch functions rather than "
......@@ -63,13 +67,25 @@ class Properties(object):
"cast_model_type was {}".format(value))
self.options[name] = value
elif name == "patch_torch_functions":
if self.opt_level != "O1" and value:
if self.opt_level not in {"O1", "O4"} and value:
warn_or_err("Currently, patch_torch_functions=True should only be set by "
"selecting opt_level='O1'.")
"selecting opt_level='O1' or 'O4'.")
self.options[name] = value
elif name == "patch_torch_functions_type":
if self.opt_level not in {"O1", "O4"} and value is not None:
warn_or_err("Currently, patch_torch_functions_type should only be set by "
"selecting opt_level='O1' or 'O4'.")
elif self.opt_level == "O1" and value != torch.float16:
warn_or_err("patch_torch_functions_type should only be set to torch.float16 "
"for opt_level='O1.")
elif self.opt_level == "O4" and value != torch.bfloat16:
warn_or_err("patch_torch_functions_type should only be set to torch.bfloat16 "
"for opt_level='O4.")
else:
self.options[name] = value
elif name == "keep_batchnorm_fp32":
if self.opt_level == "O1" and value is not None:
warn_or_err("With opt_level O1, batchnorm functions are automatically patched "
if self.opt_level in {"O1", "O4"} and value is not None:
warn_or_err("With opt_level O1 or O4, batchnorm functions are automatically patched "
"to run in FP32, so keep_batchnorm_fp32 should be None." +
" keep_batchnorm_fp32 was {}".format(value))
if value == "False":
......@@ -82,9 +98,9 @@ class Properties(object):
"or None, found keep_batchnorm_fp32={}".format(value)
self.options[name] = value
elif name == "master_weights":
if self.opt_level == "O1" and value is not None:
warn_or_err("It doesn't make sense to use master_weights with O1. "
"With O1, your model weights themselves should be FP32.")
if self.opt_level in {"O1", "O4"} and value is not None:
warn_or_err("It doesn't make sense to use master_weights with O1 and O4 . "
"With O1 and O4, your model weights themselves should be FP32.")
self.options[name] = value
elif name == "loss_scale":
if value == "dynamic":
......@@ -113,6 +129,7 @@ class O3:
properties.opt_level = "O3"
properties.cast_model_type = torch.float16
properties.patch_torch_functions = False
properties.patch_torch_functions_type = None
properties.keep_batchnorm_fp32 = False
properties.master_weights = False
properties.loss_scale = 1.0
......@@ -136,6 +153,7 @@ class O2:
properties.opt_level = "O2"
properties.cast_model_type = torch.float16
properties.patch_torch_functions = False
properties.patch_torch_functions_type = None
properties.keep_batchnorm_fp32 = True
properties.master_weights = True
properties.loss_scale = "dynamic"
......@@ -158,6 +176,7 @@ class O1:
properties.opt_level = "O1"
properties.cast_model_type = None
properties.patch_torch_functions = True
properties.patch_torch_functions_type = torch.float16
properties.keep_batchnorm_fp32 = None
properties.master_weights = None
properties.loss_scale = "dynamic"
......@@ -177,6 +196,7 @@ class O0:
properties.opt_level = "O0"
properties.cast_model_type = torch.float32
properties.patch_torch_functions = False
properties.patch_torch_functions_type = None
properties.keep_batchnorm_fp32 = None
properties.master_weights = False
properties.loss_scale = 1.0
......@@ -184,11 +204,54 @@ class O0:
# properties.enable_ddp_interop = False
return properties # modified in place so this isn't really necessary
class O4:
brief = "O4: Insert automatic casts around Pytorch functions and Tensor methods.\n"
more = "The type of your model's weights is not altered. However, internally,\n"\
"Pytorch functions are patched to cast any Tensor Core-friendly ops to BFLOAT16 for speed,\n"\
"while operations that might benefit from the additional stability of FP32 are patched\n"\
"to cast their inputs to fp32.\n"\
"Loss scaling is not required in O4 mode since bflaot16 has the same dynamic range as fp32."
def __call__(self, properties):
properties.enabled = True
properties.opt_level = "O4"
properties.cast_model_type = None
properties.patch_torch_functions = True
properties.patch_torch_functions_type = torch.bfloat16
properties.keep_batchnorm_fp32 = None
properties.master_weights = None
properties.loss_scale = 1
return properties # modified in place so this isn't really necessary
class O5:
brief = "O5: BFLOAT16 training with FP32 batchnorm and FP32 master weights.\n"
more = "Calls .bfloat16() on your model, converting the entire model (except for batchnorms)\n"\
"to BFLOAT16. Batchnorms are retained in FP32 for additional stability.\n"\
"The forward pass is patched to cast incoming Tensors to BFLOAT16, so you don't need to change\n"\
"your data pipeline.\n"\
"O5 creates FP32 master weights outside the model and patches any optimizers to update\n"\
"these master weights, then copy the master weights into the BFLOAT16 model weights.\n"\
"Master weights can also improve convergence and stability."
def __call__(self, properties):
properties.enabled = True
properties.opt_level = "O5"
properties.cast_model_type = torch.bfloat16
properties.patch_torch_functions = False
properties.patch_torch_functions = None
properties.patch_torch_functions_type = None
properties.keep_batchnorm_fp32 = True
properties.master_weights = True
properties.loss_scale = 1
return properties # modified in place so this isn't really necessary
opt_levels = {"O3": O3(),
"O2": O2(),
"O1": O1(),
"O0": O0()}
"O0": O0(),
"O4": O4(),
"O5": O5()}
# allow user to directly pass Properties struct as well?
......@@ -199,6 +262,7 @@ def initialize(
opt_level="O1",
cast_model_type=None,
patch_torch_functions=None,
patch_torch_functions_type=None,
keep_batchnorm_fp32=None,
master_weights=None,
loss_scale=None,
......@@ -235,10 +299,11 @@ def initialize(
enabled (bool, optional, default=True): If False, renders all Amp calls no-ops, so your script
should run as if Amp were not present.
opt_level (str, optional, default="O1"): Pure or mixed precision optimization level. Accepted values are
"O0", "O1", "O2", and "O3", explained in detail above.
"O0", "O1", "O2", "O3", "O4" and "O5", explained in detail above.
cast_model_type (``torch.dtype``, optional, default=None): Optional property override, see
above.
patch_torch_functions (bool, optional, default=None): Optional property override.
patch_torch_functions_type (``torch.dtype``, optional, default=None): Optional property override
keep_batchnorm_fp32 (bool or str, optional, default=None): Optional property override. If
passed as a string, must be the string "True" or "False".
master_weights (bool, optional, default=None): Optional property override.
......@@ -321,14 +386,14 @@ def initialize(
if opt_level not in opt_levels:
raise RuntimeError(
"Unexpected optimization level {}. ".format(opt_level) +
"Options are 'O0', 'O1', 'O2', 'O3'. Note that in `O0`, `O1`, etc., the prefix O is the letter O, " +
"Options are 'O0', 'O1', 'O2', 'O3', 'O4', 'O5'. Note that in `O0`, `O1`, etc., the prefix O is the letter O, " +
"not the number zero.")
else:
_amp_state.opt_properties = opt_levels[opt_level](_amp_state.opt_properties)
maybe_print("Selected optimization level {}".format(opt_levels[opt_level].brief), True)
maybe_print("Defaults for this optimization level are:", True)
for k, v in _amp_state.opt_properties.options.items():
maybe_print("{:22} : {}".format(k, v), True)
maybe_print("{:26} : {}".format(k, v), True)
_amp_state.min_loss_scale = min_loss_scale
_amp_state.max_loss_scale = max_loss_scale
......@@ -344,6 +409,8 @@ def initialize(
_amp_state.opt_properties.cast_model_type = cast_model_type
if patch_torch_functions is not None:
_amp_state.opt_properties.patch_torch_functions = patch_torch_functions
if patch_torch_functions_type is not None:
_amp_state.opt_properties.patch_torch_functions_type = patch_torch_functions_type
if keep_batchnorm_fp32 is not None:
_amp_state.opt_properties.keep_batchnorm_fp32 = keep_batchnorm_fp32
if master_weights is not None:
......@@ -353,7 +420,7 @@ def initialize(
maybe_print("After processing overrides, optimization options are:", True)
for k, v in _amp_state.opt_properties.options.items():
maybe_print("{:22} : {}".format(k, v), True)
maybe_print("{:26} : {}".format(k, v), True)
return _initialize(models, optimizers, _amp_state.opt_properties, num_losses, cast_model_outputs)
......
......@@ -26,6 +26,17 @@ FP16_FUNCS = [
'linear',
]
BFLOAT16_FUNCS = [
'conv1d',
'conv2d',
'conv3d',
'conv_transpose1d',
'conv_transpose2d',
'conv_transpose3d',
'conv_tbc', # Undocumented / maybe new?
'linear',
]
FP32_FUNCS = [
# Interpolation/Upsampling TODO: Remove for 1.2
......
......@@ -15,6 +15,10 @@ FP16_FUNCS = compat.filter_attrs(MODULE, [
'__matmul__',
])
BFLOAT16_FUNCS = [
'__matmul__',
]
FP32_FUNCS = compat.filter_attrs(MODULE, [
'__ipow__',
'__pow__',
......@@ -56,7 +60,7 @@ SEQUENCE_CASTS = []
# between `torch` and `torch.Tensor` (and check with `hasattr`,
# because a few random ones aren't defined on Tensor)
_self_mod = importlib.import_module(__name__)
for attrname in ['FP16_FUNCS', 'FP32_FUNCS', 'CASTS', 'SEQUENCE_CASTS']:
for attrname in ['FP16_FUNCS', 'BFLOAT16_FUNCS', 'FP32_FUNCS', 'CASTS', 'SEQUENCE_CASTS']:
lst = getattr(_self_mod, attrname)
for fn in getattr(torch_overrides, attrname):
if hasattr(MODULE, fn):
......
......@@ -26,6 +26,27 @@ FP16_FUNCS = [
'mv',
]
BFLOAT16_FUNCS = [
# Low level functions wrapped by torch.nn layers.
# The wrapper layers contain the weights which are then passed in as a parameter
# to these functions.
'conv1d',
'conv2d',
'conv3d',
'conv_transpose1d',
'conv_transpose2d',
'conv_transpose3d',
'conv_tbc',
# BLAS
'addmm',
'addmv',
'addr',
'matmul',
'mm',
'mv',
]
FP32_FUNCS = [
# Pointwise
'acos',
......
......@@ -28,7 +28,7 @@ def has_old_rnns():
except:
return False
def whitelist_rnn_cells(handle, verbose):
def whitelist_rnn_cells(cast_fn, handle, verbose):
# Different module + function names in old/new RNN cases
if has_old_rnns():
fn_names = ['RNNReLUCell', 'RNNTanhCell', 'LSTMCell', 'GRUCell']
......@@ -40,7 +40,7 @@ def whitelist_rnn_cells(handle, verbose):
# Insert casts on cell functions
for fn in fn_names:
wrap.cached_cast(mod, fn, utils.maybe_half, handle,
wrap.cached_cast(mod, fn, cast_fn, handle,
try_caching=True, verbose=verbose)
if has_old_rnns():
......
......@@ -6,12 +6,18 @@ from itertools import product
def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=False):
# Exception handling for 18.04 compatibility
if check_overflow:
cpu_sum = float(model_grad.float().sum())
if model_grad.is_sparse:
cpu_sum = float(model_grad.float()._values().sum())
else:
cpu_sum = float(model_grad.float().sum())
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
return True
if master_grad is not model_grad: # copy_ probably internally short-circuits this
master_grad.copy_(model_grad)
if model_grad.is_sparse:
master_grad.copy_(model_grad.to_dense())
else:
master_grad.copy_(model_grad)
if scale != 1.0:
master_grad.mul_(scale)
return False
......@@ -19,7 +25,10 @@ def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=F
def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, a, b, check_overflow=False):
# Exception handling for 18.04 compatibility
if check_overflow:
cpu_sum = float(model_grad.float().sum())
if model_grad.is_sparse:
cpu_sum = float(model_grad.float()._values().sum())
else:
cpu_sum = float(model_grad.float().sum())
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
return True
......
......@@ -62,6 +62,17 @@ def maybe_half(x, name='', verbose=False):
print('Float->Half ({})'.format(name))
return x.half()
def maybe_bfloat16(x, name='', verbose=False):
if is_nested(x):
return type(x)([maybe_bfloat16(y) for y in x])
if not x.is_cuda or type_string(x) == 'BFloat16Tensor':
return x
else:
if verbose:
print('Float->BFloat16 ({})'.format(name))
return x.bfloat16()
def maybe_float(x, name='', verbose=False):
if is_nested(x):
return type(x)([maybe_float(y) for y in x])
......@@ -189,22 +200,28 @@ def synthesize_flattened_rnn_weights(fp32_weights,
fp16_weights.append(fp16_layer_weights)
return fp16_weights
def _str_from_dtype(dtype=torch.float16):
type_to_str = {torch.float16 : 'Half',
torch.bfloat16 : 'BFloat16'}
return type_to_str[dtype]
# Roughly same as above, just the `fp32_weights` aren't nested.
# Code kept separate for readability.
def new_synthesize_flattened_rnn_weights(fp32_weights,
fp16_flat_tensor,
rnn_fn='',
dtype=torch.float16,
verbose=False):
fp16_weights = []
fp32_base_ptr = fp32_weights[0].data_ptr()
for w_fp32 in fp32_weights:
w_fp16 = w_fp32.new().half()
w_fp16 = w_fp32.new().to(dtype=dtype)
offset = (w_fp32.data_ptr() - fp32_base_ptr) // w_fp32.element_size()
w_fp16.set_(fp16_flat_tensor.storage(),
offset,
w_fp32.shape)
w_fp16.copy_(w_fp32)
if verbose:
print('Float->Half ({})'.format(rnn_fn))
print('Float->{} ({})'.format(_str_from_dtype(dtype), rnn_fn))
fp16_weights.append(w_fp16)
return fp16_weights
......@@ -51,7 +51,8 @@ def make_promote_wrapper(orig_fn, cast_fn, handle=None):
if len(types) <= 1:
return orig_fn(*args, **kwargs)
elif len(types) == 2 and types == set(['HalfTensor', 'FloatTensor']):
elif len(types) == 2 and (types == set(['HalfTensor', 'FloatTensor'])
or types == set(['BFloat16Tensor', 'FloatTensor'])):
new_args = utils.casted_args(cast_fn,
args,
kwargs)
......@@ -79,7 +80,8 @@ def sequence_promote(mod, fn, handle, verbose=False):
types = set([utils.type_string(x) for x in seq])
if len(types) <= 1:
return orig_fn(seq, *args, **kwargs)
elif types == set(['HalfTensor', 'FloatTensor']):
elif (types == set(['HalfTensor', 'FloatTensor']) or
types == set(['BFloat16Tensor', 'FloatTensor'])):
cast_seq = utils.casted_args(maybe_float,
seq, {})
return orig_fn(cast_seq, *args, **kwargs)
......@@ -102,6 +104,8 @@ def promote_match_arg0(mod, fn, handle, verbose=False):
if utils.type_string(arg0) == 'HalfTensor':
cast_fn = utils.maybe_half
if utils.type_string(arg0) == 'BFloat16Tensor':
cast_fn = utils.maybe_bfloat16
elif utils.type_string(arg0) == 'FloatTensor':
cast_fn = utils.maybe_float
else:
......@@ -119,12 +123,12 @@ def err_if_any_half(mod, fn, handle, custom_err_msg=None):
@functools.wraps(orig_fn)
def wrapper(*args, **kwargs):
types = utils.collect_fp_tensor_types(args, kwargs)
if 'HalfTensor' in types:
if 'HalfTensor' in types or 'BFloat16Tensor' in types:
if custom_err_msg:
raise NotImplementedError(custom_err_msg)
else:
raise NotImplementedError('Cannot call in-place function ' +
'{} with fp16 arguments.'.format(fn))
'{} with fp16 or bfloat16 args.'.format(fn))
else:
return orig_fn(*args, **kwargs)
utils.set_func_save(handle, mod, fn, wrapper)
......@@ -137,9 +141,9 @@ def err_if_arg0_half(mod, fn, handle, verbose=False):
@functools.wraps(orig_fn)
def wrapper(arg0, *args, **kwargs):
assert compat.is_tensor_like(arg0)
if utils.type_string(arg0) == 'HalfTensor':
if utils.type_string(arg0) in {'HalfTensor', 'BFloat16Tensor'}:
raise NotImplementedError('Cannot call in-place method ' +
'{} on fp16 Tensors.'.format(fn))
'{} with fp16 or bfloat16 args.'.format(fn))
else:
cast_fn = utils.verbosify(utils.maybe_float, fn, verbose)
new_args = utils.casted_args(cast_fn, args, kwargs)
......@@ -219,7 +223,7 @@ def rnn_cast(backend, fn, handle, verbose=False):
return fwd_wrapper
utils.set_func_save(handle, backend, fn, rnn_wrapper)
def new_rnn_cast(fn, handle, verbose=False):
def new_rnn_cast(fn, cast_fn, handle, verbose=False):
# Forward+backward compatibility around https://github.com/pytorch/pytorch/pull/15744
# For rnn backend calls that route through _rnn_impls, we must patch the ref
# that _rnn_impls stashed. For rnn backend calls that directly invoke
......@@ -232,7 +236,7 @@ def new_rnn_cast(fn, handle, verbose=False):
assert isinstance(mod, rnn_compat.VariableFunctionsShim)
fn = fn.lower()
orig_fn = utils.get_func(mod, fn)
cast_fn = utils.verbosify(utils.maybe_half, fn, verbose)
cast_fn = utils.verbosify(cast_fn, fn, verbose)
@functools.wraps(orig_fn)
def wrapper(*args, **kwargs):
# Exact call signature from modules/rnn.py
......@@ -247,14 +251,20 @@ def new_rnn_cast(fn, handle, verbose=False):
else:
params_idx = 3 # PackedSequence case
if cast_fn == utils.maybe_half:
dtype = torch.half
elif cast_fn == utils.maybe_bfloat16:
dtype = torch.bfloat16
else:
raise RuntimeError("Unsupported cast_fn passed. Supports only maybe_half and maybe_bfloat16")
new_args = []
for i, arg in enumerate(args):
if i == params_idx:
num_params = sum([x.numel() for x in arg])
fp16_weight_buf = args[0].new_empty((num_params,),
dtype=torch.half)
dtype=dtype)
casted_weights = utils.new_synthesize_flattened_rnn_weights(
arg, fp16_weight_buf, fn, verbose)
arg, fp16_weight_buf, fn, dtype, verbose)
new_args.append(casted_weights)
elif utils.is_fp_tensor(arg):
new_args.append(cast_fn(arg))
......
......@@ -76,7 +76,7 @@ struct AdamFunctor
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<DEPTH>& tl,
TensorListMetadata<DEPTH>* tl,
const float b1,
const float b2,
const float eps,
......@@ -85,21 +85,21 @@ struct AdamFunctor
adamMode_t mode,
const float decay)
{
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
int tensor_loc = tl->block_to_tensor[blockIdx.x];
int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl->sizes[tensor_loc];
T* p = (T *)tl.addresses[0][tensor_loc];
T* p = (T *)tl->addresses[0][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T *)tl.addresses[1][tensor_loc];
T* m = (T *)tl->addresses[1][tensor_loc];
m += chunk_idx*chunk_size;
T* v = (T *)tl.addresses[2][tensor_loc];
T* v = (T *)tl->addresses[2][tensor_loc];
v += chunk_idx*chunk_size;
GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc];
GRAD_T* g = (GRAD_T *)tl->addresses[3][tensor_loc];
g += chunk_idx*chunk_size;
GRAD_T* p_copy = NULL;
if (DEPTH == 5) {
p_copy = (GRAD_T *)tl.addresses[4][tensor_loc];
p_copy = (GRAD_T *)tl->addresses[4][tensor_loc];
p_copy += chunk_idx*chunk_size;
}
......@@ -736,17 +736,17 @@ struct MaybeCastFunctor
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* overflow_flag,
TensorListMetadata<DEPTH>& tl)
TensorListMetadata<DEPTH>* tl)
{
if (overflow_flag && *overflow_flag != 0) return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
int tensor_loc = tl->block_to_tensor[blockIdx.x];
int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl->sizes[tensor_loc];
FROM_T* p_in = (FROM_T *)tl.addresses[0][tensor_loc];
FROM_T* p_in = (FROM_T *)tl->addresses[0][tensor_loc];
p_in += chunk_idx*chunk_size;
TO_T* p_out = (TO_T *)tl.addresses[1][tensor_loc];
TO_T* p_out = (TO_T *)tl->addresses[1][tensor_loc];
p_out += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
......
......@@ -32,7 +32,7 @@ struct LAMBStage1Functor
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<4>& tl,
TensorListMetadata<4>* tl,
const float beta1,
const float beta2,
const float beta3,
......@@ -48,22 +48,22 @@ struct LAMBStage1Functor
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
int tensor_loc = tl->block_to_tensor[blockIdx.x];
int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl->sizes[tensor_loc];
float clipped_global_grad_norm = global_grad_norm > max_global_grad_norm ? global_grad_norm / max_global_grad_norm : 1.0f;
T* g = (T*)tl.addresses[0][tensor_loc];
T* g = (T*)tl->addresses[0][tensor_loc];
g += chunk_idx*chunk_size;
T* p = (T*)tl.addresses[1][tensor_loc];
T* p = (T*)tl->addresses[1][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T*)tl.addresses[2][tensor_loc];
T* m = (T*)tl->addresses[2][tensor_loc];
m += chunk_idx*chunk_size;
T* v = (T*)tl.addresses[3][tensor_loc];
T* v = (T*)tl->addresses[3][tensor_loc];
v += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
......@@ -147,7 +147,7 @@ struct LAMBStage2Functor
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<2>& tl,
TensorListMetadata<2>* tl,
const float* per_tensor_param_norm,
const float* per_tensor_update_norm,
const float learning_rate,
......@@ -157,10 +157,10 @@ struct LAMBStage2Functor
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
int tensor_loc = tl->block_to_tensor[blockIdx.x];
int tensor_num = tl->start_tensor_this_launch + tensor_loc;
int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl->sizes[tensor_loc];
MATH_T ratio = learning_rate;
// apply adaptive learning rate to parameters with non-zero weight decay
......@@ -171,10 +171,10 @@ struct LAMBStage2Functor
ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;
}
T* update = (T*)tl.addresses[0][tensor_loc];
T* update = (T*)tl->addresses[0][tensor_loc];
update += chunk_idx*chunk_size;
T* p = (T*)tl.addresses[1][tensor_loc];
T* p = (T*)tl->addresses[1][tensor_loc];
p += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment