hooks_helper.py 5.9 KB
Newer Older
Yanhui Liang's avatar
Yanhui Liang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Hooks helper to return a list of TensorFlow hooks for training by name.

More hooks can be added to this set. To add a new hook, 1) add the new hook to
the registry in HOOKS, 2) add a corresponding function that parses out necessary
parameters.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

Karmel Allison's avatar
Karmel Allison committed
27
import tensorflow as tf  # pylint: disable=g-bad-import-order
28
from absl import logging
Yanhui Liang's avatar
Yanhui Liang committed
29

30
31
32
from official.r1.utils.logs import hooks
from official.r1.utils.logs import logger
from official.r1.utils.logs import metric_hook
Yanhui Liang's avatar
Yanhui Liang committed
33
34
35
36
37
38

_TENSORS_TO_LOG = dict((x, x) for x in ['learning_rate',
                                        'cross_entropy',
                                        'train_accuracy'])


39
def get_train_hooks(name_list, use_tpu=False, **kwargs):
Yanhui Liang's avatar
Yanhui Liang committed
40
41
42
43
44
45
  """Factory for getting a list of TensorFlow hooks for training by name.

  Args:
    name_list: a list of strings to name desired hook classes. Allowed:
      LoggingTensorHook, ProfilerHook, ExamplesPerSecondHook, which are defined
      as keys in HOOKS
46
47
    use_tpu: Boolean of whether computation occurs on a TPU. This will disable
      hooks altogether.
Karmel Allison's avatar
Karmel Allison committed
48
    **kwargs: a dictionary of arguments to the hooks.
Yanhui Liang's avatar
Yanhui Liang committed
49
50
51
52
53
54
55
56
57
58
59

  Returns:
    list of instantiated hooks, ready to be used in a classifier.train call.

  Raises:
    ValueError: if an unrecognized name is passed.
  """

  if not name_list:
    return []

60
  if use_tpu:
61
62
63
    logging.warning(
        'hooks_helper received name_list `%s`, but a '
        'TPU is specified. No hooks will be used.', name_list)
64
65
    return []

Yanhui Liang's avatar
Yanhui Liang committed
66
67
68
69
70
71
72
73
74
75
76
  train_hooks = []
  for name in name_list:
    hook_name = HOOKS.get(name.strip().lower())
    if hook_name is None:
      raise ValueError('Unrecognized training hook requested: {}'.format(name))
    else:
      train_hooks.append(hook_name(**kwargs))

  return train_hooks


77
def get_logging_tensor_hook(every_n_iter=100, tensors_to_log=None, **kwargs):  # pylint: disable=unused-argument
Yanhui Liang's avatar
Yanhui Liang committed
78
79
80
81
82
  """Function to get LoggingTensorHook.

  Args:
    every_n_iter: `int`, print the values of `tensors` once every N local
      steps taken on the current worker.
83
84
    tensors_to_log: List of tensor names or dictionary mapping labels to tensor
      names. If not set, log _TENSORS_TO_LOG by default.
Karmel Allison's avatar
Karmel Allison committed
85
    **kwargs: a dictionary of arguments to LoggingTensorHook.
Yanhui Liang's avatar
Yanhui Liang committed
86
87
88
89
90

  Returns:
    Returns a LoggingTensorHook with a standard set of tensors that will be
    printed to stdout.
  """
91
92
93
  if tensors_to_log is None:
    tensors_to_log = _TENSORS_TO_LOG

94
  return tf.estimator.LoggingTensorHook(
95
      tensors=tensors_to_log,
Yanhui Liang's avatar
Yanhui Liang committed
96
97
98
      every_n_iter=every_n_iter)


99
def get_profiler_hook(model_dir, save_steps=1000, **kwargs):  # pylint: disable=unused-argument
Yanhui Liang's avatar
Yanhui Liang committed
100
101
102
  """Function to get ProfilerHook.

  Args:
103
    model_dir: The directory to save the profile traces to.
Yanhui Liang's avatar
Yanhui Liang committed
104
    save_steps: `int`, print profile traces every N steps.
Karmel Allison's avatar
Karmel Allison committed
105
    **kwargs: a dictionary of arguments to ProfilerHook.
Yanhui Liang's avatar
Yanhui Liang committed
106
107
108
109
110

  Returns:
    Returns a ProfilerHook that writes out timelines that can be loaded into
    profiling tools like chrome://tracing.
  """
111
  return tf.estimator.ProfilerHook(save_steps=save_steps, output_dir=model_dir)
Yanhui Liang's avatar
Yanhui Liang committed
112
113
114
115


def get_examples_per_second_hook(every_n_steps=100,
                                 batch_size=128,
116
                                 warm_steps=5,
Yanhui Liang's avatar
Yanhui Liang committed
117
118
119
120
121
122
123
124
125
                                 **kwargs):  # pylint: disable=unused-argument
  """Function to get ExamplesPerSecondHook.

  Args:
    every_n_steps: `int`, print current and average examples per second every
      N steps.
    batch_size: `int`, total batch size used to calculate examples/second from
      global time.
    warm_steps: skip this number of steps before logging and running average.
Karmel Allison's avatar
Karmel Allison committed
126
    **kwargs: a dictionary of arguments to ExamplesPerSecondHook.
Yanhui Liang's avatar
Yanhui Liang committed
127
128
129
130
131

  Returns:
    Returns a ProfilerHook that writes out timelines that can be loaded into
    profiling tools like chrome://tracing.
  """
Karmel Allison's avatar
Karmel Allison committed
132
133
134
  return hooks.ExamplesPerSecondHook(
      batch_size=batch_size, every_n_steps=every_n_steps,
      warm_steps=warm_steps, metric_logger=logger.get_benchmark_logger())
Yanhui Liang's avatar
Yanhui Liang committed
135
136


137
def get_logging_metric_hook(tensors_to_log=None,
138
139
140
141
142
143
144
145
146
                            every_n_secs=600,
                            **kwargs):  # pylint: disable=unused-argument
  """Function to get LoggingMetricHook.

  Args:
    tensors_to_log: List of tensor names or dictionary mapping labels to tensor
      names. If not set, log _TENSORS_TO_LOG by default.
    every_n_secs: `int`, the frequency for logging the metric. Default to every
      10 mins.
147
    **kwargs: a dictionary of arguments.
148
149

  Returns:
150
    Returns a LoggingMetricHook that saves tensor values in a JSON format.
151
152
153
154
155
  """
  if tensors_to_log is None:
    tensors_to_log = _TENSORS_TO_LOG
  return metric_hook.LoggingMetricHook(
      tensors=tensors_to_log,
Qianli Scott Zhu's avatar
Qianli Scott Zhu committed
156
      metric_logger=logger.get_benchmark_logger(),
157
158
159
      every_n_secs=every_n_secs)


160
161
162
163
164
165
def get_step_counter_hook(**kwargs):
  """Function to get StepCounterHook."""
  del kwargs
  return tf.estimator.StepCounterHook()


Yanhui Liang's avatar
Yanhui Liang committed
166
167
168
169
170
# A dictionary to map one hook name and its corresponding function
HOOKS = {
    'loggingtensorhook': get_logging_tensor_hook,
    'profilerhook': get_profiler_hook,
    'examplespersecondhook': get_examples_per_second_hook,
171
    'loggingmetrichook': get_logging_metric_hook,
172
    'stepcounterhook': get_step_counter_hook
Yanhui Liang's avatar
Yanhui Liang committed
173
}