Commit 68146271 authored by David Chen's avatar David Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 285991432
parent 7a69f962
...@@ -152,6 +152,8 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase): ...@@ -152,6 +152,8 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
# Since we do not load from any pretrained checkpoints, we ignore all # Since we do not load from any pretrained checkpoints, we ignore all
# accuracy metrics. # accuracy metrics.
summary.pop('eval_metrics', None) summary.pop('eval_metrics', None)
summary['start_time_sec'] = start_time_sec
super(BertClassifyBenchmarkReal, self)._report_benchmark( super(BertClassifyBenchmarkReal, self)._report_benchmark(
stats=summary, stats=summary,
wall_time_sec=wall_time_sec, wall_time_sec=wall_time_sec,
......
...@@ -38,24 +38,25 @@ class BenchmarkTimerCallback(tf.keras.callbacks.Callback): ...@@ -38,24 +38,25 @@ class BenchmarkTimerCallback(tf.keras.callbacks.Callback):
def __init__(self, num_batches_to_skip=10): def __init__(self, num_batches_to_skip=10):
super(BenchmarkTimerCallback, self).__init__() super(BenchmarkTimerCallback, self).__init__()
self.num_batches_to_skip = num_batches_to_skip self.batch_start_times = {}
self.timer_records = [] self.batch_stop_times = {}
self.start_time = None
def on_batch_begin(self, batch, logs=None): def on_batch_begin(self, batch, logs=None):
if batch < self.num_batches_to_skip: self.batch_start_times[batch] = time.time()
return
self.start_time = time.time()
def on_batch_end(self, batch, logs=None): def on_batch_end(self, batch, logs=None):
if batch < self.num_batches_to_skip: self.batch_stop_times[batch] = time.time()
return
assert self.start_time def get_examples_per_sec(self, batch_size, num_batches_to_skip=10):
self.timer_records.append(time.time() - self.start_time) batch_durations = []
for batch in self.batch_start_times:
if batch in self.batch_stop_times and batch >= num_batches_to_skip:
batch_durations.append(self.batch_stop_times[batch] -
self.batch_start_times[batch])
return batch_size / np.mean(batch_durations)
def get_examples_per_sec(self, batch_size): def get_startup_time(self, program_start_time):
return batch_size / np.mean(self.timer_records) return self.batch_start_times[0] - program_start_time
class BertBenchmarkBase(tf.test.Benchmark): class BertBenchmarkBase(tf.test.Benchmark):
...@@ -113,6 +114,11 @@ class BertBenchmarkBase(tf.test.Benchmark): ...@@ -113,6 +114,11 @@ class BertBenchmarkBase(tf.test.Benchmark):
'name': 'exp_per_second', 'name': 'exp_per_second',
'value': 0.0, 'value': 0.0,
}) })
if self.timer_callback and 'start_time_sec' in stats:
metrics.append({
'name': 'startup_time',
'value': self.timer_callback.get_startup_time(stats['start_time_sec'])
})
if 'eval_metrics' in stats: if 'eval_metrics' in stats:
metrics.append({ metrics.append({
......
...@@ -144,6 +144,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase): ...@@ -144,6 +144,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
wall_time_sec = time.time() - start_time_sec wall_time_sec = time.time() - start_time_sec
summary = self._read_training_summary_from_file() summary = self._read_training_summary_from_file()
summary['start_time_sec'] = start_time_sec
super(BertSquadBenchmarkReal, self)._report_benchmark( super(BertSquadBenchmarkReal, self)._report_benchmark(
stats=summary, stats=summary,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment