Commit 3269c84b authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Move utils to TimeHistory

PiperOrigin-RevId: 315287705
parent ea61bbf0
...@@ -44,21 +44,6 @@ RESNET_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/retinanet/resnet50-checkpoi ...@@ -44,21 +44,6 @@ RESNET_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/retinanet/resnet50-checkpoi
# pylint: enable=line-too-long # pylint: enable=line-too-long
class TimerCallback(keras_utils.TimeHistory):
"""TimeHistory subclass for benchmark reporting."""
def get_examples_per_sec(self, warmup=1):
# First entry in timestamp_log is the start of the step 1. The rest of the
# entries are the end of each step recorded.
time_log = self.timestamp_log
seconds = time_log[-1].timestamp - time_log[warmup].timestamp
steps = time_log[-1].batch_index - time_log[warmup].batch_index
return self.batch_size * steps / seconds
def get_startup_time(self, start_time_sec):
return self.timestamp_log[0].timestamp - start_time_sec
class DetectionBenchmarkBase(perfzero_benchmark.PerfZeroBenchmark): class DetectionBenchmarkBase(perfzero_benchmark.PerfZeroBenchmark):
"""Base class to hold methods common to test classes.""" """Base class to hold methods common to test classes."""
...@@ -151,7 +136,7 @@ class RetinanetAccuracy(RetinanetBenchmarkBase): ...@@ -151,7 +136,7 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
"""Starts RetinaNet accuracy benchmark test.""" """Starts RetinaNet accuracy benchmark test."""
FLAGS.params_override = json.dumps(params) FLAGS.params_override = json.dumps(params)
# Need timer callback to measure performance # Need timer callback to measure performance
self.timer_callback = TimerCallback( self.timer_callback = keras_utils.TimeHistory(
batch_size=params['train']['batch_size'], batch_size=params['train']['batch_size'],
log_steps=FLAGS.log_steps, log_steps=FLAGS.log_steps,
) )
......
...@@ -85,6 +85,18 @@ class TimeHistory(tf.keras.callbacks.Callback): ...@@ -85,6 +85,18 @@ class TimeHistory(tf.keras.callbacks.Callback):
"""The average number of training examples per second across all epochs.""" """The average number of training examples per second across all epochs."""
return self.average_steps_per_second * self.batch_size return self.average_steps_per_second * self.batch_size
def get_examples_per_sec(self, warmup=1):
"""Calculates examples/sec through timestamp_log and skip warmup period."""
# First entry in timestamp_log is the start of the step 1. The rest of the
# entries are the end of each step recorded.
time_log = self.timestamp_log
seconds = time_log[-1].timestamp - time_log[warmup].timestamp
steps = time_log[-1].batch_index - time_log[warmup].batch_index
return self.batch_size * steps / seconds
def get_startup_time(self, start_time_sec):
return self.timestamp_log[0].timestamp - start_time_sec
def on_train_end(self, logs=None): def on_train_end(self, logs=None):
self.train_finish_time = time.time() self.train_finish_time = time.time()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment