Commit 808b2f5e authored by Sai Ganesh Bandiatmakuri's avatar Sai Ganesh Bandiatmakuri Committed by A. Unique TensorFlower
Browse files

Inject enable_runtime_flags into benchmarks and add train_steps to shakespeare_main.

This will help general debugging by enabling custom execution with
--benchmark_method_flags.

E.g --benchmark_method_flags=train_steps=7 will run the benchmark for only 7
steps without modifying benchmark code.

PiperOrigin-RevId: 282799341
parent 09a639bf
......@@ -26,6 +26,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from official.staging.shakespeare import shakespeare_main
from official.utils.flags import core as flags_core
from official.utils.misc import keras_utils
from official.utils.testing import benchmark_wrappers
from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark
SHAKESPEARE_TRAIN_DATA = 'shakespeare/shakespeare.txt'
......@@ -42,6 +43,7 @@ class ShakespeareBenchmarkBase(PerfZeroBenchmark):
default_flags=default_flags,
flag_methods=[shakespeare_main.define_flags])
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
top_1_train_min=0.91,
top_1_train_max=0.94,
......@@ -75,9 +77,10 @@ class ShakespeareBenchmarkBase(PerfZeroBenchmark):
for callback in stats['callbacks']:
if isinstance(callback, keras_utils.TimeHistory):
epoch_timings = callback.epoch_runtime_log
average_time = sum(epoch_timings[1:]) / len(epoch_timings[1:])
metrics.append({'name': 'avg_epoch_time',
'value': average_time})
if len(epoch_timings) > 1:
average_time = sum(epoch_timings[1:]) / len(epoch_timings[1:])
metrics.append({'name': 'avg_epoch_time',
'value': average_time})
# First entry in timestamp_log is the start of step 1. The rest of the
# entries are the end of each step recorded.
......
......@@ -75,6 +75,8 @@ def define_flags():
flags.DEFINE_integer(
name='predict_length', default=1000,
help='Length of the predicted text including the context.')
flags.DEFINE_integer(name='train_steps', default=None,
help='Overrides train_steps per epoch if not None.')
flags.DEFINE_integer(
name='log_steps', default=100,
help='For every log_steps, we log the timing information such as '
......@@ -174,7 +176,10 @@ def train_model(flags_obj, dataset, vocab_size, strategy, checkpoint_dir=None):
Returns:
The training history and callbacks.
"""
train_steps = BATCHES_PER_EPOCH // flags_obj.batch_size
if flags_obj.train_steps:
train_steps = flags_obj.train_steps
else:
train_steps = BATCHES_PER_EPOCH // flags_obj.batch_size
strategy_scope = distribution_utils.get_strategy_scope(strategy)
with strategy_scope:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment