Commit 808b2f5e authored by Sai Ganesh Bandiatmakuri's avatar Sai Ganesh Bandiatmakuri Committed by A. Unique TensorFlower
Browse files

Inject enable_runtime_flags into benchmarks and add train_steps to shakespeare_main.

This will help general debugging by enabling custom execution with
--benchmark_method_flags.

E.g --benchmark_method_flags=train_steps=7 will run the benchmark for only 7
steps without modifying benchmark code.

PiperOrigin-RevId: 282799341
parent 09a639bf
...@@ -26,6 +26,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order ...@@ -26,6 +26,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from official.staging.shakespeare import shakespeare_main from official.staging.shakespeare import shakespeare_main
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.utils.testing import benchmark_wrappers
from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark
SHAKESPEARE_TRAIN_DATA = 'shakespeare/shakespeare.txt' SHAKESPEARE_TRAIN_DATA = 'shakespeare/shakespeare.txt'
...@@ -42,6 +43,7 @@ class ShakespeareBenchmarkBase(PerfZeroBenchmark): ...@@ -42,6 +43,7 @@ class ShakespeareBenchmarkBase(PerfZeroBenchmark):
default_flags=default_flags, default_flags=default_flags,
flag_methods=[shakespeare_main.define_flags]) flag_methods=[shakespeare_main.define_flags])
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, def _run_and_report_benchmark(self,
top_1_train_min=0.91, top_1_train_min=0.91,
top_1_train_max=0.94, top_1_train_max=0.94,
...@@ -75,6 +77,7 @@ class ShakespeareBenchmarkBase(PerfZeroBenchmark): ...@@ -75,6 +77,7 @@ class ShakespeareBenchmarkBase(PerfZeroBenchmark):
for callback in stats['callbacks']: for callback in stats['callbacks']:
if isinstance(callback, keras_utils.TimeHistory): if isinstance(callback, keras_utils.TimeHistory):
epoch_timings = callback.epoch_runtime_log epoch_timings = callback.epoch_runtime_log
if len(epoch_timings) > 1:
average_time = sum(epoch_timings[1:]) / len(epoch_timings[1:]) average_time = sum(epoch_timings[1:]) / len(epoch_timings[1:])
metrics.append({'name': 'avg_epoch_time', metrics.append({'name': 'avg_epoch_time',
'value': average_time}) 'value': average_time})
......
...@@ -75,6 +75,8 @@ def define_flags(): ...@@ -75,6 +75,8 @@ def define_flags():
flags.DEFINE_integer( flags.DEFINE_integer(
name='predict_length', default=1000, name='predict_length', default=1000,
help='Length of the predicted text including the context.') help='Length of the predicted text including the context.')
flags.DEFINE_integer(name='train_steps', default=None,
help='Overrides train_steps per epoch if not None.')
flags.DEFINE_integer( flags.DEFINE_integer(
name='log_steps', default=100, name='log_steps', default=100,
help='For every log_steps, we log the timing information such as ' help='For every log_steps, we log the timing information such as '
...@@ -174,6 +176,9 @@ def train_model(flags_obj, dataset, vocab_size, strategy, checkpoint_dir=None): ...@@ -174,6 +176,9 @@ def train_model(flags_obj, dataset, vocab_size, strategy, checkpoint_dir=None):
Returns: Returns:
The training history and callbacks. The training history and callbacks.
""" """
if flags_obj.train_steps:
train_steps = flags_obj.train_steps
else:
train_steps = BATCHES_PER_EPOCH // flags_obj.batch_size train_steps = BATCHES_PER_EPOCH // flags_obj.batch_size
strategy_scope = distribution_utils.get_strategy_scope(strategy) strategy_scope = distribution_utils.get_strategy_scope(strategy)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment