Commit 8c1bccbc authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

measure first batch end time to estimate start_time = first_batch_end - avg_step_time.

PiperOrigin-RevId: 331795709
parent cf63f0be
...@@ -30,9 +30,13 @@ from tensorflow.python.eager import monitoring ...@@ -30,9 +30,13 @@ from tensorflow.python.eager import monitoring
global_batch_size_gauge = monitoring.IntGauge( global_batch_size_gauge = monitoring.IntGauge(
'/tensorflow/training/global_batch_size', 'TF training global batch size') '/tensorflow/training/global_batch_size', 'TF training global batch size')
first_batch_start_time = monitoring.IntGauge( first_batch_time_gauge = monitoring.IntGauge(
'/tensorflow/training/first_batch_start', '/tensorflow/training/first_batch',
'TF training start time (unix epoch time in us.') 'TF training start/end time for first batch (unix epoch time in us.',
'type')
first_batch_start_time = first_batch_time_gauge.get_cell('start')
first_batch_end_time = first_batch_time_gauge.get_cell('end')
class BatchTimestamp(object): class BatchTimestamp(object):
...@@ -121,8 +125,8 @@ class TimeHistory(tf.keras.callbacks.Callback): ...@@ -121,8 +125,8 @@ class TimeHistory(tf.keras.callbacks.Callback):
def on_batch_begin(self, batch, logs=None): def on_batch_begin(self, batch, logs=None):
if not self.start_time: if not self.start_time:
self.start_time = time.time() self.start_time = time.time()
if not first_batch_start_time.get_cell().value(): if not first_batch_start_time.value():
first_batch_start_time.get_cell().set(int(self.start_time * 1000000)) first_batch_start_time.set(int(self.start_time * 1000000))
# Record the timestamp of the first global step # Record the timestamp of the first global step
if not self.timestamp_log: if not self.timestamp_log:
...@@ -131,6 +135,8 @@ class TimeHistory(tf.keras.callbacks.Callback): ...@@ -131,6 +135,8 @@ class TimeHistory(tf.keras.callbacks.Callback):
def on_batch_end(self, batch, logs=None): def on_batch_end(self, batch, logs=None):
"""Records elapse time of the batch and calculates examples per second.""" """Records elapse time of the batch and calculates examples per second."""
if not first_batch_end_time.value():
first_batch_end_time.set(int(time.time() * 1000000))
self.steps_in_epoch = batch + 1 self.steps_in_epoch = batch + 1
steps_since_last_log = self.global_steps - self.last_log_step steps_since_last_log = self.global_steps - self.last_log_step
if steps_since_last_log >= self.log_steps: if steps_since_last_log >= self.log_steps:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment