hooks_test.py 5.65 KB
Newer Older
Yanhui Liang's avatar
Yanhui Liang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Tests for hooks."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import time

Karmel Allison's avatar
Karmel Allison committed
24
import tensorflow as tf  # pylint: disable=g-bad-import-order
Yanhui Liang's avatar
Yanhui Liang committed
25

26
from official.utils.logs import hooks
Karmel Allison's avatar
Karmel Allison committed
27
from official.utils.testing import mock_lib
Yanhui Liang's avatar
Yanhui Liang committed
28

Karmel Allison's avatar
Karmel Allison committed
29
tf.logging.set_verbosity(tf.logging.DEBUG)
Yanhui Liang's avatar
Yanhui Liang committed
30
31
32


class ExamplesPerSecondHookTest(tf.test.TestCase):
33
34
35
36
37
38
39
40
41
42
43
  """Tests for the ExamplesPerSecondHook.

  In this test, we explicitly run global_step tensor after train_op in order to
  grab the correct global step value. This is to correct for discrepancies in
  reported global step when running on GPUs. As in the after_run functions in
  ExamplesPerSecondHook, the global step from run_results
  (global_step = run_values.results) is not always correct and taken as the
  stale global_step (which may be 1 off the correct value). The exact
  global_step value should be from run_context
  (global_step = run_context.session.run(global_step_tensor)
  """
Yanhui Liang's avatar
Yanhui Liang committed
44
45
46

  def setUp(self):
    """Mock out logging calls to verify if correct info is being monitored."""
Karmel Allison's avatar
Karmel Allison committed
47
    self._logger = mock_lib.MockBenchmarkLogger()
Yanhui Liang's avatar
Yanhui Liang committed
48
49
50

    self.graph = tf.Graph()
    with self.graph.as_default():
51
52
53
      tf.train.create_global_step()
      self.train_op = tf.assign_add(tf.train.get_global_step(), 1)
      self.global_step = tf.train.get_global_step()
Yanhui Liang's avatar
Yanhui Liang committed
54
55
56
57
58
59

  def test_raise_in_both_secs_and_steps(self):
    with self.assertRaises(ValueError):
      hooks.ExamplesPerSecondHook(
          batch_size=256,
          every_n_steps=10,
Karmel Allison's avatar
Karmel Allison committed
60
61
          every_n_secs=20,
          metric_logger=self._logger)
Yanhui Liang's avatar
Yanhui Liang committed
62
63
64
65
66
67

  def test_raise_in_none_secs_and_steps(self):
    with self.assertRaises(ValueError):
      hooks.ExamplesPerSecondHook(
          batch_size=256,
          every_n_steps=None,
Karmel Allison's avatar
Karmel Allison committed
68
69
          every_n_secs=None,
          metric_logger=self._logger)
Yanhui Liang's avatar
Yanhui Liang committed
70

71
  def _validate_log_every_n_steps(self, every_n_steps, warm_steps):
Yanhui Liang's avatar
Yanhui Liang committed
72
73
74
    hook = hooks.ExamplesPerSecondHook(
        batch_size=256,
        every_n_steps=every_n_steps,
Karmel Allison's avatar
Karmel Allison committed
75
76
        warm_steps=warm_steps,
        metric_logger=self._logger)
Yanhui Liang's avatar
Yanhui Liang committed
77

78
79
80
81
82
83
84
85
86
87
    with tf.train.MonitoredSession(
        tf.train.ChiefSessionCreator(), [hook]) as mon_sess:
      for _ in range(every_n_steps):
        # Explicitly run global_step after train_op to get the accurate
        # global_step value
        mon_sess.run(self.train_op)
        mon_sess.run(self.global_step)
        # Nothing should be in the list yet
        self.assertFalse(self._logger.logged_metric)

Yanhui Liang's avatar
Yanhui Liang committed
88
      mon_sess.run(self.train_op)
89
      global_step_val = mon_sess.run(self.global_step)
Yanhui Liang's avatar
Yanhui Liang committed
90

91
92
93
94
95
      if global_step_val > warm_steps:
        self._assert_metrics()
      else:
        # Nothing should be in the list yet
        self.assertFalse(self._logger.logged_metric)
Karmel Allison's avatar
Karmel Allison committed
96

97
98
99
100
      # Add additional run to verify proper reset when called multiple times.
      prev_log_len = len(self._logger.logged_metric)
      mon_sess.run(self.train_op)
      global_step_val = mon_sess.run(self.global_step)
Yanhui Liang's avatar
Yanhui Liang committed
101

102
103
104
105
106
107
      if every_n_steps == 1 and global_step_val > warm_steps:
        # Each time, we log two additional metrics. Did exactly 2 get added?
        self.assertEqual(len(self._logger.logged_metric), prev_log_len + 2)
      else:
        # No change in the size of the metric list.
        self.assertEqual(len(self._logger.logged_metric), prev_log_len)
Yanhui Liang's avatar
Yanhui Liang committed
108
109

  def test_examples_per_sec_every_1_steps(self):
110
111
    with self.graph.as_default():
      self._validate_log_every_n_steps(1, 0)
Yanhui Liang's avatar
Yanhui Liang committed
112
113

  def test_examples_per_sec_every_5_steps(self):
114
115
    with self.graph.as_default():
      self._validate_log_every_n_steps(5, 0)
Yanhui Liang's avatar
Yanhui Liang committed
116
117

  def test_examples_per_sec_every_1_steps_with_warm_steps(self):
118
119
    with self.graph.as_default():
      self._validate_log_every_n_steps(1, 10)
Yanhui Liang's avatar
Yanhui Liang committed
120
121

  def test_examples_per_sec_every_5_steps_with_warm_steps(self):
122
123
    with self.graph.as_default():
      self._validate_log_every_n_steps(5, 10)
Yanhui Liang's avatar
Yanhui Liang committed
124

125
  def _validate_log_every_n_secs(self, every_n_secs):
Yanhui Liang's avatar
Yanhui Liang committed
126
127
128
    hook = hooks.ExamplesPerSecondHook(
        batch_size=256,
        every_n_steps=None,
Karmel Allison's avatar
Karmel Allison committed
129
130
        every_n_secs=every_n_secs,
        metric_logger=self._logger)
Yanhui Liang's avatar
Yanhui Liang committed
131

132
133
134
135
136
137
138
139
140
    with tf.train.MonitoredSession(
        tf.train.ChiefSessionCreator(), [hook]) as mon_sess:
      # Explicitly run global_step after train_op to get the accurate
      # global_step value
      mon_sess.run(self.train_op)
      mon_sess.run(self.global_step)
      # Nothing should be in the list yet
      self.assertFalse(self._logger.logged_metric)
      time.sleep(every_n_secs)
Yanhui Liang's avatar
Yanhui Liang committed
141

142
143
144
      mon_sess.run(self.train_op)
      mon_sess.run(self.global_step)
      self._assert_metrics()
Yanhui Liang's avatar
Yanhui Liang committed
145
146

  def test_examples_per_sec_every_1_secs(self):
147
148
    with self.graph.as_default():
      self._validate_log_every_n_secs(1)
Yanhui Liang's avatar
Yanhui Liang committed
149
150

  def test_examples_per_sec_every_5_secs(self):
151
152
    with self.graph.as_default():
      self._validate_log_every_n_secs(5)
Yanhui Liang's avatar
Yanhui Liang committed
153

Karmel Allison's avatar
Karmel Allison committed
154
155
156
157
158
  def _assert_metrics(self):
    metrics = self._logger.logged_metric
    self.assertEqual(metrics[-2]["name"], "average_examples_per_sec")
    self.assertEqual(metrics[-1]["name"], "current_examples_per_sec")

Yanhui Liang's avatar
Yanhui Liang committed
159

Karmel Allison's avatar
Karmel Allison committed
160
if __name__ == "__main__":
Yanhui Liang's avatar
Yanhui Liang committed
161
  tf.test.main()