Unverified Commit 81a34cbe authored by Yanhui Liang's avatar Yanhui Liang Committed by GitHub
Browse files

Add logging utils (#3519)

* Adding logging utils

* restore utils

* delete old file

* update inputs and docstrings

* Update import and fix typos

* Fix formatting and comments

* Update tests
parent 4c7c8fa7
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Hook that counts examples per second every N steps or seconds."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
class ExamplesPerSecondHook(tf.train.SessionRunHook):
"""Hook to print out examples per second.
Total time is tracked and then divided by the total number of steps
to get the average step time and then batch_size is used to determine
the running average of examples per second. The examples per second for the
most recent interval is also logged.
"""
def __init__(self,
batch_size,
every_n_steps=None,
every_n_secs=None,
warm_steps=0):
"""Initializer for ExamplesPerSecondHook.
Args:
batch_size: Total batch size across all workers used to calculate
examples/second from global time.
every_n_steps: Log stats every n steps.
every_n_secs: Log stats every n seconds. Exactly one of the
`every_n_steps` or `every_n_secs` should be set.
warm_steps: The number of steps to be skipped before logging and running
average calculation. warm_steps steps refers to global steps across all
workers, not on each worker
Raises:
ValueError: if neither `every_n_steps` or `every_n_secs` is set, or
both are set.
"""
if (every_n_steps is None) == (every_n_secs is None):
raise ValueError('exactly one of every_n_steps'
' and every_n_secs should be provided.')
self._timer = tf.train.SecondOrStepTimer(
every_steps=every_n_steps, every_secs=every_n_secs)
self._step_train_time = 0
self._total_steps = 0
self._batch_size = batch_size
self._warm_steps = warm_steps
def begin(self):
"""Called once before using the session to check global step."""
self._global_step_tensor = tf.train.get_global_step()
if self._global_step_tensor is None:
raise RuntimeError(
'Global step should be created to use StepCounterHook.')
def before_run(self, run_context): # pylint: disable=unused-argument
"""Called before each call to run().
Args:
run_context: A SessionRunContext object.
Returns:
A SessionRunArgs object or None if never triggered.
"""
return tf.train.SessionRunArgs(self._global_step_tensor)
def after_run(self, run_context, run_values): # pylint: disable=unused-argument
"""Called after each call to run().
Args:
run_context: A SessionRunContext object.
run_values: A SessionRunValues object.
"""
global_step = run_values.results
if self._timer.should_trigger_for_step(
global_step) and global_step > self._warm_steps:
elapsed_time, elapsed_steps = self._timer.update_last_triggered_step(
global_step)
if elapsed_time is not None:
self._step_train_time += elapsed_time
self._total_steps += elapsed_steps
# average examples per second is based on the total (accumulative)
# training steps and training time so far
average_examples_per_sec = self._batch_size * (
self._total_steps / self._step_train_time)
# current examples per second is based on the elapsed training steps
# and training time per batch
current_examples_per_sec = self._batch_size * (
elapsed_steps / elapsed_time)
# Current examples/sec followed by average examples/sec
tf.logging.info('Batch [%g]: current exp/sec = %g, average exp/sec = '
'%g', self._total_steps, current_examples_per_sec,
average_examples_per_sec)
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Hooks helper to return a list of TensorFlow hooks for training by name.
More hooks can be added to this set. To add a new hook, 1) add the new hook to
the registry in HOOKS, 2) add a corresponding function that parses out necessary
parameters.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from official.utils.logging import hooks
_TENSORS_TO_LOG = dict((x, x) for x in ['learning_rate',
'cross_entropy',
'train_accuracy'])
def get_train_hooks(name_list, **kwargs):
"""Factory for getting a list of TensorFlow hooks for training by name.
Args:
name_list: a list of strings to name desired hook classes. Allowed:
LoggingTensorHook, ProfilerHook, ExamplesPerSecondHook, which are defined
as keys in HOOKS
kwargs: a dictionary of arguments to the hooks.
Returns:
list of instantiated hooks, ready to be used in a classifier.train call.
Raises:
ValueError: if an unrecognized name is passed.
"""
if not name_list:
return []
train_hooks = []
for name in name_list:
hook_name = HOOKS.get(name.strip().lower())
if hook_name is None:
raise ValueError('Unrecognized training hook requested: {}'.format(name))
else:
train_hooks.append(hook_name(**kwargs))
return train_hooks
def get_logging_tensor_hook(every_n_iter=100, **kwargs): # pylint: disable=unused-argument
"""Function to get LoggingTensorHook.
Args:
every_n_iter: `int`, print the values of `tensors` once every N local
steps taken on the current worker.
kwargs: a dictionary of arguments to LoggingTensorHook.
Returns:
Returns a LoggingTensorHook with a standard set of tensors that will be
printed to stdout.
"""
return tf.train.LoggingTensorHook(
tensors=_TENSORS_TO_LOG,
every_n_iter=every_n_iter)
def get_profiler_hook(save_steps=1000, **kwargs): # pylint: disable=unused-argument
"""Function to get ProfilerHook.
Args:
save_steps: `int`, print profile traces every N steps.
kwargs: a dictionary of arguments to ProfilerHook.
Returns:
Returns a ProfilerHook that writes out timelines that can be loaded into
profiling tools like chrome://tracing.
"""
return tf.train.ProfilerHook(save_steps=save_steps)
def get_examples_per_second_hook(every_n_steps=100,
batch_size=128,
warm_steps=10,
**kwargs): # pylint: disable=unused-argument
"""Function to get ExamplesPerSecondHook.
Args:
every_n_steps: `int`, print current and average examples per second every
N steps.
batch_size: `int`, total batch size used to calculate examples/second from
global time.
warm_steps: skip this number of steps before logging and running average.
kwargs: a dictionary of arguments to ExamplesPerSecondHook.
Returns:
Returns a ProfilerHook that writes out timelines that can be loaded into
profiling tools like chrome://tracing.
"""
return hooks.ExamplesPerSecondHook(every_n_steps=every_n_steps,
batch_size=batch_size,
warm_steps=warm_steps)
# A dictionary to map one hook name and its corresponding function
HOOKS = {
'loggingtensorhook': get_logging_tensor_hook,
'profilerhook': get_profiler_hook,
'examplespersecondhook': get_examples_per_second_hook,
}
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for hooks_helper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from official.utils.logging import hooks_helper
tf.logging.set_verbosity(tf.logging.ERROR)
class BaseTest(tf.test.TestCase):
def test_raise_in_non_list_names(self):
with self.assertRaises(ValueError):
hooks_helper.get_train_hooks(
'LoggingTensorHook, ProfilerHook', batch_size=256)
def test_raise_in_invalid_names(self):
invalid_names = ['StepCounterHook', 'StopAtStepHook']
with self.assertRaises(ValueError):
hooks_helper.get_train_hooks(invalid_names, batch_size=256)
def validate_train_hook_name(self,
test_hook_name,
expected_hook_name,
**kwargs):
returned_hook = hooks_helper.get_train_hooks([test_hook_name], **kwargs)
self.assertEqual(len(returned_hook), 1)
self.assertIsInstance(returned_hook[0], tf.train.SessionRunHook)
self.assertEqual(returned_hook[0].__class__.__name__.lower(),
expected_hook_name)
def test_get_train_hooks_LoggingTensorHook(self):
test_hook_name = 'LoggingTensorHook'
self.validate_train_hook_name(test_hook_name, 'loggingtensorhook')
def test_get_train_hooks_ProfilerHook(self):
test_hook_name = 'ProfilerHook'
self.validate_train_hook_name(test_hook_name, 'profilerhook')
def test_get_train_hooks_ExamplesPerSecondHook(self):
test_hook_name = 'ExamplesPerSecondHook'
self.validate_train_hook_name(test_hook_name, 'examplespersecondhook')
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for hooks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import tensorflow as tf
from tensorflow.python.training import monitored_session
from official.utils.logging import hooks
tf.logging.set_verbosity(tf.logging.ERROR)
class ExamplesPerSecondHookTest(tf.test.TestCase):
def setUp(self):
"""Mock out logging calls to verify if correct info is being monitored."""
self._actual_log = tf.logging.info
self.logged_message = None
def mock_log(*args, **kwargs):
self.logged_message = args
self._actual_log(*args, **kwargs)
tf.logging.info = mock_log
self.graph = tf.Graph()
with self.graph.as_default():
self.global_step = tf.train.get_or_create_global_step()
self.train_op = tf.assign_add(self.global_step, 1)
def tearDown(self):
tf.logging.info = self._actual_log
def test_raise_in_both_secs_and_steps(self):
with self.assertRaises(ValueError):
hooks.ExamplesPerSecondHook(
batch_size=256,
every_n_steps=10,
every_n_secs=20)
def test_raise_in_none_secs_and_steps(self):
with self.assertRaises(ValueError):
hooks.ExamplesPerSecondHook(
batch_size=256,
every_n_steps=None,
every_n_secs=None)
def _validate_log_every_n_steps(self, sess, every_n_steps, warm_steps):
hook = hooks.ExamplesPerSecondHook(
batch_size=256,
every_n_steps=every_n_steps,
warm_steps=warm_steps)
hook.begin()
mon_sess = monitored_session._HookedSession(sess, [hook])
sess.run(tf.global_variables_initializer())
self.logged_message = ''
for _ in range(every_n_steps):
mon_sess.run(self.train_op)
self.assertEqual(str(self.logged_message).find('exp/sec'), -1)
mon_sess.run(self.train_op)
global_step_val = sess.run(self.global_step)
# assertNotRegexpMatches is not supported by python 3.1 and later
if global_step_val > warm_steps:
self.assertRegexpMatches(str(self.logged_message), 'exp/sec')
else:
self.assertEqual(str(self.logged_message).find('exp/sec'), -1)
# Add additional run to verify proper reset when called multiple times.
self.logged_message = ''
mon_sess.run(self.train_op)
global_step_val = sess.run(self.global_step)
if every_n_steps == 1 and global_step_val > warm_steps:
self.assertRegexpMatches(str(self.logged_message), 'exp/sec')
else:
self.assertEqual(str(self.logged_message).find('exp/sec'), -1)
hook.end(sess)
def test_examples_per_sec_every_1_steps(self):
with self.graph.as_default(), tf.Session() as sess:
self._validate_log_every_n_steps(sess, 1, 0)
def test_examples_per_sec_every_5_steps(self):
with self.graph.as_default(), tf.Session() as sess:
self._validate_log_every_n_steps(sess, 5, 0)
def test_examples_per_sec_every_1_steps_with_warm_steps(self):
with self.graph.as_default(), tf.Session() as sess:
self._validate_log_every_n_steps(sess, 1, 10)
def test_examples_per_sec_every_5_steps_with_warm_steps(self):
with self.graph.as_default(), tf.Session() as sess:
self._validate_log_every_n_steps(sess, 5, 10)
def _validate_log_every_n_secs(self, sess, every_n_secs):
hook = hooks.ExamplesPerSecondHook(
batch_size=256,
every_n_steps=None,
every_n_secs=every_n_secs)
hook.begin()
mon_sess = monitored_session._HookedSession(sess, [hook])
sess.run(tf.global_variables_initializer())
self.logged_message = ''
mon_sess.run(self.train_op)
self.assertEqual(str(self.logged_message).find('exp/sec'), -1)
time.sleep(every_n_secs)
self.logged_message = ''
mon_sess.run(self.train_op)
self.assertRegexpMatches(str(self.logged_message), 'exp/sec')
hook.end(sess)
def test_examples_per_sec_every_1_secs(self):
with self.graph.as_default(), tf.Session() as sess:
self._validate_log_every_n_secs(sess, 1)
def test_examples_per_sec_every_5_secs(self):
with self.graph.as_default(), tf.Session() as sess:
self._validate_log_every_n_secs(sess, 5)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment