Commit 248d7b10 authored by Michael Carilli's avatar Michael Carilli
Browse files

Updating error checking on property overrides

parent d44ce75a
......@@ -13,8 +13,10 @@ def warn_or_err(msg):
if _amp_state.hard_override:
print("Warning: " + msg)
else:
raise RuntimeError(msg + " If you're sure you know what you're doing, supply " +
"hard_override=True to amp.initialize.")
raise RuntimeError(msg)
# I'm not sure if allowing hard_override is a good idea.
# + " If you're sure you know what you're doing, supply " +
# "hard_override=True to amp.initialize.")
# def iter_params(param_groups):
# for group in param_groups:
......
......@@ -163,10 +163,12 @@ def _initialize(models, optimizers, properties):
optimizers[i] = wrap_fused_adam(optimizer, properties)
if properties.loss_scale == "dynamic":
optimizers[i] = FP16_Optimizer_general(optimizer,
dynamic_loss_scale=True)
dynamic_loss_scale=True,
verbose=False)
else:
optimizers[i] = FP16_Optimizer_general(optimizer,
static_loss_scale=properties.loss_scale)
static_loss_scale=properties.loss_scale,
verbose=False)
else:
for optimizer in optimizers:
optimizer.loss_scaler = LossScaler(properties.loss_scale)
......
......@@ -16,7 +16,7 @@ class Properties(object):
"cast_model_type" : None,
"patch_torch_functions" : False,
"keep_batchnorm_fp32" : None,
"master_weights" : False,
"master_weights" : None,
"loss_scale" : 1.0,
# Reserved for future functionality
# "fused_optimizer" : False,
......@@ -51,12 +51,25 @@ class Properties(object):
if "options" in self.__dict__:
if name in self.options:
# print("setting {} {}".format(name, value))
if name == "loss_scale":
if value == "dynamic":
if name == "cast_model_type":
if self.opt_level == "O1" and value is not None:
if value is not torch.float32:
warn_or_err("O1 inserts casts around Torch functions rather than "
"model weights, so with O1, the model weights themselves "
"should remain FP32. If you wish to cast the model to a "
"different type, use opt_level='O2' or 'O3'. " +
"cast_model_type was {}".format(value))
self.options[name] = value
elif name == "patch_torch_functions":
if self.opt_level != "O1" and value:
warn_or_err("Currently, patch_torch_functions=True should only be set by "
"selecting opt_level='O1'.")
self.options[name] = value
else:
self.options[name] = float(value)
elif name == "keep_batchnorm_fp32":
if self.opt_level == "O1" and value is not None:
warn_or_err("With opt_level O1, batchnorm functions are automatically patched "
"to run in FP32, so keep_batchnorm_fp32 should be None." +
"keep_batchnorm_fp32 was {}".format(keep_batchnorm_fp32))
if value == "False":
self.options[name] = False
elif value == "True":
......@@ -64,8 +77,18 @@ class Properties(object):
else:
assert (value is True or value is False or value is None),\
"keep_batchnorm_fp32 must be a boolean, the string 'True' or 'False', "\
"or None"
"or None, found keep_batchnorm_fp32={}".format(keep_batchnorm_fp32)
self.options[name] = value
elif name == "master_weights":
if self.opt_level == "O1" and value is not None:
warn_or_err("It doesn't make sense to use master_weights with O1. "
"With O1, your model weights themselves should be FP32.")
self.options[name] = value
elif name == "loss_scale":
if value == "dynamic":
self.options[name] = value
else:
self.options[name] = float(value)
else:
self.options[name] = value
else:
......@@ -131,10 +154,10 @@ class O1:
def __call__(self, properties):
properties.enabled = True
properties.opt_level = "O1"
properties.cast_model_type = False
properties.cast_model_type = None
properties.patch_torch_functions = True
properties.keep_batchnorm_fp32 = None
properties.master_weights = False
properties.master_weights = None
properties.loss_scale = "dynamic"
# properties.fused_optimizer = False
# properties.enable_ddp_interop = False
......@@ -167,7 +190,17 @@ opt_levels = {"O3": O3(),
# allow user to directly pass Properties struct as well?
def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs):
def initialize(
models,
optimizers,
enabled=True,
opt_level=None,
cast_model_type=None,
patch_torch_functions=None,
keep_batchnorm_fp32=None,
master_weights=None,
loss_scale=None
):
"""
Initialize your models, optimizers, and the Torch tensor and functional namespace according to the
chosen ``opt_level`` and overridden properties, if any.
......@@ -182,7 +215,7 @@ def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs):
should run as if Amp were not present.
opt_level(str, required): Pure or mixed precision optimization level. Accepted values are
"O0", "O1", "O2", and "O3", which are explained in detail above.
cast_model_type (torch.dtype, optional, default=None): Optional property override, see
cast_model_type (``torch.dtype``, optional, default=None): Optional property override, see
above.
patch_torch_functions (bool, optional, default=None): Optional property override.
keep_batchnorm_fp32 (bool or str, optional, default=None): Optional property override. If
......@@ -225,8 +258,6 @@ def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs):
https://github.com/NVIDIA/apex/tree/master/examples/imagenet
"""
if not enabled:
if "hard_override" in kwargs:
_amp_state.hard_override = kwargs["hard_override"]
_amp_state.opt_properties = Properties()
return models, optimizers
......@@ -243,11 +274,22 @@ def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs):
print("{:22} : {}".format(k, v))
print("Processing user overrides (additional kwargs that are not None)...")
for k, v in kwargs.items():
if k not in _amp_state.opt_properties.options:
raise RuntimeError("Unexpected kwarg {}".format(k))
if v is not None:
setattr(_amp_state.opt_properties, k, v)
# I chose to have the keyword arguments listed directly in the argument list, so I
# can't use kwargs.items() here.
if enabled is not None:
_amp_state.opt_properties.enabled = enabled
if opt_level is not None:
_amp_state.opt_properties.opt_level = opt_level
if cast_model_type is not None:
_amp_state.opt_properties.cast_model_type = cast_model_type
if patch_torch_functions is not None:
_amp_state.opt_properties.patch_torch_functions = patch_torch_functions
if keep_batchnorm_fp32 is not None:
_amp_state.opt_properties.keep_batchnorm_fp32 = keep_batchnorm_fp32
if master_weights is not None:
_amp_state.opt_properties.master_weights = master_weights
if loss_scale is not None:
_amp_state.opt_properties.loss_scale = loss_scale
print("After processing overrides, optimization options are:")
for k, v in _amp_state.opt_properties.options.items():
......
......@@ -58,6 +58,11 @@ do
do
for keep_batchnorm in "${keep_batchnorms[@]}"
do
if [ "$opt_level" == "O1" ] && [ -n "${keep_batchnorm}" ]
then
print_banner "Skipping ${opt_level} ${loss_scale} ${keep_batchnorm}"
continue
fi
print_banner "${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} --has-ext $DATADIR"
set -x
${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} --has-ext $DATADIR
......@@ -90,6 +95,11 @@ do
do
for keep_batchnorm in "${keep_batchnorms[@]}"
do
if [ "$opt_level" == "O1" ] && [ -n "${keep_batchnorm}" ]
then
print_banner "Skipping ${opt_level} ${loss_scale} ${keep_batchnorm}"
continue
fi
print_banner "${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} $DATADIR"
set -x
${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} $DATADIR
......@@ -107,6 +117,11 @@ do
for keep_batchnorm in "${keep_batchnorms[@]}"
do
echo ""
if [ "$opt_level" == "O1" ] && [ -n "${keep_batchnorm}" ]
then
echo "Skipping ${opt_level} ${loss_scale} ${keep_batchnorm}"
continue
fi
echo "${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} [--has-ext] $DATADIR"
set -x
python compare.py --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment