metric_hook.py 4.24 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Session hook for logging benchmark metric."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from official.utils.logging import logger


class LoggingMetricHook(tf.train.LoggingTensorHook):
  """Hook to log benchmark metric information.

  This hook is very similar as tf.train.LoggingTensorHook, which logs given
  tensors every N local steps, every N seconds, or at the end. The metric
  information will be logged to given log_dir or via metric_logger in JSON
  format, which can be consumed by data analysis pipeline later.

  Note that if `at_end` is True, `tensors` should not include any tensor
  whose evaluation produces a side effect such as consuming additional inputs.
  """

  def __init__(self, tensors, log_dir=None, metric_logger=None,
               every_n_iter=None, every_n_secs=None, at_end=False):
    """Initializer for LoggingMetricHook.

    Args:
      tensors: `dict` that maps string-valued tags to tensors/tensor names,
          or `iterable` of tensors/tensor names.
      log_dir: `string`, directory path that metric hook should write log to.
      metric_logger: instance of `BenchmarkLogger`, the benchmark logger that
          hook should use to write the log. Exactly one of the `log_dir` and
          `metric_logger` should be provided.
      every_n_iter: `int`, print the values of `tensors` once every N local
          steps taken on the current worker.
      every_n_secs: `int` or `float`, print the values of `tensors` once every N
          seconds. Exactly one of `every_n_iter` and `every_n_secs` should be
          provided.
      at_end: `bool` specifying whether to print the values of `tensors` at the
          end of the run.

    Raises:
      ValueError:
        1. `every_n_iter` is non-positive, or
        2. Exactly one of every_n_iter and every_n_secs should be provided.
        3. Exactly one of log_dir and metric_logger should be provided.
    """
    super(LoggingMetricHook, self).__init__(
        tensors=tensors,
        every_n_iter=every_n_iter,
        every_n_secs=every_n_secs,
        at_end=at_end)

    if (log_dir is None) == (metric_logger is None):
      raise ValueError(
          "exactly one of log_dir and metric_logger should be provided.")

    if log_dir is not None:
      self._logger = logger.BenchmarkLogger(log_dir)
    else:
      self._logger = metric_logger

  def begin(self):
    super(LoggingMetricHook, self).begin()
    self._global_step_tensor = tf.train.get_global_step()
    if self._global_step_tensor is None:
      raise RuntimeError(
          "Global step should be created to use LoggingMetricHook.")
    if self._global_step_tensor.name not in self._current_tensors:
      self._current_tensors[self._global_step_tensor.name] = (
          self._global_step_tensor)

  def after_run(self, unused_run_context, run_values):
    # should_trigger is a internal state that populated at before_run, and it is
    # using self_timer to determine whether it should trigger.
    if self._should_trigger:
      self._log_metric(run_values.results)

    self._iter_count += 1

  def end(self, session):
    if self._log_at_end:
      values = session.run(self._current_tensors)
      self._log_metric(values)

  def _log_metric(self, tensor_values):
    self._timer.update_last_triggered_step(self._iter_count)
    global_step = tensor_values[self._global_step_tensor.name]
    # self._tag_order is populated during the init of LoggingTensorHook
    for tag in self._tag_order:
      self._logger.log_metric(tag, tensor_values[tag], global_step=global_step)