loop_fns.py 7.83 KB
Newer Older
1
# Copyright 2021 The Orbit Authors. All Rights Reserved.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Hongkun Yu's avatar
Hongkun Yu committed
14

Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
15
16
"""Utilities for creating loop functions."""

17
from absl import logging
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
18
19
20
21
22
23
from orbit.utils import tpu_summaries

import tensorflow as tf


def create_loop_fn(step_fn):
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
24
  """Creates a loop function driven by a Python `while` loop.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
25
26

  Args:
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
27
28
29
30
    step_fn: A function taking a nested structure of `tf.data.Iterator` or
      `DistributedIterator`. There are no constraints on the return value of the
      function (except that it must be compatible with any `reduce_fn` provided
      to the returned `loop_fn`).
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
31
32

  Returns:
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
33
34
35
36
    A loop function taking required `iterator` and `num_steps` parameters, as
    well as optional `state` and `reduce_fn` parameters for accumulating state
    over multiple iterations of the loop. See the `loop_fn` definition below for
    additional details.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
37
38
39
  """

  def loop_fn(iterator, num_steps, state=None, reduce_fn=None):
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    """Makes `num_steps` calls to `step_fn(iterator)`.

    Additionally, state may be accumulated across iterations of the loop.
    Conceptually, state accumulation is handled roughly as follows:

        for _ in range(num_steps):
          step_outputs  = step_fn(iterator)
          state = reduce_fn(state, step_outputs)
        return state

    However, the implementation is slightly more complicated in order to support
    looping until the iterator is exhausted (when `num_steps == -1`) and to
    properly catch exceptions when running under async remote eager (as is the
    case in TPU training setups involving separate coordinator/worker machines).
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
54
55

    Args:
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
56
      iterator: A nested structure of `tf.data.Iterator` or
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
57
        `DistributedIterator`.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
58
      num_steps: The number of steps in the loop. If `num_steps == -1`, will
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
59
60
        iterate until exausting the iterator.
      state: An optional initial state before running the loop.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
61
62
63
      reduce_fn: A callable taking two inputs, `state` and `value`, where
        `state` is the previous output from `reduce_fn`, and `value` is the
        output from `step_fn`.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
64
65

    Returns:
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
66
67
      The final state returned by `reduce_fn`, or `None` if `state` and
      `reduce_fn` are not provided.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
68
    """
69
    step = 0
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
70
    try:
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
71
72
      # To make sure the OutOfRangeError exception can be handled well under
      # async remote eager, we need to wrap the loop body in `async_scope`.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
73
      with tf.experimental.async_scope():
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
74
        while num_steps == -1 or step < num_steps:
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
75
76
77
78
79
80
          outputs = step_fn(iterator)
          if reduce_fn is not None:
            state = reduce_fn(state, outputs)
          step += 1
        return state
    except (StopIteration, tf.errors.OutOfRangeError):
81
      logging.info("The dataset iterator is exhausted after %d steps.", step)
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
82
83
84
85
86
87
88
      tf.experimental.async_clear_error()
      return state

  return loop_fn


def create_tf_while_loop_fn(step_fn):
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
89
  """Creates a loop function compatible with TF's AutoGraph loop conversion.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
90
91

  Args:
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
92
93
    step_fn: A function taking a nested structure of `tf.data.Iterator` or
      `DistributedIterator`. Currently, any return values are ignored.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
94
95

  Returns:
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
96
97
98
99
    A loop function taking required `iterator` and `num_steps` parameters. If
    called inside a `tf.function`, the loop will be converted by AutoGraph into
    a `tf.while_loop` construct. See the `loop_fn` definition below for
    additional details.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
100
101
102
  """

  def loop_fn(iterator, num_steps):
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
103
    """Makes `num_steps` calls to `step_fn(iterator)`.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
104
105

    Args:
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
106
      iterator: A nested structure of `tf.data.Iterator` or
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
107
        `DistributedIterator`.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
108
109
      num_steps: The number of steps in the loop. Should be passed as a
        `tf.Tensor`. Iterating until iterator exhaustion is not supported.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
110
111
    """
    if not isinstance(num_steps, tf.Tensor):
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
112
113
114
      raise ValueError(
          "`num_steps` should be a `tf.Tensor`. Passing a Python value can "
          "cause unnecessary retracing when wrapped by `tf.function`.")
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
115
116

    for _ in tf.range(num_steps):
117
118
119
120
      # Clear out the outer name scope so the ops created inside `tf.while_loop`
      # don't get "while/" as name prefix.
      with tf.name_scope(""):
        step_fn(iterator)
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
121
122
123
124

  return loop_fn


125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def create_tf_while_loop_fn_with_state(step_fn):
  """Creates a TF while loop function with state.

  This function is similar to `create_tf_while_loop_fn`, but allowing a `state`
  to be accumulated over multiple iterations of the loop. Note that the
  structure of the `state` cannot be changed across iterations.

  Args:
    step_fn: A function taking a nested structure of `tf.data.Iterator` or
      `DistributedIterator`. Currently, any return values are ignored.

  Returns:
    A loop function taking required `iterator`, `num_steps`, `state` and
    `reduce_fn` parameters. If called inside a `tf.function`, the loop will be
    converted by AutoGraph into a `tf.while_loop` construct. See the `loop_fn`
    definition below for additional details.
  """

  def loop_fn_with_state(iterator, num_steps, state, reduce_fn):
    """Makes `num_steps` calls to `step_fn(iterator)`.

    Args:
      iterator: A nested structure of `tf.data.Iterator` or
        `DistributedIterator`.
      num_steps: The number of steps in the loop. Should be passed as a
        `tf.Tensor`. Iterating until iterator exhaustion is not supported.
      state: An initial state before running the loop.
      reduce_fn: A callable taking two inputs, `state` and `value`, where
        `state` is the previous output from `reduce_fn`, and `value` is the
        output from `step_fn`.

    Returns:
      The final state returned by `reduce_fn`.
    """
    if not isinstance(num_steps, tf.Tensor):
      raise ValueError(
          "`num_steps` should be a `tf.Tensor`. Passing a Python value can "
          "cause unnecessary retracing when wrapped by `tf.function`.")

164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    def _get_relaxed_tensor_shape(t):
      """Returns a `TensorShape` with all `None` dimensions."""
      if not tf.is_tensor(t):
        return None

      shape = t.shape
      if shape.rank is not None and shape.rank > 0:
        return tf.TensorShape([None] * shape.rank)
      return shape

    def _get_relaxed_shape_structure(s):
      """Returns the relaxed shape of the input nested structure `s`."""
      return tf.nest.pack_sequence_as(
          state, [_get_relaxed_tensor_shape(t) for t in tf.nest.flatten(s)])

179
    for _ in tf.range(num_steps):
180
181
182
183
184
185
186
      # Clear out the outer name scope so the ops created inside `tf.while_loop`
      # don't get "while/" as name prefix.
      with tf.name_scope(""):
        # Relax the shapes within the loop, so the shape of `state` can change
        # across iterations. This is useful to aggregate outputs from each step
        # and concat to `state`.
        tf.autograph.experimental.set_loop_options(
187
            shape_invariants=[(state, _get_relaxed_shape_structure(state))])
188
189
        outputs = step_fn(iterator)
        state = reduce_fn(state, outputs)
190
191
192
193
194
    return state

  return loop_fn_with_state


Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
195
196
197
198
199
200
201
202
203
204
205
206
207
class LoopFnWithSummaries(tpu_summaries.OptionalSummariesFunction):
  """Implements a two-program approach for optimizing summaries on TPU.

  This version works with the result of `create_tf_while_loop_fn`.
  """

  def __call__(self, iterator, num_steps):
    if tf.summary.should_record_summaries():
      output = self.with_summaries(iterator, tf.constant(1))
      num_steps -= 1
    if num_steps >= 1:
      output = self.without_summaries(iterator, num_steps)
    return output