Commit ec7fbf0d authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

[Clean up] Move utils/logs to r1/utils.

PiperOrigin-RevId: 309079916
parent 87208da1
......@@ -19,11 +19,10 @@ import os
from absl import app as absl_app
from absl import flags
import tensorflow as tf
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.r1.utils.logs import logger
from official.r1.wide_deep import census_dataset
from official.r1.wide_deep import wide_deep_run_loop
from official.utils.flags import core as flags_core
def define_census_flags():
......
......@@ -23,12 +23,11 @@ import os
from absl import app as absl_app
from absl import flags
import tensorflow as tf
from official.recommendation import movielens
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.r1.utils.logs import logger
from official.r1.wide_deep import movielens_dataset
from official.r1.wide_deep import wide_deep_run_loop
from official.recommendation import movielens
from official.utils.flags import core as flags_core
def define_movie_flags():
......
......@@ -23,11 +23,11 @@ import shutil
from absl import app as absl_app
from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order
import tensorflow as tf
from official.r1.utils.logs import hooks_helper
from official.r1.utils.logs import logger
from official.utils.flags import core as flags_core
from official.utils.logs import hooks_helper
from official.utils.logs import logger
from official.utils.misc import model_helpers
......
......@@ -20,9 +20,7 @@ from __future__ import print_function
from absl import flags
import tensorflow as tf
from official.utils.flags._conventions import help_wrap
from official.utils.logs import hooks_helper
def define_base(data_dir=True, model_dir=True, clean=False, train_epochs=False,
......@@ -114,17 +112,13 @@ def define_base(data_dir=True, model_dir=True, clean=False, train_epochs=False,
help="Run the model op by op without building a model function.")
if hooks:
# Construct a pretty summary of hooks.
hook_list_str = (
u"\ufeff Hook:\n" + u"\n".join([u"\ufeff {}".format(key) for key
in hooks_helper.HOOKS]))
flags.DEFINE_list(
name="hooks", short_name="hk", default="LoggingTensorHook",
help=help_wrap(
u"A list of (case insensitive) strings to specify the names of "
u"training hooks.\n{}\n\ufeff Example: `--hooks ProfilerHook,"
u"ExamplesPerSecondHook`\n See official.utils.logs.hooks_helper "
u"for details.".format(hook_list_str))
u"training hooks. Example: `--hooks ProfilerHook,"
u"ExamplesPerSecondHook`\n See hooks_helper "
u"for details.")
)
key_flags.append("hooks")
......
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for hooks_helper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.logs import hooks_helper
from official.utils.misc import keras_utils
class BaseTest(unittest.TestCase):
def setUp(self):
super(BaseTest, self).setUp()
if keras_utils.is_v2_0:
tf.compat.v1.disable_eager_execution()
def test_raise_in_non_list_names(self):
with self.assertRaises(ValueError):
hooks_helper.get_train_hooks(
'LoggingTensorHook, ProfilerHook', model_dir="", batch_size=256)
def test_raise_in_invalid_names(self):
invalid_names = ['StepCounterHook', 'StopAtStepHook']
with self.assertRaises(ValueError):
hooks_helper.get_train_hooks(invalid_names, model_dir="", batch_size=256)
def validate_train_hook_name(self,
test_hook_name,
expected_hook_name,
**kwargs):
returned_hook = hooks_helper.get_train_hooks(
[test_hook_name], model_dir="", **kwargs)
self.assertEqual(len(returned_hook), 1)
self.assertIsInstance(returned_hook[0], tf.estimator.SessionRunHook)
self.assertEqual(returned_hook[0].__class__.__name__.lower(),
expected_hook_name)
def test_get_train_hooks_logging_tensor_hook(self):
self.validate_train_hook_name('LoggingTensorHook', 'loggingtensorhook')
def test_get_train_hooks_profiler_hook(self):
self.validate_train_hook_name('ProfilerHook', 'profilerhook')
def test_get_train_hooks_examples_per_second_hook(self):
self.validate_train_hook_name('ExamplesPerSecondHook',
'examplespersecondhook')
def test_get_logging_metric_hook(self):
test_hook_name = 'LoggingMetricHook'
self.validate_train_hook_name(test_hook_name, 'loggingmetrichook')
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment