hooks.py 4.28 KB
Newer Older
Yanhui Liang's avatar
Yanhui Liang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Hook that counts examples per second every N steps or seconds."""


from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf


class ExamplesPerSecondHook(tf.train.SessionRunHook):
  """Hook to print out examples per second.

  Total time is tracked and then divided by the total number of steps
  to get the average step time and then batch_size is used to determine
  the running average of examples per second. The examples per second for the
  most recent interval is also logged.
  """

  def __init__(self,
               batch_size,
               every_n_steps=None,
               every_n_secs=None,
               warm_steps=0):
    """Initializer for ExamplesPerSecondHook.

    Args:
      batch_size: Total batch size across all workers used to calculate
        examples/second from global time.
      every_n_steps: Log stats every n steps.
      every_n_secs: Log stats every n seconds. Exactly one of the
        `every_n_steps` or `every_n_secs` should be set.
      warm_steps: The number of steps to be skipped before logging and running
        average calculation. warm_steps steps refers to global steps across all
        workers, not on each worker

    Raises:
      ValueError: if neither `every_n_steps` or `every_n_secs` is set, or
      both are set.
    """

    if (every_n_steps is None) == (every_n_secs is None):
      raise ValueError('exactly one of every_n_steps'
                       ' and every_n_secs should be provided.')

    self._timer = tf.train.SecondOrStepTimer(
        every_steps=every_n_steps, every_secs=every_n_secs)

    self._step_train_time = 0
    self._total_steps = 0
    self._batch_size = batch_size
    self._warm_steps = warm_steps

  def begin(self):
    """Called once before using the session to check global step."""
    self._global_step_tensor = tf.train.get_global_step()
    if self._global_step_tensor is None:
      raise RuntimeError(
          'Global step should be created to use StepCounterHook.')

  def before_run(self, run_context):  # pylint: disable=unused-argument
    """Called before each call to run().

    Args:
      run_context: A SessionRunContext object.

    Returns:
      A SessionRunArgs object or None if never triggered.
    """
    return tf.train.SessionRunArgs(self._global_step_tensor)

  def after_run(self, run_context, run_values):  # pylint: disable=unused-argument
    """Called after each call to run().

    Args:
      run_context: A SessionRunContext object.
      run_values: A SessionRunValues object.
    """
    global_step = run_values.results

    if self._timer.should_trigger_for_step(
        global_step) and global_step > self._warm_steps:
      elapsed_time, elapsed_steps = self._timer.update_last_triggered_step(
          global_step)
      if elapsed_time is not None:
        self._step_train_time += elapsed_time
        self._total_steps += elapsed_steps

        # average examples per second is based on the total (accumulative)
        # training steps and training time so far
        average_examples_per_sec = self._batch_size * (
            self._total_steps / self._step_train_time)
        # current examples per second is based on the elapsed training steps
        # and training time per batch
        current_examples_per_sec = self._batch_size * (
            elapsed_steps / elapsed_time)
        # Current examples/sec followed by average examples/sec
        tf.logging.info('Batch [%g]:  current exp/sec = %g, average exp/sec = '
                        '%g', self._total_steps, current_examples_per_sec,
                        average_examples_per_sec)