Unverified Commit f4b02d15 authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

Move Keras Hook to use global step to resolve issues across epochs. (#7186)

* Move to global_step.

* Hook to use global_step.

* fix comment start step 1 not step 0.

* remove hack used for testing.

* Add docstring.
parent 9d53a513
...@@ -22,10 +22,8 @@ import time ...@@ -22,10 +22,8 @@ import time
from absl import flags from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import cifar10_main as cifar_main
from official.resnet.keras import keras_benchmark from official.resnet.keras import keras_benchmark
from official.resnet.keras import keras_cifar_main from official.resnet.keras import keras_cifar_main
from official.resnet.keras import keras_common
MIN_TOP_1_ACCURACY = 0.925 MIN_TOP_1_ACCURACY = 0.925
MAX_TOP_1_ACCURACY = 0.938 MAX_TOP_1_ACCURACY = 0.938
......
...@@ -46,41 +46,37 @@ class TimeHistory(tf.keras.callbacks.Callback): ...@@ -46,41 +46,37 @@ class TimeHistory(tf.keras.callbacks.Callback):
Args: Args:
batch_size: Total batch size. batch_size: Total batch size.
log_steps: Interval of time history logs. log_steps: Interval of time history logs.
""" """
self.batch_size = batch_size self.batch_size = batch_size
super(TimeHistory, self).__init__() super(TimeHistory, self).__init__()
self.log_steps = log_steps self.log_steps = log_steps
self.global_steps = 0
# Logs start of step 0 then end of each step based on log_steps interval. # Logs start of step 1 then end of each step based on log_steps interval.
self.timestamp_log = [] self.timestamp_log = []
def on_train_begin(self, logs=None):
self.record_batch = True
def on_train_end(self, logs=None): def on_train_end(self, logs=None):
self.train_finish_time = time.time() self.train_finish_time = time.time()
def on_batch_begin(self, batch, logs=None): def on_batch_begin(self, batch, logs=None):
if self.record_batch: self.global_steps += 1
timestamp = time.time() if self.global_steps == 1:
self.start_time = timestamp self.start_time = time.time()
self.record_batch = False self.timestamp_log.append(BatchTimestamp(self.global_steps,
if batch == 0: self.start_time))
self.timestamp_log.append(BatchTimestamp(batch, timestamp))
def on_batch_end(self, batch, logs=None): def on_batch_end(self, batch, logs=None):
if batch % self.log_steps == 0: """Records elapse time of the batch and calculates examples per second."""
if self.global_steps % self.log_steps == 0:
timestamp = time.time() timestamp = time.time()
elapsed_time = timestamp - self.start_time elapsed_time = timestamp - self.start_time
examples_per_second = (self.batch_size * self.log_steps) / elapsed_time examples_per_second = (self.batch_size * self.log_steps) / elapsed_time
if batch != 0: self.timestamp_log.append(BatchTimestamp(self.global_steps, timestamp))
self.record_batch = True tf.compat.v1.logging.info(
self.timestamp_log.append(BatchTimestamp(batch, timestamp)) "BenchmarkMetric: {'global step':%d, 'time_taken': %f,"
tf.compat.v1.logging.info( "'examples_per_second': %f}" %
"BenchmarkMetric: {'num_batches':%d, 'time_taken': %f," (self.global_steps, elapsed_time, examples_per_second))
"'examples_per_second': %f}" % self.start_time = timestamp
(batch, elapsed_time, examples_per_second))
def get_profiler_callback(model_dir, profile_steps, enable_tensorboard): def get_profiler_callback(model_dir, profile_steps, enable_tensorboard):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment