Commit bdae51af authored by Reed's avatar Reed Committed by Toby Boyd
Browse files

Include flags when reporting benchmark. (#6809)

This will allow one to easily reproduce a benchmark by running with the flags.
parent 23662bb4
...@@ -26,6 +26,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order ...@@ -26,6 +26,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import cifar10_main as cifar_main from official.resnet import cifar10_main as cifar_main
from official.resnet import imagenet_main from official.resnet import imagenet_main
from official.utils.flags import core as flags_core
from official.utils.logs import hooks from official.utils.logs import hooks
IMAGENET_DATA_DIR_NAME = 'imagenet' IMAGENET_DATA_DIR_NAME = 'imagenet'
...@@ -106,10 +107,12 @@ class EstimatorBenchmark(tf.test.Benchmark): ...@@ -106,10 +107,12 @@ class EstimatorBenchmark(tf.test.Benchmark):
exp_per_sec = sum(exp_per_second_list) / (len(exp_per_second_list)) exp_per_sec = sum(exp_per_second_list) / (len(exp_per_second_list))
metrics.append({'name': 'exp_per_second', metrics.append({'name': 'exp_per_second',
'value': exp_per_sec}) 'value': exp_per_sec})
flags_str = flags_core.get_nondefault_flags_as_str()
self.report_benchmark( self.report_benchmark(
iters=eval_results['global_step'], iters=eval_results['global_step'],
wall_time=wall_time_sec, wall_time=wall_time_sec,
metrics=metrics) metrics=metrics,
extras={'flags': flags_str})
class Resnet50EstimatorAccuracy(EstimatorBenchmark): class Resnet50EstimatorAccuracy(EstimatorBenchmark):
......
...@@ -23,6 +23,7 @@ from __future__ import print_function ...@@ -23,6 +23,7 @@ from __future__ import print_function
import functools import functools
import sys import sys
from six.moves import shlex_quote
from absl import app as absl_app from absl import app as absl_app
from absl import flags from absl import flags
...@@ -86,3 +87,45 @@ get_tf_dtype = _performance.get_tf_dtype ...@@ -86,3 +87,45 @@ get_tf_dtype = _performance.get_tf_dtype
get_loss_scale = _performance.get_loss_scale get_loss_scale = _performance.get_loss_scale
DTYPE_MAP = _performance.DTYPE_MAP DTYPE_MAP = _performance.DTYPE_MAP
require_cloud_storage = _device.require_cloud_storage require_cloud_storage = _device.require_cloud_storage
def _get_nondefault_flags_as_dict():
"""Returns the nondefault flags as a dict from flag name to value."""
nondefault_flags = {}
for flag_name in flags.FLAGS:
flag_value = getattr(flags.FLAGS, flag_name)
if (flag_name != flags.FLAGS[flag_name].short_name and
flag_value != flags.FLAGS[flag_name].default):
nondefault_flags[flag_name] = flag_value
return nondefault_flags
def get_nondefault_flags_as_str():
"""Returns flags as a string that can be passed as command line arguments.
E.g., returns: "--batch_size=256 --use_synthetic_data" for the following code
block:
```
flags.FLAGS.batch_size = 256
flags.FLAGS.use_synthetic_data = True
print(get_nondefault_flags_as_str())
```
Only flags with nondefault values are returned, as passing default flags as
command line arguments has no effect.
Returns:
A string with the flags, that can be passed as command line arguments to a
program to use the flags.
"""
nondefault_flags = _get_nondefault_flags_as_dict()
flag_strings = []
for name, value in sorted(nondefault_flags.items()):
if isinstance(value, bool):
flag_str = '--{}'.format(name) if value else '--no{}'.format(name)
elif isinstance(value, list):
flag_str = '--{}={}'.format(name, ','.join(value))
else:
flag_str = '--{}={}'.format(name, value)
flag_strings.append(flag_str)
return ' '.join(shlex_quote(flag_str) for flag_str in flag_strings)
...@@ -102,6 +102,45 @@ class BaseTester(unittest.TestCase): ...@@ -102,6 +102,45 @@ class BaseTester(unittest.TestCase):
flags_core.parse_flags([__file__, "--dtype", "fp16", flags_core.parse_flags([__file__, "--dtype", "fp16",
"--loss_scale", "abc"]) "--loss_scale", "abc"])
def test_get_nondefault_flags_as_str(self):
defaults = dict(
clean=True,
data_dir="abc",
hooks=["LoggingTensorHook"],
stop_threshold=1.5,
use_synthetic_data=False
)
flags_core.set_defaults(**defaults)
flags_core.parse_flags()
expected_flags = ""
self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags)
flags.FLAGS.clean = False
expected_flags += "--noclean"
self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags)
flags.FLAGS.data_dir = "xyz"
expected_flags += " --data_dir=xyz"
self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags)
flags.FLAGS.hooks = ["aaa", "bbb", "ccc"]
expected_flags += " --hooks=aaa,bbb,ccc"
self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags)
flags.FLAGS.stop_threshold = 3.
expected_flags += " --stop_threshold=3.0"
self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags)
flags.FLAGS.use_synthetic_data = True
expected_flags += " --use_synthetic_data"
self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags)
# Assert that explicit setting a flag to its default value does not cause it
# to appear in the string
flags.FLAGS.use_synthetic_data = False
expected_flags = expected_flags[:-len(" --use_synthetic_data")]
self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment