Commit c7fd532c authored by rohithkrn's avatar rohithkrn
Browse files

basic enablement for O4 and O5 opt levels

parent 8124df13
...@@ -80,10 +80,10 @@ def check_params_fp32(models): ...@@ -80,10 +80,10 @@ def check_params_fp32(models):
for model in models: for model in models:
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if param.is_floating_point(): if param.is_floating_point():
if 'Half' in param.type(): if 'Half' in param.type() or 'BFloat16' in param.type():
warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n" warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.initialize, you do not need to call .half() on your model\n" "When using amp.initialize, you do not need to call .half() or .bfloat16()\n"
"before passing it, no matter what optimization level you choose.".format( "on your model before passing it, no matter what optimization level you choose.".format(
name, param.type())) name, param.type()))
elif not param.is_cuda: elif not param.is_cuda:
warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n" warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
...@@ -137,7 +137,7 @@ class O2StateDictHook(object): ...@@ -137,7 +137,7 @@ class O2StateDictHook(object):
def __call__(self, module, state_dict, prefix, local_metadata): def __call__(self, module, state_dict, prefix, local_metadata):
for key in state_dict: for key in state_dict:
param = state_dict[key] param = state_dict[key]
if 'Half' in param.type(): if 'Half' in param.type() or 'BFloat16' in param.type():
param = param.to(torch.float32) param = param.to(torch.float32)
state_dict[key] = param state_dict[key] = param
...@@ -232,7 +232,9 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs ...@@ -232,7 +232,9 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
if properties.patch_torch_functions: if properties.patch_torch_functions:
# handle is unused here. It's accessible later through a global value anyway. # handle is unused here. It's accessible later through a global value anyway.
handle = amp_init(loss_scale=properties.loss_scale, verbose=(_amp_state.verbosity == 2)) handle = amp_init(loss_scale=properties.loss_scale,
patch_type=properties.patch_torch_functions_type,
verbose=(_amp_state.verbosity == 2))
for optimizer in optimizers: for optimizer in optimizers:
# Disable Amp casting for the optimizer step, because it should only be # Disable Amp casting for the optimizer step, because it should only be
# applied to FP32 master params anyway. # applied to FP32 master params anyway.
......
import types import types
from ..fp16_utils import master_params_to_model_params from ..fp16_utils import master_params_to_model_params
from ..multi_tensor_apply import multi_tensor_applier from ..multi_tensor_apply import multi_tensor_applier
from ._amp_state import maybe_print from ._amp_state import maybe_print, _amp_state
import torch import torch
from ..optimizers import FusedSGD from ..optimizers import FusedSGD
...@@ -13,7 +13,7 @@ class AmpOptimizerState(object): ...@@ -13,7 +13,7 @@ class AmpOptimizerState(object):
def _master_params_to_model_params(self): def _master_params_to_model_params(self):
stash = self._amp_stash stash = self._amp_stash
if multi_tensor_applier.available: if multi_tensor_applier.available and not _amp_state.opt_properties.opt_level not in {"O4", "O5"}:
if len(stash.all_fp16_params) > 0: if len(stash.all_fp16_params) > 0:
multi_tensor_applier( multi_tensor_applier(
stash.multi_tensor_scale, stash.multi_tensor_scale,
...@@ -37,7 +37,7 @@ def lazy_init_with_master_weights(self): ...@@ -37,7 +37,7 @@ def lazy_init_with_master_weights(self):
fp32_from_fp16_params_this_group = [] fp32_from_fp16_params_this_group = []
for i, param in enumerate(param_group['params']): for i, param in enumerate(param_group['params']):
if param.requires_grad: if param.requires_grad:
if param.type() == 'torch.cuda.HalfTensor': if param.type() in {'torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor'}:
# maybe_print("FP16_Optimizer received torch.cuda.HalfTensor with {}" # maybe_print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
# .format(param.size())) # .format(param.size()))
fp16_params_this_group.append(param) fp16_params_this_group.append(param)
...@@ -55,8 +55,8 @@ def lazy_init_with_master_weights(self): ...@@ -55,8 +55,8 @@ def lazy_init_with_master_weights(self):
fp32_params_this_group.append(param) fp32_params_this_group.append(param)
param_group['params'][i] = param param_group['params'][i] = param
else: else:
raise TypeError("Optimizer's parameters must be either " raise TypeError("Optimizer's parameters must one of "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. " "torch.cuda.FloatTensor, torch.cuda.HalfTensor, torch.cuda.BFloat16Tensor. "
"Received {}".format(param.type())) "Received {}".format(param.type()))
stash.fp16_groups.append(fp16_params_this_group) stash.fp16_groups.append(fp16_params_this_group)
...@@ -208,7 +208,7 @@ def lazy_init_no_master_weights(self): ...@@ -208,7 +208,7 @@ def lazy_init_no_master_weights(self):
stash.all_fp32_params = [] stash.all_fp32_params = []
for i, param_group in enumerate(self.param_groups): for i, param_group in enumerate(self.param_groups):
for i, param in enumerate(param_group['params']): for i, param in enumerate(param_group['params']):
if param.type() == 'torch.cuda.HalfTensor': if param.type() in {'torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor'}:
stash.all_fp16_params.append(param) stash.all_fp16_params.append(param)
elif param.type() == 'torch.cuda.FloatTensor': elif param.type() == 'torch.cuda.FloatTensor':
stash.all_fp32_params.append(param) stash.all_fp32_params.append(param)
...@@ -337,7 +337,7 @@ def _process_optimizer(optimizer, properties): ...@@ -337,7 +337,7 @@ def _process_optimizer(optimizer, properties):
raise RuntimeError("Incoming optimizer already has {} defined.".format(name)) raise RuntimeError("Incoming optimizer already has {} defined.".format(name))
# TODO: Centralize exposure and import error checking for the C backend. # TODO: Centralize exposure and import error checking for the C backend.
if multi_tensor_applier.available: if multi_tensor_applier.available and not properties.opt_level in {"O4", "O5"}:
import amp_C import amp_C
optimizer._amp_stash.multi_tensor_scale = amp_C.multi_tensor_scale optimizer._amp_stash.multi_tensor_scale = amp_C.multi_tensor_scale
optimizer._amp_stash.multi_tensor_l2norm = amp_C.multi_tensor_l2norm optimizer._amp_stash.multi_tensor_l2norm = amp_C.multi_tensor_l2norm
...@@ -435,7 +435,7 @@ def _process_optimizer(optimizer, properties): ...@@ -435,7 +435,7 @@ def _process_optimizer(optimizer, properties):
fp32_from_fp16_params_this_group = [] fp32_from_fp16_params_this_group = []
for i, param in enumerate(new_group['params']): for i, param in enumerate(new_group['params']):
if param.requires_grad: if param.requires_grad:
if param.type() == 'torch.cuda.HalfTensor': if param.type() in {'torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor'}:
fp16_params_this_group.append(param) fp16_params_this_group.append(param)
master_param = param.detach().clone().float() master_param = param.detach().clone().float()
master_param.requires_grad = True master_param.requires_grad = True
...@@ -445,8 +445,8 @@ def _process_optimizer(optimizer, properties): ...@@ -445,8 +445,8 @@ def _process_optimizer(optimizer, properties):
fp32_params_this_group.append(param) fp32_params_this_group.append(param)
new_group['params'][i] = param new_group['params'][i] = param
else: else:
raise TypeError("Optimizer's parameters must be either " raise TypeError("Optimizer's parameters must be one of "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. " "torch.cuda.FloatTensor, torch.cuda.HalfTensor, torch.cuda.BFloat16Tensor. "
"Received {}".format(param.type())) "Received {}".format(param.type()))
stash.fp16_groups.append(fp16_params_this_group) stash.fp16_groups.append(fp16_params_this_group)
...@@ -471,15 +471,15 @@ def _process_optimizer(optimizer, properties): ...@@ -471,15 +471,15 @@ def _process_optimizer(optimizer, properties):
# param.grad = None # param.grad = None
else: else:
for param in new_group['params']: for param in new_group['params']:
if param.type() == 'torch.cuda.HalfTensor': if param.type() in {'torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor'}:
stash.all_fp16_params.append(param) stash.all_fp16_params.append(param)
stash.all_fp16_grad_stash.append(None) stash.all_fp16_grad_stash.append(None)
elif param.type() == 'torch.cuda.FloatTensor': elif param.type() == 'torch.cuda.FloatTensor':
stash.all_fp32_params.append(param) stash.all_fp32_params.append(param)
stash.all_fp32_grad_stash.append(None) stash.all_fp32_grad_stash.append(None)
else: else:
raise TypeError("Optimizer's parameters must be either " raise TypeError("Optimizer's parameters must one of "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. " "torch.cuda.FloatTensor, torch.cuda.HalfTensor, torch.cuda.BFloat16Tensor. "
"Received {}".format(param.type())) "Received {}".format(param.type()))
old_add_param_group(new_group) old_add_param_group(new_group)
......
...@@ -9,7 +9,6 @@ import itertools ...@@ -9,7 +9,6 @@ import itertools
import torch import torch
_DECORATOR_HANDLE = None _DECORATOR_HANDLE = None
_USER_CAST_REGISTRY = set() _USER_CAST_REGISTRY = set()
_USER_PROMOTE_REGISTRY = set() _USER_PROMOTE_REGISTRY = set()
...@@ -65,7 +64,7 @@ def register_promote_function(module, name): ...@@ -65,7 +64,7 @@ def register_promote_function(module, name):
# Top-level function to insert _all_ the hooks. # Top-level function to insert _all_ the hooks.
def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False, allow_banned=False): def init(enabled=True, loss_scale="dynamic", patch_type=torch.float16, enable_caching=True, verbose=False, allow_banned=False):
global _DECORATOR_HANDLE global _DECORATOR_HANDLE
if not enabled: if not enabled:
...@@ -87,27 +86,41 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False, ...@@ -87,27 +86,41 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False,
wrap.promote(mod, fn, handle, verbose) wrap.promote(mod, fn, handle, verbose)
_USER_PROMOTE_REGISTRY.clear() _USER_PROMOTE_REGISTRY.clear()
# conditionally choose between fp16 and bfloat16 functions list to cache
if patch_type == torch.float16:
low_prec_funcs = 'FP16_FUNCS'
maybe_low_prec = utils.maybe_half
low_prec_tensor = torch.cuda.HalfTensor
elif patch_type == torch.bfloat16:
low_prec_funcs = 'BFLOAT16_FUNCS'
maybe_low_prec = utils.maybe_bfloat16
low_prec_tensor = torch.cuda.BFloat16Tensor
else:
raise RuntimeError("Unsupported patch_torch_functions_type passed to initialize." +
"Supported types are: torch.float16 and torch.bfloat16.")
# 1) Force-{fp16, fp32} on white- / black-list functions # 1) Force-{fp16, fp32} on white- / black-list functions
override_modules = [functional_overrides, override_modules = [functional_overrides,
torch_overrides, torch_overrides,
tensor_overrides] tensor_overrides]
cast_table = [('FP16_FUNCS', utils.maybe_half), cast_table = [(low_prec_funcs, maybe_low_prec),
('FP32_FUNCS', utils.maybe_float)] ('FP32_FUNCS', utils.maybe_float)]
for module, (list_name, cast_fn) in itertools.product(override_modules, for module, (list_name, cast_fn) in itertools.product(override_modules,
cast_table): cast_table):
for fn in getattr(module, list_name): for fn in getattr(module, list_name):
try_caching = (cast_fn == utils.maybe_half) try_caching = (cast_fn == maybe_low_prec)
wrap.cached_cast(module.MODULE, fn, cast_fn, handle, wrap.cached_cast(module.MODULE, fn, cast_fn, handle,
try_caching, verbose) try_caching, verbose)
# 1.5) Pre-0.4, put the blacklist methods on HalfTensor and whitelist # 1.5) Pre-0.4, put the blacklist methods on HalfTensor and whitelist
# methods on FloatTensor, since they're distinct types. # methods on FloatTensor, since they're distinct types.
if compat.tensor_is_float_tensor(): if compat.tensor_is_float_tensor():
for fn in tensor_overrides.FP16_FUNCS: for fn in getattr(tensor_overrides, low_prec_funcs):
wrap.cached_cast(torch.cuda.FloatTensor, fn, utils.maybe_half, wrap.cached_cast(torch.cuda.FloatTensor, fn, utils.maybe_low_prec,
handle, try_caching=True, verbose=verbose) handle, try_caching=True, verbose=verbose)
for fn in tensor_overrides.FP32_FUNCS: for fn in tensor_overrides.FP32_FUNCS:
wrap.cached_cast(torch.cuda.HalfTensor, fn, utils.maybe_float, wrap.cached_cast(low_prec_tensor, fn, utils.maybe_float,
handle, try_caching=False, verbose=verbose) handle, try_caching=False, verbose=verbose)
# 2) Enable type-promotion on multi-arg functions and methods. # 2) Enable type-promotion on multi-arg functions and methods.
...@@ -123,7 +136,7 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False, ...@@ -123,7 +136,7 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False,
# 2.5) Pre-0.4, add blacklist methods directly to HalfTensor and FloatTensor types # 2.5) Pre-0.4, add blacklist methods directly to HalfTensor and FloatTensor types
if compat.tensor_is_float_tensor(): if compat.tensor_is_float_tensor():
for cls, (list_name, promote_fn) in itertools.product([torch.cuda.FloatTensor, for cls, (list_name, promote_fn) in itertools.product([torch.cuda.FloatTensor,
torch.cuda.HalfTensor], low_prec_tensor],
promote_table): promote_table):
for fn in getattr(tensor_overrides, list_name): for fn in getattr(tensor_overrides, list_name):
promote_fn(cls, fn, handle, verbose) promote_fn(cls, fn, handle, verbose)
...@@ -141,11 +154,11 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False, ...@@ -141,11 +154,11 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False,
# 4) For other in-place methods, match the type of self tensor # 4) For other in-place methods, match the type of self tensor
for fn in utils.as_inplace(itertools.chain( for fn in utils.as_inplace(itertools.chain(
tensor_overrides.FP16_FUNCS, getattr(tensor_overrides, low_prec_funcs),
tensor_overrides.CASTS)): tensor_overrides.CASTS)):
wrap.promote_match_arg0(tensor_overrides.MODULE, fn, handle, verbose) wrap.promote_match_arg0(tensor_overrides.MODULE, fn, handle, verbose)
if compat.tensor_is_float_tensor(): if compat.tensor_is_float_tensor():
wrap.promote_match_arg0(torch.cuda.HalfTensor, fn, handle, verbose) wrap.promote_match_arg0(low_prec_tensor, fn, handle, verbose)
wrap.promote_match_arg0(torch.cuda.FloatTensor, fn, handle, verbose) wrap.promote_match_arg0(torch.cuda.FloatTensor, fn, handle, verbose)
# 5) RNNs + RNN cells are whitelisted specially # 5) RNNs + RNN cells are whitelisted specially
......
...@@ -28,7 +28,8 @@ def is_floating_point(x): ...@@ -28,7 +28,8 @@ def is_floating_point(x):
torch_type = x.type() torch_type = x.type()
return torch_type.endswith('FloatTensor') or \ return torch_type.endswith('FloatTensor') or \
torch_type.endswith('HalfTensor') or \ torch_type.endswith('HalfTensor') or \
torch_type.endswith('DoubleTensor') torch_type.endswith('DoubleTensor') or \
torch_type.endswith('BFloat16Tensor')
except AttributeError: except AttributeError:
return False return False
......
...@@ -16,6 +16,7 @@ class Properties(object): ...@@ -16,6 +16,7 @@ class Properties(object):
"opt_level" : None, "opt_level" : None,
"cast_model_type" : None, "cast_model_type" : None,
"patch_torch_functions" : False, "patch_torch_functions" : False,
"patch_torch_functions_type" : None,
"keep_batchnorm_fp32" : None, "keep_batchnorm_fp32" : None,
"master_weights" : None, "master_weights" : None,
"loss_scale" : 1.0, "loss_scale" : 1.0,
...@@ -53,7 +54,7 @@ class Properties(object): ...@@ -53,7 +54,7 @@ class Properties(object):
if name in self.options: if name in self.options:
# print("setting {} {}".format(name, value)) # print("setting {} {}".format(name, value))
if name == "cast_model_type": if name == "cast_model_type":
if self.opt_level == "O1" and value is not None: if self.opt_level in {"O1", "O4"} and value is not None:
if value is not False: if value is not False:
if value is not torch.float32: if value is not torch.float32:
warn_or_err("O1 inserts casts around Torch functions rather than " warn_or_err("O1 inserts casts around Torch functions rather than "
...@@ -63,13 +64,25 @@ class Properties(object): ...@@ -63,13 +64,25 @@ class Properties(object):
"cast_model_type was {}".format(value)) "cast_model_type was {}".format(value))
self.options[name] = value self.options[name] = value
elif name == "patch_torch_functions": elif name == "patch_torch_functions":
if self.opt_level != "O1" and value: if self.opt_level not in {"O1", "O4"} and value:
warn_or_err("Currently, patch_torch_functions=True should only be set by " warn_or_err("Currently, patch_torch_functions=True should only be set by "
"selecting opt_level='O1'.") "selecting opt_level='O1' or 'O4'.")
self.options[name] = value self.options[name] = value
elif name == "patch_torch_functions_type":
if self.opt_level not in {"O1", "O4"} and value is not None:
warn_or_err("Currently, patch_torch_functions_type should only be set by "
"selecting opt_level='O1' or 'O4'.")
elif self.opt_level == "O1" and value != torch.float16:
warn_or_err("patch_torch_functions_type should only be set to torch.float16 "
"for opt_level='O1.")
elif self.opt_level == "O4" and value != torch.bfloat16:
warn_or_err("patch_torch_functions_type should only be set to torch.bfloat16 "
"for opt_level='O4.")
else:
self.options[name] = value
elif name == "keep_batchnorm_fp32": elif name == "keep_batchnorm_fp32":
if self.opt_level == "O1" and value is not None: if self.opt_level in {"O1", "O4"} and value is not None:
warn_or_err("With opt_level O1, batchnorm functions are automatically patched " warn_or_err("With opt_level O1 or O4, batchnorm functions are automatically patched "
"to run in FP32, so keep_batchnorm_fp32 should be None." + "to run in FP32, so keep_batchnorm_fp32 should be None." +
" keep_batchnorm_fp32 was {}".format(value)) " keep_batchnorm_fp32 was {}".format(value))
if value == "False": if value == "False":
...@@ -82,9 +95,9 @@ class Properties(object): ...@@ -82,9 +95,9 @@ class Properties(object):
"or None, found keep_batchnorm_fp32={}".format(value) "or None, found keep_batchnorm_fp32={}".format(value)
self.options[name] = value self.options[name] = value
elif name == "master_weights": elif name == "master_weights":
if self.opt_level == "O1" and value is not None: if self.opt_level in {"O1", "O4"} and value is not None:
warn_or_err("It doesn't make sense to use master_weights with O1. " warn_or_err("It doesn't make sense to use master_weights with O1 and O4 . "
"With O1, your model weights themselves should be FP32.") "With O1 and O4, your model weights themselves should be FP32.")
self.options[name] = value self.options[name] = value
elif name == "loss_scale": elif name == "loss_scale":
if value == "dynamic": if value == "dynamic":
...@@ -113,6 +126,7 @@ class O3: ...@@ -113,6 +126,7 @@ class O3:
properties.opt_level = "O3" properties.opt_level = "O3"
properties.cast_model_type = torch.float16 properties.cast_model_type = torch.float16
properties.patch_torch_functions = False properties.patch_torch_functions = False
properties.patch_torch_functions_type = None
properties.keep_batchnorm_fp32 = False properties.keep_batchnorm_fp32 = False
properties.master_weights = False properties.master_weights = False
properties.loss_scale = 1.0 properties.loss_scale = 1.0
...@@ -136,6 +150,7 @@ class O2: ...@@ -136,6 +150,7 @@ class O2:
properties.opt_level = "O2" properties.opt_level = "O2"
properties.cast_model_type = torch.float16 properties.cast_model_type = torch.float16
properties.patch_torch_functions = False properties.patch_torch_functions = False
properties.patch_torch_functions_type = None
properties.keep_batchnorm_fp32 = True properties.keep_batchnorm_fp32 = True
properties.master_weights = True properties.master_weights = True
properties.loss_scale = "dynamic" properties.loss_scale = "dynamic"
...@@ -158,6 +173,7 @@ class O1: ...@@ -158,6 +173,7 @@ class O1:
properties.opt_level = "O1" properties.opt_level = "O1"
properties.cast_model_type = None properties.cast_model_type = None
properties.patch_torch_functions = True properties.patch_torch_functions = True
properties.patch_torch_functions_type = torch.float16
properties.keep_batchnorm_fp32 = None properties.keep_batchnorm_fp32 = None
properties.master_weights = None properties.master_weights = None
properties.loss_scale = "dynamic" properties.loss_scale = "dynamic"
...@@ -177,6 +193,7 @@ class O0: ...@@ -177,6 +193,7 @@ class O0:
properties.opt_level = "O0" properties.opt_level = "O0"
properties.cast_model_type = torch.float32 properties.cast_model_type = torch.float32
properties.patch_torch_functions = False properties.patch_torch_functions = False
properties.patch_torch_functions_type = None
properties.keep_batchnorm_fp32 = None properties.keep_batchnorm_fp32 = None
properties.master_weights = False properties.master_weights = False
properties.loss_scale = 1.0 properties.loss_scale = 1.0
...@@ -184,11 +201,54 @@ class O0: ...@@ -184,11 +201,54 @@ class O0:
# properties.enable_ddp_interop = False # properties.enable_ddp_interop = False
return properties # modified in place so this isn't really necessary return properties # modified in place so this isn't really necessary
class O4:
brief = "O4: Insert automatic casts around Pytorch functions and Tensor methods.\n"
more = "The type of your model's weights is not altered. However, internally,\n"\
"Pytorch functions are patched to cast any Tensor Core-friendly ops to BFLOAT16 for speed,\n"\
"while operations that might benefit from the additional stability of FP32 are patched\n"\
"to cast their inputs to fp32.\n"\
"Loss scaling is not required in O4 mode since bflaot16 has the same dynamic range as fp32."
def __call__(self, properties):
properties.enabled = True
properties.opt_level = "O4"
properties.cast_model_type = None
properties.patch_torch_functions = True
properties.patch_torch_functions_type = torch.bfloat16
properties.keep_batchnorm_fp32 = None
properties.master_weights = None
properties.loss_scale = 1
return properties # modified in place so this isn't really necessary
class O5:
brief = "O5: BFLOAT16 training with FP32 batchnorm and FP32 master weights.\n"
more = "Calls .bfloat16() on your model, converting the entire model (except for batchnorms)\n"\
"to BFLOAT16. Batchnorms are retained in FP32 for additional stability.\n"\
"The forward pass is patched to cast incoming Tensors to BFLOAT16, so you don't need to change\n"\
"your data pipeline.\n"\
"O5 creates FP32 master weights outside the model and patches any optimizers to update\n"\
"these master weights, then copy the master weights into the BFLOAT16 model weights.\n"\
"Master weights can also improve convergence and stability."
def __call__(self, properties):
properties.enabled = True
properties.opt_level = "O5"
properties.cast_model_type = torch.bfloat16
properties.patch_torch_functions = False
properties.patch_torch_functions = None
properties.patch_torch_functions_type = None
properties.keep_batchnorm_fp32 = True
properties.master_weights = True
properties.loss_scale = 1
return properties # modified in place so this isn't really necessary
opt_levels = {"O3": O3(), opt_levels = {"O3": O3(),
"O2": O2(), "O2": O2(),
"O1": O1(), "O1": O1(),
"O0": O0()} "O0": O0(),
"O4": O4(),
"O5": O5()}
# allow user to directly pass Properties struct as well? # allow user to directly pass Properties struct as well?
...@@ -199,6 +259,7 @@ def initialize( ...@@ -199,6 +259,7 @@ def initialize(
opt_level="O1", opt_level="O1",
cast_model_type=None, cast_model_type=None,
patch_torch_functions=None, patch_torch_functions=None,
patch_torch_functions_type=None,
keep_batchnorm_fp32=None, keep_batchnorm_fp32=None,
master_weights=None, master_weights=None,
loss_scale=None, loss_scale=None,
...@@ -235,10 +296,11 @@ def initialize( ...@@ -235,10 +296,11 @@ def initialize(
enabled (bool, optional, default=True): If False, renders all Amp calls no-ops, so your script enabled (bool, optional, default=True): If False, renders all Amp calls no-ops, so your script
should run as if Amp were not present. should run as if Amp were not present.
opt_level (str, optional, default="O1"): Pure or mixed precision optimization level. Accepted values are opt_level (str, optional, default="O1"): Pure or mixed precision optimization level. Accepted values are
"O0", "O1", "O2", and "O3", explained in detail above. "O0", "O1", "O2", "O3", "O4" and "O5", explained in detail above.
cast_model_type (``torch.dtype``, optional, default=None): Optional property override, see cast_model_type (``torch.dtype``, optional, default=None): Optional property override, see
above. above.
patch_torch_functions (bool, optional, default=None): Optional property override. patch_torch_functions (bool, optional, default=None): Optional property override.
patch_torch_functions_type (``torch.dtype``, optional, default=None): Optional property override
keep_batchnorm_fp32 (bool or str, optional, default=None): Optional property override. If keep_batchnorm_fp32 (bool or str, optional, default=None): Optional property override. If
passed as a string, must be the string "True" or "False". passed as a string, must be the string "True" or "False".
master_weights (bool, optional, default=None): Optional property override. master_weights (bool, optional, default=None): Optional property override.
...@@ -321,7 +383,7 @@ def initialize( ...@@ -321,7 +383,7 @@ def initialize(
if opt_level not in opt_levels: if opt_level not in opt_levels:
raise RuntimeError( raise RuntimeError(
"Unexpected optimization level {}. ".format(opt_level) + "Unexpected optimization level {}. ".format(opt_level) +
"Options are 'O0', 'O1', 'O2', 'O3'. Note that in `O0`, `O1`, etc., the prefix O is the letter O, " + "Options are 'O0', 'O1', 'O2', 'O3', 'O4', 'O5'. Note that in `O0`, `O1`, etc., the prefix O is the letter O, " +
"not the number zero.") "not the number zero.")
else: else:
_amp_state.opt_properties = opt_levels[opt_level](_amp_state.opt_properties) _amp_state.opt_properties = opt_levels[opt_level](_amp_state.opt_properties)
...@@ -344,6 +406,8 @@ def initialize( ...@@ -344,6 +406,8 @@ def initialize(
_amp_state.opt_properties.cast_model_type = cast_model_type _amp_state.opt_properties.cast_model_type = cast_model_type
if patch_torch_functions is not None: if patch_torch_functions is not None:
_amp_state.opt_properties.patch_torch_functions = patch_torch_functions _amp_state.opt_properties.patch_torch_functions = patch_torch_functions
if patch_torch_functions_type is not None:
_amp_state.opt_properties.patch_torch_functions_type = patch_torch_functions_type
if keep_batchnorm_fp32 is not None: if keep_batchnorm_fp32 is not None:
_amp_state.opt_properties.keep_batchnorm_fp32 = keep_batchnorm_fp32 _amp_state.opt_properties.keep_batchnorm_fp32 = keep_batchnorm_fp32
if master_weights is not None: if master_weights is not None:
......
...@@ -26,6 +26,17 @@ FP16_FUNCS = [ ...@@ -26,6 +26,17 @@ FP16_FUNCS = [
'linear', 'linear',
] ]
BFLOAT16_FUNCS = [
'conv1d',
'conv2d',
'conv3d',
'conv_transpose1d',
'conv_transpose2d',
'conv_transpose3d',
'conv_tbc', # Undocumented / maybe new?
'linear',
]
FP32_FUNCS = [ FP32_FUNCS = [
# Interpolation/Upsampling TODO: Remove for 1.2 # Interpolation/Upsampling TODO: Remove for 1.2
......
...@@ -15,6 +15,10 @@ FP16_FUNCS = [ ...@@ -15,6 +15,10 @@ FP16_FUNCS = [
'__matmul__', '__matmul__',
] ]
BFLOAT16_FUNCS = [
'__matmul__',
]
FP32_FUNCS = [ FP32_FUNCS = [
'__ipow__', '__ipow__',
'__pow__', '__pow__',
...@@ -56,7 +60,7 @@ SEQUENCE_CASTS = [] ...@@ -56,7 +60,7 @@ SEQUENCE_CASTS = []
# between `torch` and `torch.Tensor` (and check with `hasattr`, # between `torch` and `torch.Tensor` (and check with `hasattr`,
# because a few random ones aren't defined on Tensor) # because a few random ones aren't defined on Tensor)
_self_mod = importlib.import_module(__name__) _self_mod = importlib.import_module(__name__)
for attrname in ['FP16_FUNCS', 'FP32_FUNCS', 'CASTS', 'SEQUENCE_CASTS']: for attrname in ['FP16_FUNCS', 'BFLOAT16_FUNCS', 'FP32_FUNCS', 'CASTS', 'SEQUENCE_CASTS']:
lst = getattr(_self_mod, attrname) lst = getattr(_self_mod, attrname)
for fn in getattr(torch_overrides, attrname): for fn in getattr(torch_overrides, attrname):
if hasattr(MODULE, fn): if hasattr(MODULE, fn):
......
...@@ -26,6 +26,27 @@ FP16_FUNCS = [ ...@@ -26,6 +26,27 @@ FP16_FUNCS = [
'mv', 'mv',
] ]
BFLOAT16_FUNCS = [
# Low level functions wrapped by torch.nn layers.
# The wrapper layers contain the weights which are then passed in as a parameter
# to these functions.
'conv1d',
'conv2d',
'conv3d',
'conv_transpose1d',
'conv_transpose2d',
'conv_transpose3d',
'conv_tbc',
# BLAS
'addmm',
'addmv',
'addr',
'matmul',
'mm',
'mv',
]
FP32_FUNCS = [ FP32_FUNCS = [
# Pointwise # Pointwise
'acos', 'acos',
......
...@@ -62,6 +62,17 @@ def maybe_half(x, name='', verbose=False): ...@@ -62,6 +62,17 @@ def maybe_half(x, name='', verbose=False):
print('Float->Half ({})'.format(name)) print('Float->Half ({})'.format(name))
return x.half() return x.half()
def maybe_bfloat16(x, name='', verbose=False):
if is_nested(x):
return type(x)([maybe_bfloat16(y) for y in x])
if not x.is_cuda or type_string(x) == 'BFloat16Tensor':
return x
else:
if verbose:
print('Float->BFloat16 ({})'.format(name))
return x.bfloat16()
def maybe_float(x, name='', verbose=False): def maybe_float(x, name='', verbose=False):
if is_nested(x): if is_nested(x):
return type(x)([maybe_float(y) for y in x]) return type(x)([maybe_float(y) for y in x])
......
...@@ -102,6 +102,8 @@ def promote_match_arg0(mod, fn, handle, verbose=False): ...@@ -102,6 +102,8 @@ def promote_match_arg0(mod, fn, handle, verbose=False):
if utils.type_string(arg0) == 'HalfTensor': if utils.type_string(arg0) == 'HalfTensor':
cast_fn = utils.maybe_half cast_fn = utils.maybe_half
if utils.type_string(arg0) == 'BFloat16Tensor':
cast_fn = utils.maybe_bfloat16
elif utils.type_string(arg0) == 'FloatTensor': elif utils.type_string(arg0) == 'FloatTensor':
cast_fn = utils.maybe_float cast_fn = utils.maybe_float
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment