bert_benchmark_utils.py 4.37 KB
Newer Older
davidmochen's avatar
davidmochen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions or classes shared between BERT benchmarks."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import time

# pylint: disable=g-bad-import-order
import numpy as np
from absl import flags
Hongkun Yu's avatar
Hongkun Yu committed
26
import tensorflow as tf
davidmochen's avatar
davidmochen committed
27
28
# pylint: enable=g-bad-import-order

Toby Boyd's avatar
Toby Boyd committed
29
from official.utils.flags import core as flags_core
David Chen's avatar
David Chen committed
30
from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark
Toby Boyd's avatar
Toby Boyd committed
31

davidmochen's avatar
davidmochen committed
32
33
34
35
36
37
38
39
FLAGS = flags.FLAGS


class BenchmarkTimerCallback(tf.keras.callbacks.Callback):
  """Callback that records time it takes to run each batch."""

  def __init__(self, num_batches_to_skip=10):
    super(BenchmarkTimerCallback, self).__init__()
David Chen's avatar
David Chen committed
40
41
    self.batch_start_times = {}
    self.batch_stop_times = {}
davidmochen's avatar
davidmochen committed
42

43
  def on_batch_begin(self, batch, logs=None):
David Chen's avatar
David Chen committed
44
    self.batch_start_times[batch] = time.time()
davidmochen's avatar
davidmochen committed
45
46

  def on_batch_end(self, batch, logs=None):
47
48
49
50
51
    # If there are multiple steps_per_loop, the end batch index will not be the
    # same as the starting index. Use the last starting index instead.
    if batch not in self.batch_start_times:
      batch = max(self.batch_start_times.keys())

David Chen's avatar
David Chen committed
52
    self.batch_stop_times[batch] = time.time()
davidmochen's avatar
davidmochen committed
53

54
  def get_examples_per_sec(self, batch_size, num_batches_to_skip=1):
David Chen's avatar
David Chen committed
55
56
57
58
59
60
    batch_durations = []
    for batch in self.batch_start_times:
      if batch in self.batch_stop_times and batch >= num_batches_to_skip:
        batch_durations.append(self.batch_stop_times[batch] -
                               self.batch_start_times[batch])
    return batch_size / np.mean(batch_durations)
davidmochen's avatar
davidmochen committed
61

David Chen's avatar
David Chen committed
62
63
  def get_startup_time(self, program_start_time):
    return self.batch_start_times[0] - program_start_time
davidmochen's avatar
davidmochen committed
64
65


David Chen's avatar
David Chen committed
66
class BertBenchmarkBase(PerfZeroBenchmark):
davidmochen's avatar
davidmochen committed
67
68
69
70
  """Base class to hold methods common to test classes."""
  local_flags = None

  def __init__(self, output_dir=None):
David Chen's avatar
David Chen committed
71
    super(BertBenchmarkBase, self).__init__(output_dir=output_dir)
davidmochen's avatar
davidmochen committed
72
73
74
75
76
    self.num_gpus = 8
    self.timer_callback = None

  def _setup(self):
    """Sets up and resets flags before each test."""
David Chen's avatar
David Chen committed
77
    super(BertBenchmarkBase, self)._setup()
davidmochen's avatar
davidmochen committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    self.timer_callback = BenchmarkTimerCallback()

  def _report_benchmark(self, stats, wall_time_sec, min_accuracy, max_accuracy):
    """Report benchmark results by writing to local protobuf file.

    Args:
      stats: dict returned from BERT models with known entries.
      wall_time_sec: the during of the benchmark execution in seconds
      min_accuracy: Minimum classification accuracy constraint to verify
        correctness of the model.
      max_accuracy: Maximum classification accuracy constraint to verify
        correctness of the model.
    """
    metrics = [{
        'name': 'training_loss',
        'value': stats['train_loss'],
    }]
95
96
97
98
99
    if self.timer_callback:
      metrics.append({
          'name':
              'exp_per_second',
          'value':
100
101
              self.timer_callback.get_examples_per_sec(FLAGS.train_batch_size *
                                                       FLAGS.steps_per_loop)
102
103
104
105
106
107
      })
    else:
      metrics.append({
          'name': 'exp_per_second',
          'value': 0.0,
      })
David Chen's avatar
David Chen committed
108
109
110
111
112
    if self.timer_callback and 'start_time_sec' in stats:
      metrics.append({
          'name': 'startup_time',
          'value': self.timer_callback.get_startup_time(stats['start_time_sec'])
      })
davidmochen's avatar
davidmochen committed
113
114
115
116
117
118
119
120

    if 'eval_metrics' in stats:
      metrics.append({
          'name': 'eval_accuracy',
          'value': stats['eval_metrics'],
          'min_value': min_accuracy,
          'max_value': max_accuracy,
      })
Toby Boyd's avatar
Toby Boyd committed
121
    flags_str = flags_core.get_nondefault_flags_as_str()
davidmochen's avatar
davidmochen committed
122
123
124
    self.report_benchmark(
        iters=stats['total_training_steps'],
        wall_time=wall_time_sec,
Toby Boyd's avatar
Toby Boyd committed
125
126
        metrics=metrics,
        extras={'flags': flags_str})