Commit e353e4e5 authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Remove dynamic_loss_scale argument to define_performance.

All models which support loss scaling support dynamic loss scaling, so the argument has no purpose. It used to be that some models scaled the loss manually instead of using a LossScaleOptimizer, and so did not support dynamic loss scaling.

PiperOrigin-RevId: 367719521
parent c4b23773
......@@ -100,7 +100,6 @@ def define_common_bert_flags():
synthetic_data=False,
max_train_steps=False,
dtype=True,
dynamic_loss_scale=True,
loss_scale=True,
all_reduce_alg=True,
num_packs=False,
......
......@@ -162,7 +162,6 @@ def define_ncf_flags():
dtype=True,
fp16_implementation=True,
loss_scale=True,
dynamic_loss_scale=True,
enable_xla=True,
)
flags_core.define_device(tpu=True)
......
......@@ -61,7 +61,6 @@ def define_performance(num_parallel_calls=False,
tf_gpu_thread_mode=False,
datasets_num_private_threads=False,
datasets_num_parallel_batches=False,
dynamic_loss_scale=False,
fp16_implementation=False,
loss_scale=False,
tf_data_experimental_slack=False,
......@@ -84,8 +83,6 @@ def define_performance(num_parallel_calls=False,
datasets_num_private_threads: Number of private threads for datasets.
datasets_num_parallel_batches: Determines how many batches to process in
parallel when using map and batch from tf.data.
dynamic_loss_scale: Allow the "loss_scale" flag to take on the value
"dynamic". Only valid if `dtype` is True.
fp16_implementation: Create fp16_implementation flag.
loss_scale: Controls the loss scaling, normally for mixed-precision
training. Can only be turned on if dtype is also True.
......@@ -156,45 +153,37 @@ def define_performance(num_parallel_calls=False,
default="fp32",
enum_values=DTYPE_MAP.keys(),
help=help_wrap("The TensorFlow datatype used for calculations. "
"Variables may be cast to a higher precision on a "
"case-by-case basis for numerical stability."))
"For 16-bit dtypes, variables and certain ops will "
"still be float32 for numeric stability."))
loss_scale_help_text = (
"The amount to scale the loss by when the model is run. {}. Before "
"gradients are computed, the loss is multiplied by the loss scale, "
"making all gradients loss_scale times larger. To adjust for this, "
"gradients are divided by the loss scale before being applied to "
"variables. This is mathematically equivalent to training without "
"a loss scale, but the loss scale helps avoid some intermediate "
"gradients from underflowing to zero. If not provided the default "
"for fp16 is 128 and 1 for all other dtypes.{}")
if dynamic_loss_scale:
loss_scale_help_text = loss_scale_help_text.format(
"This can be an int/float or the string 'dynamic'",
" The string 'dynamic' can be used to dynamically determine the "
"optimal loss scale during training, but currently this "
"significantly slows down performance")
loss_scale_validation_msg = ("loss_scale should be a positive int/float "
"or the string 'dynamic'.")
else:
loss_scale_help_text = loss_scale_help_text.format(
"This must be an int/float", "")
loss_scale_validation_msg = "loss_scale should be a positive int/float."
if loss_scale:
flags.DEFINE_string(
name="loss_scale",
short_name="ls",
default=None,
help=help_wrap(loss_scale_help_text))
help=help_wrap(
"The amount to scale the loss by when --dtype=fp16. This can be "
"an int/float or the string 'dynamic'. Before gradients are "
"computed, the loss is multiplied by the loss scale, making all "
"gradients loss_scale times larger. To adjust for this, "
"gradients are divided by the loss scale before being applied to "
"variables. This is mathematically equivalent to training "
"without a loss scale, but the loss scale helps avoid some "
"intermediate gradients from underflowing to zero. The default "
"is 'dynamic', which dynamic determines the optimal loss scale "
"during training."))
# pylint: disable=unused-variable
@flags.validator(
flag_name="loss_scale", message=loss_scale_validation_msg)
def _check_loss_scale(loss_scale): # pylint: disable=unused-variable
flag_name="loss_scale",
message="loss_scale should be a positive int/float or the string "
"'dynamic'.")
def _check_loss_scale(loss_scale):
"""Validator to check the loss scale flag is valid."""
if loss_scale is None:
return True # null case is handled in get_loss_scale()
if loss_scale == "dynamic" and dynamic_loss_scale:
if loss_scale == "dynamic":
return True
try:
......@@ -203,6 +192,7 @@ def define_performance(num_parallel_calls=False,
return False
return loss_scale > 0
# pylint: enable=unused-variable
if fp16_implementation:
flags.DEFINE_enum(
......
......@@ -32,7 +32,6 @@ def define_flags():
num_parallel_calls=True,
inter_op=True,
intra_op=True,
dynamic_loss_scale=True,
loss_scale=True,
synthetic_data=True,
dtype=True)
......
......@@ -188,8 +188,7 @@ def build_stats(history, eval_output, callbacks):
return stats
def define_keras_flags(dynamic_loss_scale=True,
model=False,
def define_keras_flags(model=False,
optimizer=False,
pretrained_filepath=False):
"""Define flags for Keras models."""
......@@ -208,7 +207,6 @@ def define_keras_flags(dynamic_loss_scale=True,
num_packs=True,
tf_gpu_thread_mode=True,
datasets_num_private_threads=True,
dynamic_loss_scale=dynamic_loss_scale,
loss_scale=True,
fp16_implementation=True,
tf_data_experimental_slack=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment