metric_hook_test.py 8.51 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for metric_hook."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tempfile
import time

Karmel Allison's avatar
Karmel Allison committed
24
25
import tensorflow as tf  # pylint: disable=g-bad-import-order
from tensorflow.python.training import monitored_session  # pylint: disable=g-bad-import-order
26

Karmel Allison's avatar
Karmel Allison committed
27
28
from official.utils.logs import metric_hook
from official.utils.testing import mock_lib
29
30
31


class LoggingMetricHookTest(tf.test.TestCase):
32
  """Tests for LoggingMetricHook."""
33
34
35
36
37

  def setUp(self):
    super(LoggingMetricHookTest, self).setUp()

    self._log_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
Karmel Allison's avatar
Karmel Allison committed
38
    self._logger = mock_lib.MockBenchmarkLogger()
39
40
41

  def tearDown(self):
    super(LoggingMetricHookTest, self).tearDown()
42
    tf.io.gfile.rmtree(self.get_temp_dir())
43
44

  def test_illegal_args(self):
Karmel Allison's avatar
Karmel Allison committed
45
46
47
48
49
    with self.assertRaisesRegexp(ValueError, "nvalid every_n_iter"):
      metric_hook.LoggingMetricHook(tensors=["t"], every_n_iter=0)
    with self.assertRaisesRegexp(ValueError, "nvalid every_n_iter"):
      metric_hook.LoggingMetricHook(tensors=["t"], every_n_iter=-10)
    with self.assertRaisesRegexp(ValueError, "xactly one of"):
50
      metric_hook.LoggingMetricHook(
Karmel Allison's avatar
Karmel Allison committed
51
52
53
54
55
          tensors=["t"], every_n_iter=5, every_n_secs=5)
    with self.assertRaisesRegexp(ValueError, "xactly one of"):
      metric_hook.LoggingMetricHook(tensors=["t"])
    with self.assertRaisesRegexp(ValueError, "metric_logger"):
      metric_hook.LoggingMetricHook(tensors=["t"], every_n_iter=5)
56
57

  def test_print_at_end_only(self):
58
59
    with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
      tf.compat.v1.train.get_or_create_global_step()
Karmel Allison's avatar
Karmel Allison committed
60
      t = tf.constant(42.0, name="foo")
61
62
63
64
      train_op = tf.constant(3)
      hook = metric_hook.LoggingMetricHook(
          tensors=[t.name], at_end=True, metric_logger=self._logger)
      hook.begin()
Karmel Allison's avatar
Karmel Allison committed
65
      mon_sess = monitored_session._HookedSession(sess, [hook])  # pylint: disable=protected-access
66
      sess.run(tf.compat.v1.global_variables_initializer())
67
68
69
70
71
72
73
74
75
76
77
78
79
80

      for _ in range(3):
        mon_sess.run(train_op)
        self.assertEqual(self._logger.logged_metric, [])

      hook.end(sess)
      self.assertEqual(len(self._logger.logged_metric), 1)
      metric = self._logger.logged_metric[0]
      self.assertRegexpMatches(metric["name"], "foo")
      self.assertEqual(metric["value"], 42.0)
      self.assertEqual(metric["unit"], None)
      self.assertEqual(metric["global_step"], 0)

  def test_global_step_not_found(self):
Karmel Allison's avatar
Karmel Allison committed
81
82
    with tf.Graph().as_default():
      t = tf.constant(42.0, name="foo")
83
84
85
86
      hook = metric_hook.LoggingMetricHook(
          tensors=[t.name], at_end=True, metric_logger=self._logger)

      with self.assertRaisesRegexp(
Karmel Allison's avatar
Karmel Allison committed
87
          RuntimeError, "should be created to use LoggingMetricHook."):
88
89
90
        hook.begin()

  def test_log_tensors(self):
91
92
    with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
      tf.compat.v1.train.get_or_create_global_step()
Karmel Allison's avatar
Karmel Allison committed
93
94
      t1 = tf.constant(42.0, name="foo")
      t2 = tf.constant(43.0, name="bar")
95
96
97
98
      train_op = tf.constant(3)
      hook = metric_hook.LoggingMetricHook(
          tensors=[t1, t2], at_end=True, metric_logger=self._logger)
      hook.begin()
Karmel Allison's avatar
Karmel Allison committed
99
      mon_sess = monitored_session._HookedSession(sess, [hook])  # pylint: disable=protected-access
100
      sess.run(tf.compat.v1.global_variables_initializer())
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120

      for _ in range(3):
        mon_sess.run(train_op)
        self.assertEqual(self._logger.logged_metric, [])

      hook.end(sess)
      self.assertEqual(len(self._logger.logged_metric), 2)
      metric1 = self._logger.logged_metric[0]
      self.assertRegexpMatches(str(metric1["name"]), "foo")
      self.assertEqual(metric1["value"], 42.0)
      self.assertEqual(metric1["unit"], None)
      self.assertEqual(metric1["global_step"], 0)

      metric2 = self._logger.logged_metric[1]
      self.assertRegexpMatches(str(metric2["name"]), "bar")
      self.assertEqual(metric2["value"], 43.0)
      self.assertEqual(metric2["unit"], None)
      self.assertEqual(metric2["global_step"], 0)

  def _validate_print_every_n_steps(self, sess, at_end):
Karmel Allison's avatar
Karmel Allison committed
121
    t = tf.constant(42.0, name="foo")
122
123
124
125
126
127

    train_op = tf.constant(3)
    hook = metric_hook.LoggingMetricHook(
        tensors=[t.name], every_n_iter=10, at_end=at_end,
        metric_logger=self._logger)
    hook.begin()
Karmel Allison's avatar
Karmel Allison committed
128
    mon_sess = monitored_session._HookedSession(sess, [hook])  # pylint: disable=protected-access
129
    sess.run(tf.compat.v1.global_variables_initializer())
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    mon_sess.run(train_op)
    self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
    for _ in range(3):
      self._logger.logged_metric = []
      for _ in range(9):
        mon_sess.run(train_op)
        # assertNotRegexpMatches is not supported by python 3.1 and later
        self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
      mon_sess.run(train_op)
      self.assertRegexpMatches(str(self._logger.logged_metric), t.name)

    # Add additional run to verify proper reset when called multiple times.
    self._logger.logged_metric = []
    mon_sess.run(train_op)
    # assertNotRegexpMatches is not supported by python 3.1 and later
    self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)

    self._logger.logged_metric = []
    hook.end(sess)
    if at_end:
      self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
    else:
      # assertNotRegexpMatches is not supported by python 3.1 and later
      self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)

  def test_print_every_n_steps(self):
156
157
    with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
      tf.compat.v1.train.get_or_create_global_step()
158
159
160
161
162
      self._validate_print_every_n_steps(sess, at_end=False)
      # Verify proper reset.
      self._validate_print_every_n_steps(sess, at_end=False)

  def test_print_every_n_steps_and_end(self):
163
164
    with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
      tf.compat.v1.train.get_or_create_global_step()
165
166
167
168
169
      self._validate_print_every_n_steps(sess, at_end=True)
      # Verify proper reset.
      self._validate_print_every_n_steps(sess, at_end=True)

  def _validate_print_every_n_secs(self, sess, at_end):
Karmel Allison's avatar
Karmel Allison committed
170
    t = tf.constant(42.0, name="foo")
171
172
173
174
175
176
    train_op = tf.constant(3)

    hook = metric_hook.LoggingMetricHook(
        tensors=[t.name], every_n_secs=1.0, at_end=at_end,
        metric_logger=self._logger)
    hook.begin()
Karmel Allison's avatar
Karmel Allison committed
177
    mon_sess = monitored_session._HookedSession(sess, [hook])  # pylint: disable=protected-access
178
    sess.run(tf.compat.v1.global_variables_initializer())
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201

    mon_sess.run(train_op)
    self.assertRegexpMatches(str(self._logger.logged_metric), t.name)

    # assertNotRegexpMatches is not supported by python 3.1 and later
    self._logger.logged_metric = []
    mon_sess.run(train_op)
    self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
    time.sleep(1.0)

    self._logger.logged_metric = []
    mon_sess.run(train_op)
    self.assertRegexpMatches(str(self._logger.logged_metric), t.name)

    self._logger.logged_metric = []
    hook.end(sess)
    if at_end:
      self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
    else:
      # assertNotRegexpMatches is not supported by python 3.1 and later
      self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)

  def test_print_every_n_secs(self):
202
203
    with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
      tf.compat.v1.train.get_or_create_global_step()
204
205
206
207
208
      self._validate_print_every_n_secs(sess, at_end=False)
      # Verify proper reset.
      self._validate_print_every_n_secs(sess, at_end=False)

  def test_print_every_n_secs_and_end(self):
209
210
    with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
      tf.compat.v1.train.get_or_create_global_step()
211
212
213
214
215
      self._validate_print_every_n_secs(sess, at_end=True)
      # Verify proper reset.
      self._validate_print_every_n_secs(sess, at_end=True)


Karmel Allison's avatar
Karmel Allison committed
216
if __name__ == "__main__":
217
  tf.test.main()