"vscode:/vscode.git/clone" did not exist on "01e6fd51f990bb8bbb9e8eb99cc38ae268a95c01"
Commit 8ca30bf6 authored by Will Cromar's avatar Will Cromar Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 294711762
parent 658c84c8
...@@ -42,7 +42,7 @@ flags.DEFINE_boolean(name='single_l2_loss_op', default=False, ...@@ -42,7 +42,7 @@ flags.DEFINE_boolean(name='single_l2_loss_op', default=False,
'instead of using Keras per-layer L2 loss.') 'instead of using Keras per-layer L2 loss.')
def build_stats(train_result, eval_result, time_callback): def build_stats(train_result, eval_result, time_callback, avg_exp_per_second):
"""Normalizes and returns dictionary of stats. """Normalizes and returns dictionary of stats.
Args: Args:
...@@ -50,6 +50,7 @@ def build_stats(train_result, eval_result, time_callback): ...@@ -50,6 +50,7 @@ def build_stats(train_result, eval_result, time_callback):
eval_result: Output of the eval step. Assumes first value is eval_loss and eval_result: Output of the eval step. Assumes first value is eval_loss and
second value is accuracy_top_1. second value is accuracy_top_1.
time_callback: Time tracking callback instance. time_callback: Time tracking callback instance.
avg_exp_per_second: Average training examples per second.
Returns: Returns:
Dictionary of normalized results. Dictionary of normalized results.
...@@ -67,11 +68,8 @@ def build_stats(train_result, eval_result, time_callback): ...@@ -67,11 +68,8 @@ def build_stats(train_result, eval_result, time_callback):
timestamp_log = time_callback.timestamp_log timestamp_log = time_callback.timestamp_log
stats['step_timestamp_log'] = timestamp_log stats['step_timestamp_log'] = timestamp_log
stats['train_finish_time'] = time_callback.train_finish_time stats['train_finish_time'] = time_callback.train_finish_time
if len(timestamp_log) > 1:
stats['avg_exp_per_second'] = ( stats['avg_exp_per_second'] = avg_exp_per_second
time_callback.batch_size * time_callback.log_steps *
(len(time_callback.timestamp_log) - 1) /
(timestamp_log[-1].timestamp - timestamp_log[0].timestamp))
return stats return stats
...@@ -348,6 +346,7 @@ def run(flags_obj): ...@@ -348,6 +346,7 @@ def run(flags_obj):
else: else:
summary_writer = None summary_writer = None
examples_per_second_history = []
train_iter = iter(train_ds) train_iter = iter(train_ds)
time_callback.on_train_begin() time_callback.on_train_begin()
for epoch in range(current_step // per_epoch_steps, train_epochs): for epoch in range(current_step // per_epoch_steps, train_epochs):
...@@ -355,6 +354,7 @@ def run(flags_obj): ...@@ -355,6 +354,7 @@ def run(flags_obj):
training_accuracy.reset_states() training_accuracy.reset_states()
steps_in_current_epoch = 0 steps_in_current_epoch = 0
time_callback.on_epoch_begin(epoch + 1)
while steps_in_current_epoch < per_epoch_steps: while steps_in_current_epoch < per_epoch_steps:
time_callback.on_batch_begin( time_callback.on_batch_begin(
steps_in_current_epoch+epoch*per_epoch_steps) steps_in_current_epoch+epoch*per_epoch_steps)
...@@ -374,6 +374,12 @@ def run(flags_obj): ...@@ -374,6 +374,12 @@ def run(flags_obj):
training_accuracy.result().numpy(), training_accuracy.result().numpy(),
epoch + 1) epoch + 1)
time_callback.on_epoch_end(epoch + 1)
epoch_time = time_callback.epoch_runtime_log[-1]
steps_per_second = per_epoch_steps / epoch_time
examples_per_second = steps_per_second * flags_obj.batch_size
examples_per_second_history.append(examples_per_second)
if (not flags_obj.skip_eval and if (not flags_obj.skip_eval and
(epoch + 1) % flags_obj.epochs_between_evals == 0): (epoch + 1) % flags_obj.epochs_between_evals == 0):
test_loss.reset_states() test_loss.reset_states()
...@@ -403,6 +409,8 @@ def run(flags_obj): ...@@ -403,6 +409,8 @@ def run(flags_obj):
tf.summary.scalar('eval_loss', test_loss.result(), current_steps) tf.summary.scalar('eval_loss', test_loss.result(), current_steps)
tf.summary.scalar( tf.summary.scalar(
'eval_accuracy', test_accuracy.result(), current_steps) 'eval_accuracy', test_accuracy.result(), current_steps)
tf.summary.scalar('global_step/sec', steps_per_second, current_steps)
tf.summary.scalar('examples/sec', examples_per_second, current_steps)
time_callback.on_train_end() time_callback.on_train_end()
if summary_writer: if summary_writer:
...@@ -416,7 +424,12 @@ def run(flags_obj): ...@@ -416,7 +424,12 @@ def run(flags_obj):
train_result = [train_loss.result().numpy(), train_result = [train_loss.result().numpy(),
training_accuracy.result().numpy()] training_accuracy.result().numpy()]
stats = build_stats(train_result, eval_result, time_callback) stats = build_stats(
train_result,
eval_result,
time_callback,
tf.reduce_mean(examples_per_second_history).numpy(),
)
return stats return stats
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment