metric_hook.py 3.93 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Session hook for logging benchmark metric."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

21
import tensorflow as tf  # pylint: disable=g-bad-import-order
22
23


24
class LoggingMetricHook(tf.estimator.LoggingTensorHook):
25
26
27
28
29
30
31
32
33
34
35
  """Hook to log benchmark metric information.

  This hook is very similar as tf.train.LoggingTensorHook, which logs given
  tensors every N local steps, every N seconds, or at the end. The metric
  information will be logged to given log_dir or via metric_logger in JSON
  format, which can be consumed by data analysis pipeline later.

  Note that if `at_end` is True, `tensors` should not include any tensor
  whose evaluation produces a side effect such as consuming additional inputs.
  """

Qianli Scott Zhu's avatar
Qianli Scott Zhu committed
36
  def __init__(self, tensors, metric_logger=None,
37
38
39
40
41
42
43
               every_n_iter=None, every_n_secs=None, at_end=False):
    """Initializer for LoggingMetricHook.

    Args:
      tensors: `dict` that maps string-valued tags to tensors/tensor names,
          or `iterable` of tensors/tensor names.
      metric_logger: instance of `BenchmarkLogger`, the benchmark logger that
Qianli Scott Zhu's avatar
Qianli Scott Zhu committed
44
          hook should use to write the log.
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
      every_n_iter: `int`, print the values of `tensors` once every N local
          steps taken on the current worker.
      every_n_secs: `int` or `float`, print the values of `tensors` once every N
          seconds. Exactly one of `every_n_iter` and `every_n_secs` should be
          provided.
      at_end: `bool` specifying whether to print the values of `tensors` at the
          end of the run.

    Raises:
      ValueError:
        1. `every_n_iter` is non-positive, or
        2. Exactly one of every_n_iter and every_n_secs should be provided.
        3. Exactly one of log_dir and metric_logger should be provided.
    """
    super(LoggingMetricHook, self).__init__(
        tensors=tensors,
        every_n_iter=every_n_iter,
        every_n_secs=every_n_secs,
        at_end=at_end)

Qianli Scott Zhu's avatar
Qianli Scott Zhu committed
65
66
67
    if metric_logger is None:
      raise ValueError("metric_logger should be provided.")
    self._logger = metric_logger
68
69
70

  def begin(self):
    super(LoggingMetricHook, self).begin()
71
    self._global_step_tensor = tf.compat.v1.train.get_global_step()
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    if self._global_step_tensor is None:
      raise RuntimeError(
          "Global step should be created to use LoggingMetricHook.")
    if self._global_step_tensor.name not in self._current_tensors:
      self._current_tensors[self._global_step_tensor.name] = (
          self._global_step_tensor)

  def after_run(self, unused_run_context, run_values):
    # should_trigger is a internal state that populated at before_run, and it is
    # using self_timer to determine whether it should trigger.
    if self._should_trigger:
      self._log_metric(run_values.results)

    self._iter_count += 1

  def end(self, session):
    if self._log_at_end:
      values = session.run(self._current_tensors)
      self._log_metric(values)

  def _log_metric(self, tensor_values):
    self._timer.update_last_triggered_step(self._iter_count)
    global_step = tensor_values[self._global_step_tensor.name]
    # self._tag_order is populated during the init of LoggingTensorHook
    for tag in self._tag_order:
      self._logger.log_metric(tag, tensor_values[tag], global_step=global_step)