Unverified Commit 42a8af1d authored by Reed's avatar Reed Committed by GitHub
Browse files

Have each model provide a default loss scale. (#6930)

Before, there was a global default loss scale for all models. Currently, only resnet uses loss scaling, but this will be useful once more models support it.
parent 0a83bef9
...@@ -196,7 +196,8 @@ def run(flags_obj): ...@@ -196,7 +196,8 @@ def run(flags_obj):
# TODO(reedwm): Remove manually wrapping optimizer once mixed precision # TODO(reedwm): Remove manually wrapping optimizer once mixed precision
# can be enabled with a single line of code. # can be enabled with a single line of code.
optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer( optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
optimizer, loss_scale=flags_core.get_loss_scale(flags_obj)) optimizer, loss_scale=flags_core.get_loss_scale(flags_obj,
default_for_fp16=128))
if flags_obj.enable_xla and not flags_obj.enable_eager: if flags_obj.enable_xla and not flags_obj.enable_eager:
# TODO(b/129861005): Fix OOM issue in eager mode when setting # TODO(b/129861005): Fix OOM issue in eager mode when setting
......
...@@ -591,7 +591,8 @@ def resnet_main( ...@@ -591,7 +591,8 @@ def resnet_main(
'data_format': flags_obj.data_format, 'data_format': flags_obj.data_format,
'batch_size': flags_obj.batch_size, 'batch_size': flags_obj.batch_size,
'resnet_version': int(flags_obj.resnet_version), 'resnet_version': int(flags_obj.resnet_version),
'loss_scale': flags_core.get_loss_scale(flags_obj), 'loss_scale': flags_core.get_loss_scale(flags_obj,
default_for_fp16=128),
'dtype': flags_core.get_tf_dtype(flags_obj), 'dtype': flags_core.get_tf_dtype(flags_obj),
'fine_tune': flags_obj.fine_tune, 'fine_tune': flags_obj.fine_tune,
'num_workers': num_workers, 'num_workers': num_workers,
......
...@@ -26,10 +26,10 @@ import tensorflow as tf # pylint: disable=g-bad-import-order ...@@ -26,10 +26,10 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.flags._conventions import help_wrap from official.utils.flags._conventions import help_wrap
# Map string to (TensorFlow dtype, default loss scale) # Map string to TensorFlow dtype
DTYPE_MAP = { DTYPE_MAP = {
"fp16": (tf.float16, 128), "fp16": tf.float16,
"fp32": (tf.float32, 1), "fp32": tf.float32,
} }
...@@ -38,15 +38,19 @@ def get_tf_dtype(flags_obj): ...@@ -38,15 +38,19 @@ def get_tf_dtype(flags_obj):
# If the graph_rewrite is used, we build the graph with fp32, and let the # If the graph_rewrite is used, we build the graph with fp32, and let the
# graph rewrite change ops to fp16. # graph rewrite change ops to fp16.
return tf.float32 return tf.float32
return DTYPE_MAP[flags_obj.dtype][0] return DTYPE_MAP[flags_obj.dtype]
def get_loss_scale(flags_obj): def get_loss_scale(flags_obj, default_for_fp16):
if flags_obj.loss_scale == "dynamic": if flags_obj.loss_scale == "dynamic":
return flags_obj.loss_scale return flags_obj.loss_scale
elif flags_obj.loss_scale is not None: elif flags_obj.loss_scale is not None:
return float(flags_obj.loss_scale) return float(flags_obj.loss_scale)
return DTYPE_MAP[flags_obj.dtype][1] elif flags_obj.dtype == "fp32":
return 1 # No loss scaling is needed for fp32
else:
assert flags_obj.dtype == "fp16"
return default_for_fp16
def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True, def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
......
...@@ -80,20 +80,30 @@ class BaseTester(unittest.TestCase): ...@@ -80,20 +80,30 @@ class BaseTester(unittest.TestCase):
assert flags.FLAGS.use_synthetic_data assert flags.FLAGS.use_synthetic_data
def test_parse_dtype_info(self): def test_parse_dtype_info(self):
for dtype_str, tf_dtype, loss_scale in [["fp16", tf.float16, 128], flags_core.parse_flags([__file__, "--dtype", "fp16"])
["fp32", tf.float32, 1]]: self.assertEqual(flags_core.get_tf_dtype(flags.FLAGS), tf.float16)
flags_core.parse_flags([__file__, "--dtype", dtype_str]) self.assertEqual(flags_core.get_loss_scale(flags.FLAGS,
default_for_fp16=2), 2)
flags_core.parse_flags(
[__file__, "--dtype", "fp16", "--loss_scale", "5"])
self.assertEqual(flags_core.get_loss_scale(flags.FLAGS,
default_for_fp16=2), 5)
flags_core.parse_flags(
[__file__, "--dtype", "fp16", "--loss_scale", "dynamic"])
self.assertEqual(flags_core.get_loss_scale(flags.FLAGS,
default_for_fp16=2), "dynamic")
flags_core.parse_flags([__file__, "--dtype", "fp32"])
self.assertEqual(flags_core.get_tf_dtype(flags.FLAGS), tf.float32)
self.assertEqual(flags_core.get_loss_scale(flags.FLAGS,
default_for_fp16=2), 1)
flags_core.parse_flags([__file__, "--dtype", "fp32", "--loss_scale", "5"])
self.assertEqual(flags_core.get_loss_scale(flags.FLAGS,
default_for_fp16=2), 5)
self.assertEqual(flags_core.get_tf_dtype(flags.FLAGS), tf_dtype)
self.assertEqual(flags_core.get_loss_scale(flags.FLAGS), loss_scale)
flags_core.parse_flags(
[__file__, "--dtype", dtype_str, "--loss_scale", "5"])
self.assertEqual(flags_core.get_loss_scale(flags.FLAGS), 5)
flags_core.parse_flags(
[__file__, "--dtype", dtype_str, "--loss_scale", "dynamic"])
self.assertEqual(flags_core.get_loss_scale(flags.FLAGS), "dynamic")
with self.assertRaises(SystemExit): with self.assertRaises(SystemExit):
flags_core.parse_flags([__file__, "--dtype", "int8"]) flags_core.parse_flags([__file__, "--dtype", "int8"])
......
...@@ -36,31 +36,30 @@ ...@@ -36,31 +36,30 @@
3. **Flag values should not be mutated.** 3. **Flag values should not be mutated.**
Instead of mutating flag values, use getter functions to return the desired values. An example Instead of mutating flag values, use getter functions to return the desired values. An example
getter function is `get_loss_scale` function below: getter function is `get_tf_dtype` function below:
``` ```
# Map string to (TensorFlow dtype, default loss scale) # Map string to TensorFlow dtype
DTYPE_MAP = { DTYPE_MAP = {
"fp16": (tf.float16, 128), "fp16": tf.float16,
"fp32": (tf.float32, 1), "fp32": tf.float32,
} }
def get_tf_dtype(flags_obj):
def get_loss_scale(flags_obj): if getattr(flags_obj, "fp16_implementation", None) == "graph_rewrite":
if flags_obj.loss_scale == "dynamic": # If the graph_rewrite is used, we build the graph with fp32, and let the
return flags_obj.loss_scale # graph rewrite change ops to fp16.
if flags_obj.loss_scale is not None: return tf.float32
return flags_obj.loss_scale return DTYPE_MAP[flags_obj.dtype]
return DTYPE_MAP[flags_obj.dtype][1]
def main(_): def main(_):
flags_obj = flags.FLAGS() flags_obj = flags.FLAGS()
# Do not mutate flags_obj # Do not mutate flags_obj
# if flags_obj.loss_scale is None: # if flags_obj.fp16_implementation == "graph_rewrite":
# flags_obj.loss_scale = DTYPE_MAP[flags_obj.dtype][1] # Don't do this # flags_obj.dtype = "float32" # Don't do this
print(get_loss_scale(flags_obj)) print(get_tf_dtype(flags_obj))
... ...
``` ```
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment