"scripts/git@developer.sourcefind.cn:change/sglang.git" did not exist on "b5044fbf12d1764444ed8105c4520b3a24db0cea"
Unverified Commit f4b02d15 authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

Move Keras Hook to use global step to resolve issues across epochs. (#7186)

* Move to global_step.

* Hook to use global_step.

* fix comment start step 1 not step 0.

* remove hack used for testing.

* Add docstring.
parent 9d53a513
...@@ -22,10 +22,8 @@ import time ...@@ -22,10 +22,8 @@ import time
from absl import flags from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import cifar10_main as cifar_main
from official.resnet.keras import keras_benchmark from official.resnet.keras import keras_benchmark
from official.resnet.keras import keras_cifar_main from official.resnet.keras import keras_cifar_main
from official.resnet.keras import keras_common
MIN_TOP_1_ACCURACY = 0.925 MIN_TOP_1_ACCURACY = 0.925
MAX_TOP_1_ACCURACY = 0.938 MAX_TOP_1_ACCURACY = 0.938
......
...@@ -46,41 +46,37 @@ class TimeHistory(tf.keras.callbacks.Callback): ...@@ -46,41 +46,37 @@ class TimeHistory(tf.keras.callbacks.Callback):
Args: Args:
batch_size: Total batch size. batch_size: Total batch size.
log_steps: Interval of time history logs. log_steps: Interval of time history logs.
""" """
self.batch_size = batch_size self.batch_size = batch_size
super(TimeHistory, self).__init__() super(TimeHistory, self).__init__()
self.log_steps = log_steps self.log_steps = log_steps
self.global_steps = 0
# Logs start of step 0 then end of each step based on log_steps interval. # Logs start of step 1 then end of each step based on log_steps interval.
self.timestamp_log = [] self.timestamp_log = []
def on_train_begin(self, logs=None):
self.record_batch = True
def on_train_end(self, logs=None): def on_train_end(self, logs=None):
self.train_finish_time = time.time() self.train_finish_time = time.time()
def on_batch_begin(self, batch, logs=None): def on_batch_begin(self, batch, logs=None):
if self.record_batch: self.global_steps += 1
timestamp = time.time() if self.global_steps == 1:
self.start_time = timestamp self.start_time = time.time()
self.record_batch = False self.timestamp_log.append(BatchTimestamp(self.global_steps,
if batch == 0: self.start_time))
self.timestamp_log.append(BatchTimestamp(batch, timestamp))
def on_batch_end(self, batch, logs=None): def on_batch_end(self, batch, logs=None):
if batch % self.log_steps == 0: """Records elapse time of the batch and calculates examples per second."""
if self.global_steps % self.log_steps == 0:
timestamp = time.time() timestamp = time.time()
elapsed_time = timestamp - self.start_time elapsed_time = timestamp - self.start_time
examples_per_second = (self.batch_size * self.log_steps) / elapsed_time examples_per_second = (self.batch_size * self.log_steps) / elapsed_time
if batch != 0: self.timestamp_log.append(BatchTimestamp(self.global_steps, timestamp))
self.record_batch = True tf.compat.v1.logging.info(
self.timestamp_log.append(BatchTimestamp(batch, timestamp)) "BenchmarkMetric: {'global step':%d, 'time_taken': %f,"
tf.compat.v1.logging.info( "'examples_per_second': %f}" %
"BenchmarkMetric: {'num_batches':%d, 'time_taken': %f," (self.global_steps, elapsed_time, examples_per_second))
"'examples_per_second': %f}" % self.start_time = timestamp
(batch, elapsed_time, examples_per_second))
def get_profiler_callback(model_dir, profile_steps, enable_tensorboard): def get_profiler_callback(model_dir, profile_steps, enable_tensorboard):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment