estimator_cifar_benchmark.py 5.89 KB
Newer Older
Shining Sun's avatar
Shining Sun committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
15
"""Executes Estimator benchmarks and accuracy tests."""
Shining Sun's avatar
Shining Sun committed
16
17
18

from __future__ import absolute_import
from __future__ import division
19
20
21
from __future__ import print_function

import os
Toby Boyd's avatar
Toby Boyd committed
22
import time
23
24
25
26
27
28

from absl import flags
from absl.testing import flagsaver
import tensorflow as tf  # pylint: disable=g-bad-import-order

from official.resnet import cifar10_main as cifar_main
29
from official.utils.logs import hooks
30

Toby Boyd's avatar
Toby Boyd committed
31
32
33
MIN_TOP_1_ACCURACY = 0.926
MAX_TOP_1_ACCURACY = 0.938

34

35
class EstimatorCifar10BenchmarkTests(tf.test.Benchmark):
36
37
38
39
  """Benchmarks and accuracy tests for Estimator ResNet56."""

  local_flags = None

40
  def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
41
42
43
44
45
    """A benchmark class.

    Args:
      output_dir: directory where to output e.g. log files
      root_data_dir: directory under which to look for dataset
46
47
48
      **kwargs: arbitrary named arguments. This is needed to make the
                constructor forward compatible in case PerfZero provides more
                named arguments before updating the constructor.
49
50
    """

51
    self.output_dir = output_dir
52
    self.data_dir = os.path.join(root_data_dir, 'cifar-10-batches-bin')
53
54
55
56
57

  def resnet56_1_gpu(self):
    """Test layers model with Estimator and distribution strategies."""
    self._setup()
    flags.FLAGS.num_gpus = 1
58
    flags.FLAGS.data_dir = self.data_dir
59
60
61
62
63
    flags.FLAGS.batch_size = 128
    flags.FLAGS.train_epochs = 182
    flags.FLAGS.model_dir = self._get_model_dir('resnet56_1_gpu')
    flags.FLAGS.resnet_size = 56
    flags.FLAGS.dtype = 'fp32'
64
    flags.FLAGS.hooks = ['ExamplesPerSecondHook']
65
    self._run_and_report_benchmark()
66
67
68
69
70

  def resnet56_fp16_1_gpu(self):
    """Test layers FP16 model with Estimator and distribution strategies."""
    self._setup()
    flags.FLAGS.num_gpus = 1
71
    flags.FLAGS.data_dir = self.data_dir
72
73
74
75
76
    flags.FLAGS.batch_size = 128
    flags.FLAGS.train_epochs = 182
    flags.FLAGS.model_dir = self._get_model_dir('resnet56_fp16_1_gpu')
    flags.FLAGS.resnet_size = 56
    flags.FLAGS.dtype = 'fp16'
77
    flags.FLAGS.hooks = ['ExamplesPerSecondHook']
78
    self._run_and_report_benchmark()
79
80
81
82

  def resnet56_2_gpu(self):
    """Test layers model with Estimator and dist_strat. 2 GPUs."""
    self._setup()
83
    flags.FLAGS.num_gpus = 2
84
    flags.FLAGS.data_dir = self.data_dir
85
86
87
88
89
    flags.FLAGS.batch_size = 128
    flags.FLAGS.train_epochs = 182
    flags.FLAGS.model_dir = self._get_model_dir('resnet56_2_gpu')
    flags.FLAGS.resnet_size = 56
    flags.FLAGS.dtype = 'fp32'
90
    flags.FLAGS.hooks = ['ExamplesPerSecondHook']
91
    self._run_and_report_benchmark()
92
93
94
95
96

  def resnet56_fp16_2_gpu(self):
    """Test layers FP16 model with Estimator and dist_strat. 2 GPUs."""
    self._setup()
    flags.FLAGS.num_gpus = 2
97
    flags.FLAGS.data_dir = self.data_dir
98
99
100
101
102
    flags.FLAGS.batch_size = 128
    flags.FLAGS.train_epochs = 182
    flags.FLAGS.model_dir = self._get_model_dir('resnet56_fp16_2_gpu')
    flags.FLAGS.resnet_size = 56
    flags.FLAGS.dtype = 'fp16'
103
    flags.FLAGS.hooks = ['ExamplesPerSecondHook']
104
105
106
    self._run_and_report_benchmark()

  def unit_test(self):
Toby Boyd's avatar
Toby Boyd committed
107
    """A lightweight test that can finish quickly."""
108
109
    self._setup()
    flags.FLAGS.num_gpus = 1
110
    flags.FLAGS.data_dir = self.data_dir
111
112
113
114
115
    flags.FLAGS.batch_size = 128
    flags.FLAGS.train_epochs = 1
    flags.FLAGS.model_dir = self._get_model_dir('resnet56_1_gpu')
    flags.FLAGS.resnet_size = 8
    flags.FLAGS.dtype = 'fp32'
116
    flags.FLAGS.hooks = ['ExamplesPerSecondHook']
117
118
119
    self._run_and_report_benchmark()

  def _run_and_report_benchmark(self):
Toby Boyd's avatar
Toby Boyd committed
120
    """Executes benchmark and reports result."""
121
    start_time_sec = time.time()
122
    stats = cifar_main.run_cifar(flags.FLAGS)
123
124
    wall_time_sec = time.time() - start_time_sec

125
126
127
128
129
130
131
    examples_per_sec_hook = None
    for hook in stats['train_hooks']:
      if isinstance(hook, hooks.ExamplesPerSecondHook):
        examples_per_sec_hook = hook
        break

    eval_results = stats['eval_results']
132
133
    metrics = []
    metrics.append({'name': 'accuracy_top_1',
Toby Boyd's avatar
Toby Boyd committed
134
135
136
                    'value': eval_results['accuracy'].item(),
                    'min_value': MIN_TOP_1_ACCURACY,
                    'max_value': MAX_TOP_1_ACCURACY})
137
138
    metrics.append({'name': 'accuracy_top_5',
                    'value': eval_results['accuracy_top_5'].item()})
139
140
141
142
    if examples_per_sec_hook:
      exp_per_second_list = examples_per_sec_hook.current_examples_per_sec_list
      # ExamplesPerSecondHook skips the first 10 steps.
      exp_per_sec = sum(exp_per_second_list) / (len(exp_per_second_list))
143
144
      metrics.append({'name': 'exp_per_second',
                      'value': exp_per_sec})
145

146
    self.report_benchmark(
147
        iters=eval_results['global_step'],
148
        wall_time=wall_time_sec,
149
        metrics=metrics)
150
151
152
153
154

  def _get_model_dir(self, folder_name):
    return os.path.join(self.output_dir, folder_name)

  def _setup(self):
155
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.DEBUG)
156
157
158
159
160
161
162
163
    if EstimatorCifar10BenchmarkTests.local_flags is None:
      cifar_main.define_cifar_flags()
      # Loads flags to get defaults to then override.
      flags.FLAGS(['foo'])
      saved_flag_values = flagsaver.save_flag_values()
      EstimatorCifar10BenchmarkTests.local_flags = saved_flag_values
      return
    flagsaver.restore_flag_values(EstimatorCifar10BenchmarkTests.local_flags)