Commit 248d7b10 authored by Michael Carilli's avatar Michael Carilli
Browse files

Updating error checking on property overrides

parent d44ce75a
...@@ -13,8 +13,10 @@ def warn_or_err(msg): ...@@ -13,8 +13,10 @@ def warn_or_err(msg):
if _amp_state.hard_override: if _amp_state.hard_override:
print("Warning: " + msg) print("Warning: " + msg)
else: else:
raise RuntimeError(msg + " If you're sure you know what you're doing, supply " + raise RuntimeError(msg)
"hard_override=True to amp.initialize.") # I'm not sure if allowing hard_override is a good idea.
# + " If you're sure you know what you're doing, supply " +
# "hard_override=True to amp.initialize.")
# def iter_params(param_groups): # def iter_params(param_groups):
# for group in param_groups: # for group in param_groups:
......
...@@ -163,10 +163,12 @@ def _initialize(models, optimizers, properties): ...@@ -163,10 +163,12 @@ def _initialize(models, optimizers, properties):
optimizers[i] = wrap_fused_adam(optimizer, properties) optimizers[i] = wrap_fused_adam(optimizer, properties)
if properties.loss_scale == "dynamic": if properties.loss_scale == "dynamic":
optimizers[i] = FP16_Optimizer_general(optimizer, optimizers[i] = FP16_Optimizer_general(optimizer,
dynamic_loss_scale=True) dynamic_loss_scale=True,
verbose=False)
else: else:
optimizers[i] = FP16_Optimizer_general(optimizer, optimizers[i] = FP16_Optimizer_general(optimizer,
static_loss_scale=properties.loss_scale) static_loss_scale=properties.loss_scale,
verbose=False)
else: else:
for optimizer in optimizers: for optimizer in optimizers:
optimizer.loss_scaler = LossScaler(properties.loss_scale) optimizer.loss_scaler = LossScaler(properties.loss_scale)
......
...@@ -16,7 +16,7 @@ class Properties(object): ...@@ -16,7 +16,7 @@ class Properties(object):
"cast_model_type" : None, "cast_model_type" : None,
"patch_torch_functions" : False, "patch_torch_functions" : False,
"keep_batchnorm_fp32" : None, "keep_batchnorm_fp32" : None,
"master_weights" : False, "master_weights" : None,
"loss_scale" : 1.0, "loss_scale" : 1.0,
# Reserved for future functionality # Reserved for future functionality
# "fused_optimizer" : False, # "fused_optimizer" : False,
...@@ -51,12 +51,25 @@ class Properties(object): ...@@ -51,12 +51,25 @@ class Properties(object):
if "options" in self.__dict__: if "options" in self.__dict__:
if name in self.options: if name in self.options:
# print("setting {} {}".format(name, value)) # print("setting {} {}".format(name, value))
if name == "loss_scale": if name == "cast_model_type":
if value == "dynamic": if self.opt_level == "O1" and value is not None:
self.options[name] = value if value is not torch.float32:
else: warn_or_err("O1 inserts casts around Torch functions rather than "
self.options[name] = float(value) "model weights, so with O1, the model weights themselves "
"should remain FP32. If you wish to cast the model to a "
"different type, use opt_level='O2' or 'O3'. " +
"cast_model_type was {}".format(value))
self.options[name] = value
elif name == "patch_torch_functions":
if self.opt_level != "O1" and value:
warn_or_err("Currently, patch_torch_functions=True should only be set by "
"selecting opt_level='O1'.")
self.options[name] = value
elif name == "keep_batchnorm_fp32": elif name == "keep_batchnorm_fp32":
if self.opt_level == "O1" and value is not None:
warn_or_err("With opt_level O1, batchnorm functions are automatically patched "
"to run in FP32, so keep_batchnorm_fp32 should be None." +
"keep_batchnorm_fp32 was {}".format(keep_batchnorm_fp32))
if value == "False": if value == "False":
self.options[name] = False self.options[name] = False
elif value == "True": elif value == "True":
...@@ -64,8 +77,18 @@ class Properties(object): ...@@ -64,8 +77,18 @@ class Properties(object):
else: else:
assert (value is True or value is False or value is None),\ assert (value is True or value is False or value is None),\
"keep_batchnorm_fp32 must be a boolean, the string 'True' or 'False', "\ "keep_batchnorm_fp32 must be a boolean, the string 'True' or 'False', "\
"or None" "or None, found keep_batchnorm_fp32={}".format(keep_batchnorm_fp32)
self.options[name] = value self.options[name] = value
elif name == "master_weights":
if self.opt_level == "O1" and value is not None:
warn_or_err("It doesn't make sense to use master_weights with O1. "
"With O1, your model weights themselves should be FP32.")
self.options[name] = value
elif name == "loss_scale":
if value == "dynamic":
self.options[name] = value
else:
self.options[name] = float(value)
else: else:
self.options[name] = value self.options[name] = value
else: else:
...@@ -131,10 +154,10 @@ class O1: ...@@ -131,10 +154,10 @@ class O1:
def __call__(self, properties): def __call__(self, properties):
properties.enabled = True properties.enabled = True
properties.opt_level = "O1" properties.opt_level = "O1"
properties.cast_model_type = False properties.cast_model_type = None
properties.patch_torch_functions = True properties.patch_torch_functions = True
properties.keep_batchnorm_fp32 = None properties.keep_batchnorm_fp32 = None
properties.master_weights = False properties.master_weights = None
properties.loss_scale = "dynamic" properties.loss_scale = "dynamic"
# properties.fused_optimizer = False # properties.fused_optimizer = False
# properties.enable_ddp_interop = False # properties.enable_ddp_interop = False
...@@ -167,7 +190,17 @@ opt_levels = {"O3": O3(), ...@@ -167,7 +190,17 @@ opt_levels = {"O3": O3(),
# allow user to directly pass Properties struct as well? # allow user to directly pass Properties struct as well?
def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs): def initialize(
models,
optimizers,
enabled=True,
opt_level=None,
cast_model_type=None,
patch_torch_functions=None,
keep_batchnorm_fp32=None,
master_weights=None,
loss_scale=None
):
""" """
Initialize your models, optimizers, and the Torch tensor and functional namespace according to the Initialize your models, optimizers, and the Torch tensor and functional namespace according to the
chosen ``opt_level`` and overridden properties, if any. chosen ``opt_level`` and overridden properties, if any.
...@@ -182,7 +215,7 @@ def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs): ...@@ -182,7 +215,7 @@ def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs):
should run as if Amp were not present. should run as if Amp were not present.
opt_level(str, required): Pure or mixed precision optimization level. Accepted values are opt_level(str, required): Pure or mixed precision optimization level. Accepted values are
"O0", "O1", "O2", and "O3", which are explained in detail above. "O0", "O1", "O2", and "O3", which are explained in detail above.
cast_model_type (torch.dtype, optional, default=None): Optional property override, see cast_model_type (``torch.dtype``, optional, default=None): Optional property override, see
above. above.
patch_torch_functions (bool, optional, default=None): Optional property override. patch_torch_functions (bool, optional, default=None): Optional property override.
keep_batchnorm_fp32 (bool or str, optional, default=None): Optional property override. If keep_batchnorm_fp32 (bool or str, optional, default=None): Optional property override. If
...@@ -225,8 +258,6 @@ def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs): ...@@ -225,8 +258,6 @@ def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs):
https://github.com/NVIDIA/apex/tree/master/examples/imagenet https://github.com/NVIDIA/apex/tree/master/examples/imagenet
""" """
if not enabled: if not enabled:
if "hard_override" in kwargs:
_amp_state.hard_override = kwargs["hard_override"]
_amp_state.opt_properties = Properties() _amp_state.opt_properties = Properties()
return models, optimizers return models, optimizers
...@@ -243,11 +274,22 @@ def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs): ...@@ -243,11 +274,22 @@ def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs):
print("{:22} : {}".format(k, v)) print("{:22} : {}".format(k, v))
print("Processing user overrides (additional kwargs that are not None)...") print("Processing user overrides (additional kwargs that are not None)...")
for k, v in kwargs.items(): # I chose to have the keyword arguments listed directly in the argument list, so I
if k not in _amp_state.opt_properties.options: # can't use kwargs.items() here.
raise RuntimeError("Unexpected kwarg {}".format(k)) if enabled is not None:
if v is not None: _amp_state.opt_properties.enabled = enabled
setattr(_amp_state.opt_properties, k, v) if opt_level is not None:
_amp_state.opt_properties.opt_level = opt_level
if cast_model_type is not None:
_amp_state.opt_properties.cast_model_type = cast_model_type
if patch_torch_functions is not None:
_amp_state.opt_properties.patch_torch_functions = patch_torch_functions
if keep_batchnorm_fp32 is not None:
_amp_state.opt_properties.keep_batchnorm_fp32 = keep_batchnorm_fp32
if master_weights is not None:
_amp_state.opt_properties.master_weights = master_weights
if loss_scale is not None:
_amp_state.opt_properties.loss_scale = loss_scale
print("After processing overrides, optimization options are:") print("After processing overrides, optimization options are:")
for k, v in _amp_state.opt_properties.options.items(): for k, v in _amp_state.opt_properties.options.items():
......
...@@ -58,6 +58,11 @@ do ...@@ -58,6 +58,11 @@ do
do do
for keep_batchnorm in "${keep_batchnorms[@]}" for keep_batchnorm in "${keep_batchnorms[@]}"
do do
if [ "$opt_level" == "O1" ] && [ -n "${keep_batchnorm}" ]
then
print_banner "Skipping ${opt_level} ${loss_scale} ${keep_batchnorm}"
continue
fi
print_banner "${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} --has-ext $DATADIR" print_banner "${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} --has-ext $DATADIR"
set -x set -x
${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} --has-ext $DATADIR ${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} --has-ext $DATADIR
...@@ -90,6 +95,11 @@ do ...@@ -90,6 +95,11 @@ do
do do
for keep_batchnorm in "${keep_batchnorms[@]}" for keep_batchnorm in "${keep_batchnorms[@]}"
do do
if [ "$opt_level" == "O1" ] && [ -n "${keep_batchnorm}" ]
then
print_banner "Skipping ${opt_level} ${loss_scale} ${keep_batchnorm}"
continue
fi
print_banner "${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} $DATADIR" print_banner "${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} $DATADIR"
set -x set -x
${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} $DATADIR ${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} $DATADIR
...@@ -107,6 +117,11 @@ do ...@@ -107,6 +117,11 @@ do
for keep_batchnorm in "${keep_batchnorms[@]}" for keep_batchnorm in "${keep_batchnorms[@]}"
do do
echo "" echo ""
if [ "$opt_level" == "O1" ] && [ -n "${keep_batchnorm}" ]
then
echo "Skipping ${opt_level} ${loss_scale} ${keep_batchnorm}"
continue
fi
echo "${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} [--has-ext] $DATADIR" echo "${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} [--has-ext] $DATADIR"
set -x set -x
python compare.py --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} python compare.py --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment