Unverified Commit d41ed934 authored by Yanhui Liang's avatar Yanhui Liang Committed by GitHub
Browse files

Fix hooks_test for examples/second hook (#4411)

* Fix hooks_test

* Add more comments

* Fix lints
parent 04c81871
......@@ -22,17 +22,25 @@ from __future__ import print_function
import time
import tensorflow as tf # pylint: disable=g-bad-import-order
from tensorflow.python.training import monitored_session # pylint: disable=g-bad-import-order
from official.utils.logs import hooks
from official.utils.testing import mock_lib
tf.logging.set_verbosity(tf.logging.DEBUG)
class ExamplesPerSecondHookTest(tf.test.TestCase):
"""Tests for the ExamplesPerSecondHook."""
"""Tests for the ExamplesPerSecondHook.
In this test, we explicitly run global_step tensor after train_op in order to
grab the correct global step value. This is to correct for discrepancies in
reported global step when running on GPUs. As in the after_run functions in
ExamplesPerSecondHook, the global step from run_results
(global_step = run_values.results) is not always correct and taken as the
stale global_step (which may be 1 off the correct value). The exact
global_step value should be from run_context
(global_step = run_context.session.run(global_step_tensor)
"""
def setUp(self):
"""Mock out logging calls to verify if correct info is being monitored."""
......@@ -40,8 +48,9 @@ class ExamplesPerSecondHookTest(tf.test.TestCase):
self.graph = tf.Graph()
with self.graph.as_default():
self.global_step = tf.train.get_or_create_global_step()
self.train_op = tf.assign_add(self.global_step, 1)
tf.train.create_global_step()
self.train_op = tf.assign_add(tf.train.get_global_step(), 1)
self.global_step = tf.train.get_global_step()
def test_raise_in_both_secs_and_steps(self):
with self.assertRaises(ValueError):
......@@ -59,86 +68,88 @@ class ExamplesPerSecondHookTest(tf.test.TestCase):
every_n_secs=None,
metric_logger=self._logger)
def _validate_log_every_n_steps(self, sess, every_n_steps, warm_steps):
def _validate_log_every_n_steps(self, every_n_steps, warm_steps):
hook = hooks.ExamplesPerSecondHook(
batch_size=256,
every_n_steps=every_n_steps,
warm_steps=warm_steps,
metric_logger=self._logger)
hook.begin()
mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access
sess.run(tf.global_variables_initializer())
for _ in range(every_n_steps):
with tf.train.MonitoredSession(
tf.train.ChiefSessionCreator(), [hook]) as mon_sess:
for _ in range(every_n_steps):
# Explicitly run global_step after train_op to get the accurate
# global_step value
mon_sess.run(self.train_op)
mon_sess.run(self.global_step)
# Nothing should be in the list yet
self.assertFalse(self._logger.logged_metric)
mon_sess.run(self.train_op)
# Nothing should be in the list yet
self.assertFalse(self._logger.logged_metric)
global_step_val = mon_sess.run(self.global_step)
mon_sess.run(self.train_op)
global_step_val = sess.run(self.global_step)
if global_step_val > warm_steps:
self._assert_metrics()
else:
# Nothing should be in the list yet
self.assertFalse(self._logger.logged_metric)
if global_step_val > warm_steps:
self._assert_metrics()
else:
# Nothing should be in the list yet
self.assertFalse(self._logger.logged_metric)
# Add additional run to verify proper reset when called multiple times.
prev_log_len = len(self._logger.logged_metric)
mon_sess.run(self.train_op)
global_step_val = sess.run(self.global_step)
if every_n_steps == 1 and global_step_val > warm_steps:
# Each time, we log two additional metrics. Did exactly 2 get added?
self.assertEqual(len(self._logger.logged_metric), prev_log_len + 2)
else:
# No change in the size of the metric list.
self.assertEqual(len(self._logger.logged_metric), prev_log_len)
# Add additional run to verify proper reset when called multiple times.
prev_log_len = len(self._logger.logged_metric)
mon_sess.run(self.train_op)
global_step_val = mon_sess.run(self.global_step)
hook.end(sess)
if every_n_steps == 1 and global_step_val > warm_steps:
# Each time, we log two additional metrics. Did exactly 2 get added?
self.assertEqual(len(self._logger.logged_metric), prev_log_len + 2)
else:
# No change in the size of the metric list.
self.assertEqual(len(self._logger.logged_metric), prev_log_len)
def test_examples_per_sec_every_1_steps(self):
with self.graph.as_default(), tf.Session() as sess:
self._validate_log_every_n_steps(sess, 1, 0)
with self.graph.as_default():
self._validate_log_every_n_steps(1, 0)
def test_examples_per_sec_every_5_steps(self):
with self.graph.as_default(), tf.Session() as sess:
self._validate_log_every_n_steps(sess, 5, 0)
with self.graph.as_default():
self._validate_log_every_n_steps(5, 0)
def test_examples_per_sec_every_1_steps_with_warm_steps(self):
with self.graph.as_default(), tf.Session() as sess:
self._validate_log_every_n_steps(sess, 1, 10)
with self.graph.as_default():
self._validate_log_every_n_steps(1, 10)
def test_examples_per_sec_every_5_steps_with_warm_steps(self):
with self.graph.as_default(), tf.Session() as sess:
self._validate_log_every_n_steps(sess, 5, 10)
with self.graph.as_default():
self._validate_log_every_n_steps(5, 10)
def _validate_log_every_n_secs(self, sess, every_n_secs):
def _validate_log_every_n_secs(self, every_n_secs):
hook = hooks.ExamplesPerSecondHook(
batch_size=256,
every_n_steps=None,
every_n_secs=every_n_secs,
metric_logger=self._logger)
hook.begin()
mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access
sess.run(tf.global_variables_initializer())
mon_sess.run(self.train_op)
# Nothing should be in the list yet
self.assertFalse(self._logger.logged_metric)
time.sleep(every_n_secs)
mon_sess.run(self.train_op)
self._assert_metrics()
with tf.train.MonitoredSession(
tf.train.ChiefSessionCreator(), [hook]) as mon_sess:
# Explicitly run global_step after train_op to get the accurate
# global_step value
mon_sess.run(self.train_op)
mon_sess.run(self.global_step)
# Nothing should be in the list yet
self.assertFalse(self._logger.logged_metric)
time.sleep(every_n_secs)
hook.end(sess)
mon_sess.run(self.train_op)
mon_sess.run(self.global_step)
self._assert_metrics()
def test_examples_per_sec_every_1_secs(self):
with self.graph.as_default(), tf.Session() as sess:
self._validate_log_every_n_secs(sess, 1)
with self.graph.as_default():
self._validate_log_every_n_secs(1)
def test_examples_per_sec_every_5_secs(self):
with self.graph.as_default(), tf.Session() as sess:
self._validate_log_every_n_secs(sess, 5)
with self.graph.as_default():
self._validate_log_every_n_secs(5)
def _assert_metrics(self):
metrics = self._logger.logged_metric
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment