Unverified Commit f4b02d15 authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

Move Keras Hook to use global step to resolve issues across epochs. (#7186)

* Move to global_step.

* Hook to use global_step.

* fix comment start step 1 not step 0.

* remove hack used for testing.

* Add docstring.
parent 9d53a513
......@@ -22,10 +22,8 @@ import time
from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import cifar10_main as cifar_main
from official.resnet.keras import keras_benchmark
from official.resnet.keras import keras_cifar_main
from official.resnet.keras import keras_common
MIN_TOP_1_ACCURACY = 0.925
MAX_TOP_1_ACCURACY = 0.938
......
......@@ -46,41 +46,37 @@ class TimeHistory(tf.keras.callbacks.Callback):
Args:
batch_size: Total batch size.
log_steps: Interval of time history logs.
"""
self.batch_size = batch_size
super(TimeHistory, self).__init__()
self.log_steps = log_steps
self.global_steps = 0
# Logs start of step 0 then end of each step based on log_steps interval.
# Logs start of step 1 then end of each step based on log_steps interval.
self.timestamp_log = []
def on_train_begin(self, logs=None):
self.record_batch = True
def on_train_end(self, logs=None):
self.train_finish_time = time.time()
def on_batch_begin(self, batch, logs=None):
if self.record_batch:
timestamp = time.time()
self.start_time = timestamp
self.record_batch = False
if batch == 0:
self.timestamp_log.append(BatchTimestamp(batch, timestamp))
self.global_steps += 1
if self.global_steps == 1:
self.start_time = time.time()
self.timestamp_log.append(BatchTimestamp(self.global_steps,
self.start_time))
def on_batch_end(self, batch, logs=None):
if batch % self.log_steps == 0:
"""Records elapse time of the batch and calculates examples per second."""
if self.global_steps % self.log_steps == 0:
timestamp = time.time()
elapsed_time = timestamp - self.start_time
examples_per_second = (self.batch_size * self.log_steps) / elapsed_time
if batch != 0:
self.record_batch = True
self.timestamp_log.append(BatchTimestamp(batch, timestamp))
self.timestamp_log.append(BatchTimestamp(self.global_steps, timestamp))
tf.compat.v1.logging.info(
"BenchmarkMetric: {'num_batches':%d, 'time_taken': %f,"
"BenchmarkMetric: {'global step':%d, 'time_taken': %f,"
"'examples_per_second': %f}" %
(batch, elapsed_time, examples_per_second))
(self.global_steps, elapsed_time, examples_per_second))
self.start_time = timestamp
def get_profiler_callback(model_dir, profile_steps, enable_tensorboard):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment