Unverified Commit 42a8af1d authored by Reed's avatar Reed Committed by GitHub
Browse files

Have each model provide a default loss scale. (#6930)

Before, there was a global default loss scale for all models. Currently, only resnet uses loss scaling, but this will be useful once more models support it.
parent 0a83bef9
......@@ -196,7 +196,8 @@ def run(flags_obj):
# TODO(reedwm): Remove manually wrapping optimizer once mixed precision
# can be enabled with a single line of code.
optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
optimizer, loss_scale=flags_core.get_loss_scale(flags_obj))
optimizer, loss_scale=flags_core.get_loss_scale(flags_obj,
default_for_fp16=128))
if flags_obj.enable_xla and not flags_obj.enable_eager:
# TODO(b/129861005): Fix OOM issue in eager mode when setting
......
......@@ -591,7 +591,8 @@ def resnet_main(
'data_format': flags_obj.data_format,
'batch_size': flags_obj.batch_size,
'resnet_version': int(flags_obj.resnet_version),
'loss_scale': flags_core.get_loss_scale(flags_obj),
'loss_scale': flags_core.get_loss_scale(flags_obj,
default_for_fp16=128),
'dtype': flags_core.get_tf_dtype(flags_obj),
'fine_tune': flags_obj.fine_tune,
'num_workers': num_workers,
......
......@@ -26,10 +26,10 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.flags._conventions import help_wrap
# Map string to (TensorFlow dtype, default loss scale)
# Map string to TensorFlow dtype
DTYPE_MAP = {
"fp16": (tf.float16, 128),
"fp32": (tf.float32, 1),
"fp16": tf.float16,
"fp32": tf.float32,
}
......@@ -38,15 +38,19 @@ def get_tf_dtype(flags_obj):
# If the graph_rewrite is used, we build the graph with fp32, and let the
# graph rewrite change ops to fp16.
return tf.float32
return DTYPE_MAP[flags_obj.dtype][0]
return DTYPE_MAP[flags_obj.dtype]
def get_loss_scale(flags_obj):
def get_loss_scale(flags_obj, default_for_fp16):
if flags_obj.loss_scale == "dynamic":
return flags_obj.loss_scale
elif flags_obj.loss_scale is not None:
return float(flags_obj.loss_scale)
return DTYPE_MAP[flags_obj.dtype][1]
elif flags_obj.dtype == "fp32":
return 1 # No loss scaling is needed for fp32
else:
assert flags_obj.dtype == "fp16"
return default_for_fp16
def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
......
......@@ -80,20 +80,30 @@ class BaseTester(unittest.TestCase):
assert flags.FLAGS.use_synthetic_data
def test_parse_dtype_info(self):
for dtype_str, tf_dtype, loss_scale in [["fp16", tf.float16, 128],
["fp32", tf.float32, 1]]:
flags_core.parse_flags([__file__, "--dtype", dtype_str])
flags_core.parse_flags([__file__, "--dtype", "fp16"])
self.assertEqual(flags_core.get_tf_dtype(flags.FLAGS), tf.float16)
self.assertEqual(flags_core.get_loss_scale(flags.FLAGS,
default_for_fp16=2), 2)
flags_core.parse_flags(
[__file__, "--dtype", "fp16", "--loss_scale", "5"])
self.assertEqual(flags_core.get_loss_scale(flags.FLAGS,
default_for_fp16=2), 5)
flags_core.parse_flags(
[__file__, "--dtype", "fp16", "--loss_scale", "dynamic"])
self.assertEqual(flags_core.get_loss_scale(flags.FLAGS,
default_for_fp16=2), "dynamic")
flags_core.parse_flags([__file__, "--dtype", "fp32"])
self.assertEqual(flags_core.get_tf_dtype(flags.FLAGS), tf.float32)
self.assertEqual(flags_core.get_loss_scale(flags.FLAGS,
default_for_fp16=2), 1)
flags_core.parse_flags([__file__, "--dtype", "fp32", "--loss_scale", "5"])
self.assertEqual(flags_core.get_loss_scale(flags.FLAGS,
default_for_fp16=2), 5)
self.assertEqual(flags_core.get_tf_dtype(flags.FLAGS), tf_dtype)
self.assertEqual(flags_core.get_loss_scale(flags.FLAGS), loss_scale)
flags_core.parse_flags(
[__file__, "--dtype", dtype_str, "--loss_scale", "5"])
self.assertEqual(flags_core.get_loss_scale(flags.FLAGS), 5)
flags_core.parse_flags(
[__file__, "--dtype", dtype_str, "--loss_scale", "dynamic"])
self.assertEqual(flags_core.get_loss_scale(flags.FLAGS), "dynamic")
with self.assertRaises(SystemExit):
flags_core.parse_flags([__file__, "--dtype", "int8"])
......
......@@ -36,31 +36,30 @@
3. **Flag values should not be mutated.**
Instead of mutating flag values, use getter functions to return the desired values. An example
getter function is `get_loss_scale` function below:
getter function is `get_tf_dtype` function below:
```
# Map string to (TensorFlow dtype, default loss scale)
# Map string to TensorFlow dtype
DTYPE_MAP = {
"fp16": (tf.float16, 128),
"fp32": (tf.float32, 1),
"fp16": tf.float16,
"fp32": tf.float32,
}
def get_loss_scale(flags_obj):
if flags_obj.loss_scale == "dynamic":
return flags_obj.loss_scale
if flags_obj.loss_scale is not None:
return flags_obj.loss_scale
return DTYPE_MAP[flags_obj.dtype][1]
def get_tf_dtype(flags_obj):
if getattr(flags_obj, "fp16_implementation", None) == "graph_rewrite":
# If the graph_rewrite is used, we build the graph with fp32, and let the
# graph rewrite change ops to fp16.
return tf.float32
return DTYPE_MAP[flags_obj.dtype]
def main(_):
flags_obj = flags.FLAGS()
# Do not mutate flags_obj
# if flags_obj.loss_scale is None:
# flags_obj.loss_scale = DTYPE_MAP[flags_obj.dtype][1] # Don't do this
# if flags_obj.fp16_implementation == "graph_rewrite":
# flags_obj.dtype = "float32" # Don't do this
print(get_loss_scale(flags_obj))
print(get_tf_dtype(flags_obj))
...
```
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment