Commit 8ca30bf6 authored by Will Cromar's avatar Will Cromar Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 294711762
parent 658c84c8
......@@ -42,7 +42,7 @@ flags.DEFINE_boolean(name='single_l2_loss_op', default=False,
'instead of using Keras per-layer L2 loss.')
def build_stats(train_result, eval_result, time_callback):
def build_stats(train_result, eval_result, time_callback, avg_exp_per_second):
"""Normalizes and returns dictionary of stats.
Args:
......@@ -50,6 +50,7 @@ def build_stats(train_result, eval_result, time_callback):
eval_result: Output of the eval step. Assumes first value is eval_loss and
second value is accuracy_top_1.
time_callback: Time tracking callback instance.
avg_exp_per_second: Average training examples per second.
Returns:
Dictionary of normalized results.
......@@ -67,11 +68,8 @@ def build_stats(train_result, eval_result, time_callback):
timestamp_log = time_callback.timestamp_log
stats['step_timestamp_log'] = timestamp_log
stats['train_finish_time'] = time_callback.train_finish_time
if len(timestamp_log) > 1:
stats['avg_exp_per_second'] = (
time_callback.batch_size * time_callback.log_steps *
(len(time_callback.timestamp_log) - 1) /
(timestamp_log[-1].timestamp - timestamp_log[0].timestamp))
stats['avg_exp_per_second'] = avg_exp_per_second
return stats
......@@ -348,6 +346,7 @@ def run(flags_obj):
else:
summary_writer = None
examples_per_second_history = []
train_iter = iter(train_ds)
time_callback.on_train_begin()
for epoch in range(current_step // per_epoch_steps, train_epochs):
......@@ -355,6 +354,7 @@ def run(flags_obj):
training_accuracy.reset_states()
steps_in_current_epoch = 0
time_callback.on_epoch_begin(epoch + 1)
while steps_in_current_epoch < per_epoch_steps:
time_callback.on_batch_begin(
steps_in_current_epoch+epoch*per_epoch_steps)
......@@ -374,6 +374,12 @@ def run(flags_obj):
training_accuracy.result().numpy(),
epoch + 1)
time_callback.on_epoch_end(epoch + 1)
epoch_time = time_callback.epoch_runtime_log[-1]
steps_per_second = per_epoch_steps / epoch_time
examples_per_second = steps_per_second * flags_obj.batch_size
examples_per_second_history.append(examples_per_second)
if (not flags_obj.skip_eval and
(epoch + 1) % flags_obj.epochs_between_evals == 0):
test_loss.reset_states()
......@@ -403,6 +409,8 @@ def run(flags_obj):
tf.summary.scalar('eval_loss', test_loss.result(), current_steps)
tf.summary.scalar(
'eval_accuracy', test_accuracy.result(), current_steps)
tf.summary.scalar('global_step/sec', steps_per_second, current_steps)
tf.summary.scalar('examples/sec', examples_per_second, current_steps)
time_callback.on_train_end()
if summary_writer:
......@@ -416,7 +424,12 @@ def run(flags_obj):
train_result = [train_loss.result().numpy(),
training_accuracy.result().numpy()]
stats = build_stats(train_result, eval_result, time_callback)
stats = build_stats(
train_result,
eval_result,
time_callback,
tf.reduce_mean(examples_per_second_history).numpy(),
)
return stats
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment