Unverified Commit d626b908 authored by Karmel Allison's avatar Karmel Allison Committed by GitHub
Browse files

Fix/log ex per sec (#4360)

* Using BenchmarkLogger

* Using BenchmarkLogger

* Fixing tests

* Linting fixes.

* Adding comments

* Moving mock logger

* Moving mock logger

* Glinting

* Responding to CR

* Reverting assertEmpty
parent 023fc2b2
...@@ -20,7 +20,9 @@ from __future__ import absolute_import ...@@ -20,7 +20,9 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.logs import logger
class ExamplesPerSecondHook(tf.train.SessionRunHook): class ExamplesPerSecondHook(tf.train.SessionRunHook):
...@@ -36,7 +38,8 @@ class ExamplesPerSecondHook(tf.train.SessionRunHook): ...@@ -36,7 +38,8 @@ class ExamplesPerSecondHook(tf.train.SessionRunHook):
batch_size, batch_size,
every_n_steps=None, every_n_steps=None,
every_n_secs=None, every_n_secs=None,
warm_steps=0): warm_steps=0,
metric_logger=None):
"""Initializer for ExamplesPerSecondHook. """Initializer for ExamplesPerSecondHook.
Args: Args:
...@@ -48,6 +51,9 @@ class ExamplesPerSecondHook(tf.train.SessionRunHook): ...@@ -48,6 +51,9 @@ class ExamplesPerSecondHook(tf.train.SessionRunHook):
warm_steps: The number of steps to be skipped before logging and running warm_steps: The number of steps to be skipped before logging and running
average calculation. warm_steps steps refers to global steps across all average calculation. warm_steps steps refers to global steps across all
workers, not on each worker workers, not on each worker
metric_logger: instance of `BenchmarkLogger`, the benchmark logger that
hook should use to write the log. If None, BaseBenchmarkLogger will
be used.
Raises: Raises:
ValueError: if neither `every_n_steps` or `every_n_secs` is set, or ValueError: if neither `every_n_steps` or `every_n_secs` is set, or
...@@ -55,8 +61,10 @@ class ExamplesPerSecondHook(tf.train.SessionRunHook): ...@@ -55,8 +61,10 @@ class ExamplesPerSecondHook(tf.train.SessionRunHook):
""" """
if (every_n_steps is None) == (every_n_secs is None): if (every_n_steps is None) == (every_n_secs is None):
raise ValueError('exactly one of every_n_steps' raise ValueError("exactly one of every_n_steps"
' and every_n_secs should be provided.') " and every_n_secs should be provided.")
self._logger = metric_logger or logger.BaseBenchmarkLogger()
self._timer = tf.train.SecondOrStepTimer( self._timer = tf.train.SecondOrStepTimer(
every_steps=every_n_steps, every_secs=every_n_secs) every_steps=every_n_steps, every_secs=every_n_secs)
...@@ -71,7 +79,7 @@ class ExamplesPerSecondHook(tf.train.SessionRunHook): ...@@ -71,7 +79,7 @@ class ExamplesPerSecondHook(tf.train.SessionRunHook):
self._global_step_tensor = tf.train.get_global_step() self._global_step_tensor = tf.train.get_global_step()
if self._global_step_tensor is None: if self._global_step_tensor is None:
raise RuntimeError( raise RuntimeError(
'Global step should be created to use StepCounterHook.') "Global step should be created to use StepCounterHook.")
def before_run(self, run_context): # pylint: disable=unused-argument def before_run(self, run_context): # pylint: disable=unused-argument
"""Called before each call to run(). """Called before each call to run().
...@@ -109,7 +117,11 @@ class ExamplesPerSecondHook(tf.train.SessionRunHook): ...@@ -109,7 +117,11 @@ class ExamplesPerSecondHook(tf.train.SessionRunHook):
# and training time per batch # and training time per batch
current_examples_per_sec = self._batch_size * ( current_examples_per_sec = self._batch_size * (
elapsed_steps / elapsed_time) elapsed_steps / elapsed_time)
# Current examples/sec followed by average examples/sec
tf.logging.info('Batch [%g]: current exp/sec = %g, average exp/sec = ' self._logger.log_metric(
'%g', self._total_steps, current_examples_per_sec, "average_examples_per_sec", average_examples_per_sec,
average_examples_per_sec) global_step=global_step)
self._logger.log_metric(
"current_examples_per_sec", current_examples_per_sec,
global_step=global_step)
...@@ -119,9 +119,9 @@ def get_examples_per_second_hook(every_n_steps=100, ...@@ -119,9 +119,9 @@ def get_examples_per_second_hook(every_n_steps=100,
Returns a ProfilerHook that writes out timelines that can be loaded into Returns a ProfilerHook that writes out timelines that can be loaded into
profiling tools like chrome://tracing. profiling tools like chrome://tracing.
""" """
return hooks.ExamplesPerSecondHook(every_n_steps=every_n_steps, return hooks.ExamplesPerSecondHook(
batch_size=batch_size, batch_size=batch_size, every_n_steps=every_n_steps,
warm_steps=warm_steps) warm_steps=warm_steps, metric_logger=logger.get_benchmark_logger())
def get_logging_metric_hook(tensors_to_log=None, def get_logging_metric_hook(tensors_to_log=None,
......
...@@ -25,9 +25,10 @@ import tensorflow as tf # pylint: disable=g-bad-import-order ...@@ -25,9 +25,10 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from tensorflow.python.training import monitored_session # pylint: disable=g-bad-import-order from tensorflow.python.training import monitored_session # pylint: disable=g-bad-import-order
from official.utils.logs import hooks from official.utils.logs import hooks
from official.utils.testing import mock_lib
tf.logging.set_verbosity(tf.logging.ERROR) tf.logging.set_verbosity(tf.logging.DEBUG)
class ExamplesPerSecondHookTest(tf.test.TestCase): class ExamplesPerSecondHookTest(tf.test.TestCase):
...@@ -35,67 +36,63 @@ class ExamplesPerSecondHookTest(tf.test.TestCase): ...@@ -35,67 +36,63 @@ class ExamplesPerSecondHookTest(tf.test.TestCase):
def setUp(self): def setUp(self):
"""Mock out logging calls to verify if correct info is being monitored.""" """Mock out logging calls to verify if correct info is being monitored."""
self._actual_log = tf.logging.info self._logger = mock_lib.MockBenchmarkLogger()
self.logged_message = None
def mock_log(*args, **kwargs):
self.logged_message = args
self._actual_log(*args, **kwargs)
tf.logging.info = mock_log
self.graph = tf.Graph() self.graph = tf.Graph()
with self.graph.as_default(): with self.graph.as_default():
self.global_step = tf.train.get_or_create_global_step() self.global_step = tf.train.get_or_create_global_step()
self.train_op = tf.assign_add(self.global_step, 1) self.train_op = tf.assign_add(self.global_step, 1)
def tearDown(self):
tf.logging.info = self._actual_log
def test_raise_in_both_secs_and_steps(self): def test_raise_in_both_secs_and_steps(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
hooks.ExamplesPerSecondHook( hooks.ExamplesPerSecondHook(
batch_size=256, batch_size=256,
every_n_steps=10, every_n_steps=10,
every_n_secs=20) every_n_secs=20,
metric_logger=self._logger)
def test_raise_in_none_secs_and_steps(self): def test_raise_in_none_secs_and_steps(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
hooks.ExamplesPerSecondHook( hooks.ExamplesPerSecondHook(
batch_size=256, batch_size=256,
every_n_steps=None, every_n_steps=None,
every_n_secs=None) every_n_secs=None,
metric_logger=self._logger)
def _validate_log_every_n_steps(self, sess, every_n_steps, warm_steps): def _validate_log_every_n_steps(self, sess, every_n_steps, warm_steps):
hook = hooks.ExamplesPerSecondHook( hook = hooks.ExamplesPerSecondHook(
batch_size=256, batch_size=256,
every_n_steps=every_n_steps, every_n_steps=every_n_steps,
warm_steps=warm_steps) warm_steps=warm_steps,
metric_logger=self._logger)
hook.begin() hook.begin()
mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
self.logged_message = ''
for _ in range(every_n_steps): for _ in range(every_n_steps):
mon_sess.run(self.train_op) mon_sess.run(self.train_op)
self.assertEqual(str(self.logged_message).find('exp/sec'), -1) # Nothing should be in the list yet
self.assertFalse(self._logger.logged_metric)
mon_sess.run(self.train_op) mon_sess.run(self.train_op)
global_step_val = sess.run(self.global_step) global_step_val = sess.run(self.global_step)
# assertNotRegexpMatches is not supported by python 3.1 and later
if global_step_val > warm_steps: if global_step_val > warm_steps:
self.assertRegexpMatches(str(self.logged_message), 'exp/sec') self._assert_metrics()
else: else:
self.assertEqual(str(self.logged_message).find('exp/sec'), -1) # Nothing should be in the list yet
self.assertFalse(self._logger.logged_metric)
# Add additional run to verify proper reset when called multiple times. # Add additional run to verify proper reset when called multiple times.
self.logged_message = '' prev_log_len = len(self._logger.logged_metric)
mon_sess.run(self.train_op) mon_sess.run(self.train_op)
global_step_val = sess.run(self.global_step) global_step_val = sess.run(self.global_step)
if every_n_steps == 1 and global_step_val > warm_steps: if every_n_steps == 1 and global_step_val > warm_steps:
self.assertRegexpMatches(str(self.logged_message), 'exp/sec') # Each time, we log two additional metrics. Did exactly 2 get added?
self.assertEqual(len(self._logger.logged_metric), prev_log_len + 2)
else: else:
self.assertEqual(str(self.logged_message).find('exp/sec'), -1) # No change in the size of the metric list.
self.assertEqual(len(self._logger.logged_metric), prev_log_len)
hook.end(sess) hook.end(sess)
...@@ -119,19 +116,19 @@ class ExamplesPerSecondHookTest(tf.test.TestCase): ...@@ -119,19 +116,19 @@ class ExamplesPerSecondHookTest(tf.test.TestCase):
hook = hooks.ExamplesPerSecondHook( hook = hooks.ExamplesPerSecondHook(
batch_size=256, batch_size=256,
every_n_steps=None, every_n_steps=None,
every_n_secs=every_n_secs) every_n_secs=every_n_secs,
metric_logger=self._logger)
hook.begin() hook.begin()
mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
self.logged_message = ''
mon_sess.run(self.train_op) mon_sess.run(self.train_op)
self.assertEqual(str(self.logged_message).find('exp/sec'), -1) # Nothing should be in the list yet
self.assertFalse(self._logger.logged_metric)
time.sleep(every_n_secs) time.sleep(every_n_secs)
self.logged_message = ''
mon_sess.run(self.train_op) mon_sess.run(self.train_op)
self.assertRegexpMatches(str(self.logged_message), 'exp/sec') self._assert_metrics()
hook.end(sess) hook.end(sess)
...@@ -143,6 +140,11 @@ class ExamplesPerSecondHookTest(tf.test.TestCase): ...@@ -143,6 +140,11 @@ class ExamplesPerSecondHookTest(tf.test.TestCase):
with self.graph.as_default(), tf.Session() as sess: with self.graph.as_default(), tf.Session() as sess:
self._validate_log_every_n_secs(sess, 5) self._validate_log_every_n_secs(sess, 5)
def _assert_metrics(self):
metrics = self._logger.logged_metric
self.assertEqual(metrics[-2]["name"], "average_examples_per_sec")
self.assertEqual(metrics[-1]["name"], "current_examples_per_sec")
if __name__ == '__main__': if __name__ == "__main__":
tf.test.main() tf.test.main()
...@@ -47,20 +47,20 @@ _logger_lock = threading.Lock() ...@@ -47,20 +47,20 @@ _logger_lock = threading.Lock()
def config_benchmark_logger(flag_obj=None): def config_benchmark_logger(flag_obj=None):
"""Config the global benchmark logger""" """Config the global benchmark logger."""
_logger_lock.acquire() _logger_lock.acquire()
try: try:
global _benchmark_logger global _benchmark_logger
if not flag_obj: if not flag_obj:
flag_obj = FLAGS flag_obj = FLAGS
if (not hasattr(flag_obj, 'benchmark_logger_type') or if (not hasattr(flag_obj, "benchmark_logger_type") or
flag_obj.benchmark_logger_type == 'BaseBenchmarkLogger'): flag_obj.benchmark_logger_type == "BaseBenchmarkLogger"):
_benchmark_logger = BaseBenchmarkLogger() _benchmark_logger = BaseBenchmarkLogger()
elif flag_obj.benchmark_logger_type == 'BenchmarkFileLogger': elif flag_obj.benchmark_logger_type == "BenchmarkFileLogger":
_benchmark_logger = BenchmarkFileLogger(flag_obj.benchmark_log_dir) _benchmark_logger = BenchmarkFileLogger(flag_obj.benchmark_log_dir)
elif flag_obj.benchmark_logger_type == 'BenchmarkBigQueryLogger': elif flag_obj.benchmark_logger_type == "BenchmarkBigQueryLogger":
from official.benchmark import benchmark_uploader as bu # pylint: disable=g-import-not-at-top from official.benchmark import benchmark_uploader as bu # pylint: disable=g-import-not-at-top
bq_uploader = bu.BigQueryUploader(gcp_project=flag_obj.gcp_project) bq_uploader = bu.BigQueryUploader(gcp_project=flag_obj.gcp_project)
_benchmark_logger = BenchmarkBigQueryLogger( _benchmark_logger = BenchmarkBigQueryLogger(
bigquery_uploader=bq_uploader, bigquery_uploader=bq_uploader,
...@@ -69,8 +69,8 @@ def config_benchmark_logger(flag_obj=None): ...@@ -69,8 +69,8 @@ def config_benchmark_logger(flag_obj=None):
bigquery_metric_table=flag_obj.bigquery_metric_table, bigquery_metric_table=flag_obj.bigquery_metric_table,
run_id=str(uuid.uuid4())) run_id=str(uuid.uuid4()))
else: else:
raise ValueError('Unrecognized benchmark_logger_type: %s', raise ValueError("Unrecognized benchmark_logger_type: %s"
flag_obj.benchmark_logger_type) % flag_obj.benchmark_logger_type)
finally: finally:
_logger_lock.release() _logger_lock.release()
...@@ -247,6 +247,7 @@ class BenchmarkBigQueryLogger(BaseBenchmarkLogger): ...@@ -247,6 +247,7 @@ class BenchmarkBigQueryLogger(BaseBenchmarkLogger):
self._run_id, self._run_id,
run_info)) run_info))
def _gather_run_info(model_name, dataset_name, run_params): def _gather_run_info(model_name, dataset_name, run_params):
"""Collect the benchmark run information for the local environment.""" """Collect the benchmark run information for the local environment."""
run_info = { run_info = {
...@@ -303,6 +304,7 @@ def _collect_run_params(run_info, run_params): ...@@ -303,6 +304,7 @@ def _collect_run_params(run_info, run_params):
run_info["run_parameters"] = [ run_info["run_parameters"] = [
process_param(k, v) for k, v in sorted(run_params.items())] process_param(k, v) for k, v in sorted(run_params.items())]
def _collect_tensorflow_environment_variables(run_info): def _collect_tensorflow_environment_variables(run_info):
run_info["tensorflow_environment_variables"] = [ run_info["tensorflow_environment_variables"] = [
{"name": k, "value": v} {"name": k, "value": v}
......
...@@ -21,10 +21,11 @@ from __future__ import print_function ...@@ -21,10 +21,11 @@ from __future__ import print_function
import tempfile import tempfile
import time import time
import tensorflow as tf import tensorflow as tf # pylint: disable=g-bad-import-order
from tensorflow.python.training import monitored_session from tensorflow.python.training import monitored_session # pylint: disable=g-bad-import-order
from official.utils.logs import metric_hook # pylint: disable=g-bad-import-order from official.utils.logs import metric_hook
from official.utils.testing import mock_lib
class LoggingMetricHookTest(tf.test.TestCase): class LoggingMetricHookTest(tf.test.TestCase):
...@@ -33,49 +34,35 @@ class LoggingMetricHookTest(tf.test.TestCase): ...@@ -33,49 +34,35 @@ class LoggingMetricHookTest(tf.test.TestCase):
def setUp(self): def setUp(self):
super(LoggingMetricHookTest, self).setUp() super(LoggingMetricHookTest, self).setUp()
class MockMetricLogger(object):
def __init__(self):
self.logged_metric = []
def log_metric(self, name, value, unit=None, global_step=None,
extras=None):
self.logged_metric.append({
"name": name,
"value": float(value),
"unit": unit,
"global_step": global_step,
"extras": extras})
self._log_dir = tempfile.mkdtemp(dir=self.get_temp_dir()) self._log_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
self._logger = MockMetricLogger() self._logger = mock_lib.MockBenchmarkLogger()
def tearDown(self): def tearDown(self):
super(LoggingMetricHookTest, self).tearDown() super(LoggingMetricHookTest, self).tearDown()
tf.gfile.DeleteRecursively(self.get_temp_dir()) tf.gfile.DeleteRecursively(self.get_temp_dir())
def test_illegal_args(self): def test_illegal_args(self):
with self.assertRaisesRegexp(ValueError, 'nvalid every_n_iter'): with self.assertRaisesRegexp(ValueError, "nvalid every_n_iter"):
metric_hook.LoggingMetricHook(tensors=['t'], every_n_iter=0) metric_hook.LoggingMetricHook(tensors=["t"], every_n_iter=0)
with self.assertRaisesRegexp(ValueError, 'nvalid every_n_iter'): with self.assertRaisesRegexp(ValueError, "nvalid every_n_iter"):
metric_hook.LoggingMetricHook(tensors=['t'], every_n_iter=-10) metric_hook.LoggingMetricHook(tensors=["t"], every_n_iter=-10)
with self.assertRaisesRegexp(ValueError, 'xactly one of'): with self.assertRaisesRegexp(ValueError, "xactly one of"):
metric_hook.LoggingMetricHook( metric_hook.LoggingMetricHook(
tensors=['t'], every_n_iter=5, every_n_secs=5) tensors=["t"], every_n_iter=5, every_n_secs=5)
with self.assertRaisesRegexp(ValueError, 'xactly one of'): with self.assertRaisesRegexp(ValueError, "xactly one of"):
metric_hook.LoggingMetricHook(tensors=['t']) metric_hook.LoggingMetricHook(tensors=["t"])
with self.assertRaisesRegexp(ValueError, 'metric_logger'): with self.assertRaisesRegexp(ValueError, "metric_logger"):
metric_hook.LoggingMetricHook(tensors=['t'], every_n_iter=5) metric_hook.LoggingMetricHook(tensors=["t"], every_n_iter=5)
def test_print_at_end_only(self): def test_print_at_end_only(self):
with tf.Graph().as_default(), tf.Session() as sess: with tf.Graph().as_default(), tf.Session() as sess:
tf.train.get_or_create_global_step() tf.train.get_or_create_global_step()
t = tf.constant(42.0, name='foo') t = tf.constant(42.0, name="foo")
train_op = tf.constant(3) train_op = tf.constant(3)
hook = metric_hook.LoggingMetricHook( hook = metric_hook.LoggingMetricHook(
tensors=[t.name], at_end=True, metric_logger=self._logger) tensors=[t.name], at_end=True, metric_logger=self._logger)
hook.begin() hook.begin()
mon_sess = monitored_session._HookedSession(sess, [hook]) mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
for _ in range(3): for _ in range(3):
...@@ -91,25 +78,25 @@ class LoggingMetricHookTest(tf.test.TestCase): ...@@ -91,25 +78,25 @@ class LoggingMetricHookTest(tf.test.TestCase):
self.assertEqual(metric["global_step"], 0) self.assertEqual(metric["global_step"], 0)
def test_global_step_not_found(self): def test_global_step_not_found(self):
with tf.Graph().as_default(), tf.Session() as sess: with tf.Graph().as_default():
t = tf.constant(42.0, name='foo') t = tf.constant(42.0, name="foo")
hook = metric_hook.LoggingMetricHook( hook = metric_hook.LoggingMetricHook(
tensors=[t.name], at_end=True, metric_logger=self._logger) tensors=[t.name], at_end=True, metric_logger=self._logger)
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
RuntimeError, 'should be created to use LoggingMetricHook.'): RuntimeError, "should be created to use LoggingMetricHook."):
hook.begin() hook.begin()
def test_log_tensors(self): def test_log_tensors(self):
with tf.Graph().as_default(), tf.Session() as sess: with tf.Graph().as_default(), tf.Session() as sess:
tf.train.get_or_create_global_step() tf.train.get_or_create_global_step()
t1 = tf.constant(42.0, name='foo') t1 = tf.constant(42.0, name="foo")
t2 = tf.constant(43.0, name='bar') t2 = tf.constant(43.0, name="bar")
train_op = tf.constant(3) train_op = tf.constant(3)
hook = metric_hook.LoggingMetricHook( hook = metric_hook.LoggingMetricHook(
tensors=[t1, t2], at_end=True, metric_logger=self._logger) tensors=[t1, t2], at_end=True, metric_logger=self._logger)
hook.begin() hook.begin()
mon_sess = monitored_session._HookedSession(sess, [hook]) mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
for _ in range(3): for _ in range(3):
...@@ -131,14 +118,14 @@ class LoggingMetricHookTest(tf.test.TestCase): ...@@ -131,14 +118,14 @@ class LoggingMetricHookTest(tf.test.TestCase):
self.assertEqual(metric2["global_step"], 0) self.assertEqual(metric2["global_step"], 0)
def _validate_print_every_n_steps(self, sess, at_end): def _validate_print_every_n_steps(self, sess, at_end):
t = tf.constant(42.0, name='foo') t = tf.constant(42.0, name="foo")
train_op = tf.constant(3) train_op = tf.constant(3)
hook = metric_hook.LoggingMetricHook( hook = metric_hook.LoggingMetricHook(
tensors=[t.name], every_n_iter=10, at_end=at_end, tensors=[t.name], every_n_iter=10, at_end=at_end,
metric_logger=self._logger) metric_logger=self._logger)
hook.begin() hook.begin()
mon_sess = monitored_session._HookedSession(sess, [hook]) mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
mon_sess.run(train_op) mon_sess.run(train_op)
self.assertRegexpMatches(str(self._logger.logged_metric), t.name) self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
...@@ -180,14 +167,14 @@ class LoggingMetricHookTest(tf.test.TestCase): ...@@ -180,14 +167,14 @@ class LoggingMetricHookTest(tf.test.TestCase):
self._validate_print_every_n_steps(sess, at_end=True) self._validate_print_every_n_steps(sess, at_end=True)
def _validate_print_every_n_secs(self, sess, at_end): def _validate_print_every_n_secs(self, sess, at_end):
t = tf.constant(42.0, name='foo') t = tf.constant(42.0, name="foo")
train_op = tf.constant(3) train_op = tf.constant(3)
hook = metric_hook.LoggingMetricHook( hook = metric_hook.LoggingMetricHook(
tensors=[t.name], every_n_secs=1.0, at_end=at_end, tensors=[t.name], every_n_secs=1.0, at_end=at_end,
metric_logger=self._logger) metric_logger=self._logger)
hook.begin() hook.begin()
mon_sess = monitored_session._HookedSession(sess, [hook]) mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
mon_sess.run(train_op) mon_sess.run(train_op)
...@@ -226,5 +213,5 @@ class LoggingMetricHookTest(tf.test.TestCase): ...@@ -226,5 +213,5 @@ class LoggingMetricHookTest(tf.test.TestCase):
self._validate_print_every_n_secs(sess, at_end=True) self._validate_print_every_n_secs(sess, at_end=True)
if __name__ == '__main__': if __name__ == "__main__":
tf.test.main() tf.test.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mock objects and related functions for testing."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
class MockBenchmarkLogger(object):
"""This is a mock logger that can be used in dependent tests."""
def __init__(self):
self.logged_metric = []
def log_metric(self, name, value, unit=None, global_step=None,
extras=None):
self.logged_metric.append({
"name": name,
"value": float(value),
"unit": unit,
"global_step": global_step,
"extras": extras})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment