hooks.py 4.65 KB
Newer Older
Yanhui Liang's avatar
Yanhui Liang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Hook that counts examples per second every N steps or seconds."""


from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

Karmel Allison's avatar
Karmel Allison committed
23
24
25
import tensorflow as tf  # pylint: disable=g-bad-import-order

from official.utils.logs import logger
Yanhui Liang's avatar
Yanhui Liang committed
26
27


28
class ExamplesPerSecondHook(tf.estimator.SessionRunHook):
Yanhui Liang's avatar
Yanhui Liang committed
29
30
31
32
33
34
35
36
37
38
39
40
  """Hook to print out examples per second.

  Total time is tracked and then divided by the total number of steps
  to get the average step time and then batch_size is used to determine
  the running average of examples per second. The examples per second for the
  most recent interval is also logged.
  """

  def __init__(self,
               batch_size,
               every_n_steps=None,
               every_n_secs=None,
Karmel Allison's avatar
Karmel Allison committed
41
42
               warm_steps=0,
               metric_logger=None):
Yanhui Liang's avatar
Yanhui Liang committed
43
44
45
46
47
48
49
50
51
52
53
    """Initializer for ExamplesPerSecondHook.

    Args:
      batch_size: Total batch size across all workers used to calculate
        examples/second from global time.
      every_n_steps: Log stats every n steps.
      every_n_secs: Log stats every n seconds. Exactly one of the
        `every_n_steps` or `every_n_secs` should be set.
      warm_steps: The number of steps to be skipped before logging and running
        average calculation. warm_steps steps refers to global steps across all
        workers, not on each worker
Karmel Allison's avatar
Karmel Allison committed
54
55
56
      metric_logger: instance of `BenchmarkLogger`, the benchmark logger that
          hook should use to write the log. If None, BaseBenchmarkLogger will
          be used.
Yanhui Liang's avatar
Yanhui Liang committed
57
58
59
60
61
62
63

    Raises:
      ValueError: if neither `every_n_steps` or `every_n_secs` is set, or
      both are set.
    """

    if (every_n_steps is None) == (every_n_secs is None):
Karmel Allison's avatar
Karmel Allison committed
64
65
66
67
      raise ValueError("exactly one of every_n_steps"
                       " and every_n_secs should be provided.")

    self._logger = metric_logger or logger.BaseBenchmarkLogger()
Yanhui Liang's avatar
Yanhui Liang committed
68

69
    self._timer = tf.estimator.SecondOrStepTimer(
Yanhui Liang's avatar
Yanhui Liang committed
70
71
72
73
74
75
76
77
78
        every_steps=every_n_steps, every_secs=every_n_secs)

    self._step_train_time = 0
    self._total_steps = 0
    self._batch_size = batch_size
    self._warm_steps = warm_steps

  def begin(self):
    """Called once before using the session to check global step."""
79
    self._global_step_tensor = tf.compat.v1.train.get_global_step()
Yanhui Liang's avatar
Yanhui Liang committed
80
81
    if self._global_step_tensor is None:
      raise RuntimeError(
Karmel Allison's avatar
Karmel Allison committed
82
          "Global step should be created to use StepCounterHook.")
Yanhui Liang's avatar
Yanhui Liang committed
83
84
85
86
87
88
89
90
91
92

  def before_run(self, run_context):  # pylint: disable=unused-argument
    """Called before each call to run().

    Args:
      run_context: A SessionRunContext object.

    Returns:
      A SessionRunArgs object or None if never triggered.
    """
93
    return tf.estimator.SessionRunArgs(self._global_step_tensor)
Yanhui Liang's avatar
Yanhui Liang committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119

  def after_run(self, run_context, run_values):  # pylint: disable=unused-argument
    """Called after each call to run().

    Args:
      run_context: A SessionRunContext object.
      run_values: A SessionRunValues object.
    """
    global_step = run_values.results

    if self._timer.should_trigger_for_step(
        global_step) and global_step > self._warm_steps:
      elapsed_time, elapsed_steps = self._timer.update_last_triggered_step(
          global_step)
      if elapsed_time is not None:
        self._step_train_time += elapsed_time
        self._total_steps += elapsed_steps

        # average examples per second is based on the total (accumulative)
        # training steps and training time so far
        average_examples_per_sec = self._batch_size * (
            self._total_steps / self._step_train_time)
        # current examples per second is based on the elapsed training steps
        # and training time per batch
        current_examples_per_sec = self._batch_size * (
            elapsed_steps / elapsed_time)
Karmel Allison's avatar
Karmel Allison committed
120
121
122
123
124
125
126
127

        self._logger.log_metric(
            "average_examples_per_sec", average_examples_per_sec,
            global_step=global_step)

        self._logger.log_metric(
            "current_examples_per_sec", current_examples_per_sec,
            global_step=global_step)