keras_benchmark.py 4.08 KB
Newer Older
Toby Boyd's avatar
Toby Boyd committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Executes Keras benchmarks and accuracy tests."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

from absl import flags
from absl.testing import flagsaver
import tensorflow as tf  # pylint: disable=g-bad-import-order

FLAGS = flags.FLAGS


30
class KerasBenchmark(tf.test.Benchmark):
Toby Boyd's avatar
Toby Boyd committed
31
32
33
34
35
36
37
38
  """Base benchmark class with methods to simplify testing."""
  local_flags = None

  def __init__(self, output_dir=None, default_flags=None, flag_methods=None):
    self.output_dir = output_dir
    self.default_flags = default_flags or {}
    self.flag_methods = flag_methods or {}

39
40
41
    if not output_dir:
      output_dir = '/tmp/'

Toby Boyd's avatar
Toby Boyd committed
42
43
44
45
46
  def _get_model_dir(self, folder_name):
    return os.path.join(self.output_dir, folder_name)

  def _setup(self):
    """Sets up and resets flags before each test."""
47
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.DEBUG)
Toby Boyd's avatar
Toby Boyd committed
48
49
50
51
52
53
54
55
56
57
58
59
60
    if KerasBenchmark.local_flags is None:
      for flag_method in self.flag_methods:
        flag_method()
      # Loads flags to get defaults to then override. List cannot be empty.
      flags.FLAGS(['foo'])
      # Overrides flag values with defaults for the class of tests.
      for k, v in self.default_flags.items():
        setattr(FLAGS, k, v)
      saved_flag_values = flagsaver.save_flag_values()
      KerasBenchmark.local_flags = saved_flag_values
    else:
      flagsaver.restore_flag_values(KerasBenchmark.local_flags)

61
62
63
64
65
66
67
68
  def _report_benchmark(self,
                        stats,
                        wall_time_sec,
                        top_1_max=None,
                        top_1_min=None,
                        log_steps=None,
                        total_batch_size=None,
                        warmup=1):
Toby Boyd's avatar
Toby Boyd committed
69
    """Report benchmark results by writing to local protobuf file.
Toby Boyd's avatar
Toby Boyd committed
70
71
72

    Args:
      stats: dict returned from keras models with known entries.
73
      wall_time_sec: the during of the benchmark execution in seconds
Toby Boyd's avatar
Toby Boyd committed
74
75
76
77
78
79
      top_1_max: highest passing level for top_1 accuracy.
      top_1_min: lowest passing level for top_1 accuracy.
      log_steps: How often the log was created for stats['step_timestamp_log'].
      total_batch_size: Global batch-size.
      warmup: number of entries in stats['step_timestamp_log'] to ignore.
    """
80

81
    metrics = []
82
    if 'accuracy_top_1' in stats:
83
84
85
86
87
88
      metrics.append({'name': 'accuracy_top_1',
                      'value': stats['accuracy_top_1'],
                      'min_value': top_1_min,
                      'max_value': top_1_max})
      metrics.append({'name': 'top_1_train_accuracy',
                      'value': stats['training_accuracy_top_1']})
89
90
91
92
93
94
95
96
97
98

    if (warmup and 'step_timestamp_log' in stats and
        len(stats['step_timestamp_log']) > warmup):
      # first entry in the time_log is start of step 1. The rest of the
      # entries are the end of each step recorded
      time_log = stats['step_timestamp_log']
      elapsed = time_log[-1].timestamp - time_log[warmup].timestamp
      num_examples = (
          total_batch_size * log_steps * (len(time_log) - warmup - 1))
      examples_per_sec = num_examples / elapsed
99
100
      metrics.append({'name': 'exp_per_second',
                      'value': examples_per_sec})
101
102

    if 'avg_exp_per_second' in stats:
103
104
      metrics.append({'name': 'avg_exp_per_second',
                      'value': stats['avg_exp_per_second']})
105

106
    self.report_benchmark(iters=-1, wall_time=wall_time_sec, metrics=metrics)