loop_fns.py 7.32 KB
Newer Older
1
# Copyright 2021 The Orbit Authors. All Rights Reserved.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Hongkun Yu's avatar
Hongkun Yu committed
14

Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
15
16
17
18
19
20
21
22
"""Utilities for creating loop functions."""

from orbit.utils import tpu_summaries

import tensorflow as tf


def create_loop_fn(step_fn):
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
23
  """Creates a loop function driven by a Python `while` loop.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
24
25

  Args:
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
26
27
28
29
    step_fn: A function taking a nested structure of `tf.data.Iterator` or
      `DistributedIterator`. There are no constraints on the return value of the
      function (except that it must be compatible with any `reduce_fn` provided
      to the returned `loop_fn`).
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
30
31

  Returns:
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
32
33
34
35
    A loop function taking required `iterator` and `num_steps` parameters, as
    well as optional `state` and `reduce_fn` parameters for accumulating state
    over multiple iterations of the loop. See the `loop_fn` definition below for
    additional details.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
36
37
38
  """

  def loop_fn(iterator, num_steps, state=None, reduce_fn=None):
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    """Makes `num_steps` calls to `step_fn(iterator)`.

    Additionally, state may be accumulated across iterations of the loop.
    Conceptually, state accumulation is handled roughly as follows:

        for _ in range(num_steps):
          step_outputs  = step_fn(iterator)
          state = reduce_fn(state, step_outputs)
        return state

    However, the implementation is slightly more complicated in order to support
    looping until the iterator is exhausted (when `num_steps == -1`) and to
    properly catch exceptions when running under async remote eager (as is the
    case in TPU training setups involving separate coordinator/worker machines).
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
53
54

    Args:
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
55
      iterator: A nested structure of `tf.data.Iterator` or
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
56
        `DistributedIterator`.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
57
      num_steps: The number of steps in the loop. If `num_steps == -1`, will
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
58
59
        iterate until exausting the iterator.
      state: An optional initial state before running the loop.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
60
61
62
      reduce_fn: A callable taking two inputs, `state` and `value`, where
        `state` is the previous output from `reduce_fn`, and `value` is the
        output from `step_fn`.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
63
64

    Returns:
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
65
66
      The final state returned by `reduce_fn`, or `None` if `state` and
      `reduce_fn` are not provided.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
67
68
69
    """
    try:
      step = 0
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
70
71
      # To make sure the OutOfRangeError exception can be handled well under
      # async remote eager, we need to wrap the loop body in `async_scope`.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
72
      with tf.experimental.async_scope():
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
73
        while num_steps == -1 or step < num_steps:
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
74
75
76
77
78
79
80
81
82
83
84
85
86
          outputs = step_fn(iterator)
          if reduce_fn is not None:
            state = reduce_fn(state, outputs)
          step += 1
        return state
    except (StopIteration, tf.errors.OutOfRangeError):
      tf.experimental.async_clear_error()
      return state

  return loop_fn


def create_tf_while_loop_fn(step_fn):
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
87
  """Creates a loop function compatible with TF's AutoGraph loop conversion.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
88
89

  Args:
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
90
91
    step_fn: A function taking a nested structure of `tf.data.Iterator` or
      `DistributedIterator`. Currently, any return values are ignored.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
92
93

  Returns:
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
94
95
96
97
    A loop function taking required `iterator` and `num_steps` parameters. If
    called inside a `tf.function`, the loop will be converted by AutoGraph into
    a `tf.while_loop` construct. See the `loop_fn` definition below for
    additional details.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
98
99
100
  """

  def loop_fn(iterator, num_steps):
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
101
    """Makes `num_steps` calls to `step_fn(iterator)`.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
102
103

    Args:
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
104
      iterator: A nested structure of `tf.data.Iterator` or
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
105
        `DistributedIterator`.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
106
107
      num_steps: The number of steps in the loop. Should be passed as a
        `tf.Tensor`. Iterating until iterator exhaustion is not supported.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
108
109
    """
    if not isinstance(num_steps, tf.Tensor):
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
110
111
112
      raise ValueError(
          "`num_steps` should be a `tf.Tensor`. Passing a Python value can "
          "cause unnecessary retracing when wrapped by `tf.function`.")
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
113
114

    for _ in tf.range(num_steps):
115
116
117
118
      # Clear out the outer name scope so the ops created inside `tf.while_loop`
      # don't get "while/" as name prefix.
      with tf.name_scope(""):
        step_fn(iterator)
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
119
120
121
122

  return loop_fn


123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
def create_tf_while_loop_fn_with_state(step_fn):
  """Creates a TF while loop function with state.

  This function is similar to `create_tf_while_loop_fn`, but allowing a `state`
  to be accumulated over multiple iterations of the loop. Note that the
  structure of the `state` cannot be changed across iterations.

  Args:
    step_fn: A function taking a nested structure of `tf.data.Iterator` or
      `DistributedIterator`. Currently, any return values are ignored.

  Returns:
    A loop function taking required `iterator`, `num_steps`, `state` and
    `reduce_fn` parameters. If called inside a `tf.function`, the loop will be
    converted by AutoGraph into a `tf.while_loop` construct. See the `loop_fn`
    definition below for additional details.
  """

  def loop_fn_with_state(iterator, num_steps, state, reduce_fn):
    """Makes `num_steps` calls to `step_fn(iterator)`.

    Args:
      iterator: A nested structure of `tf.data.Iterator` or
        `DistributedIterator`.
      num_steps: The number of steps in the loop. Should be passed as a
        `tf.Tensor`. Iterating until iterator exhaustion is not supported.
      state: An initial state before running the loop.
      reduce_fn: A callable taking two inputs, `state` and `value`, where
        `state` is the previous output from `reduce_fn`, and `value` is the
        output from `step_fn`.

    Returns:
      The final state returned by `reduce_fn`.
    """
    if not isinstance(num_steps, tf.Tensor):
      raise ValueError(
          "`num_steps` should be a `tf.Tensor`. Passing a Python value can "
          "cause unnecessary retracing when wrapped by `tf.function`.")

    for _ in tf.range(num_steps):
163
164
165
166
167
168
169
170
171
172
173
174
      # Clear out the outer name scope so the ops created inside `tf.while_loop`
      # don't get "while/" as name prefix.
      with tf.name_scope(""):
        # Relax the shapes within the loop, so the shape of `state` can change
        # across iterations. This is useful to aggregate outputs from each step
        # and concat to `state`.
        tf.autograph.experimental.set_loop_options(
            shape_invariants=[(t, tf.TensorShape([None] * t.shape.rank))
                              for t in tf.nest.flatten(state)
                              if tf.is_tensor(t)])
        outputs = step_fn(iterator)
        state = reduce_fn(state, outputs)
175
176
177
178
179
    return state

  return loop_fn_with_state


Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
180
181
182
183
184
185
186
187
188
189
190
191
192
class LoopFnWithSummaries(tpu_summaries.OptionalSummariesFunction):
  """Implements a two-program approach for optimizing summaries on TPU.

  This version works with the result of `create_tf_while_loop_fn`.
  """

  def __call__(self, iterator, num_steps):
    if tf.summary.should_record_summaries():
      output = self.with_summaries(iterator, tf.constant(1))
      num_steps -= 1
    if num_steps >= 1:
      output = self.without_summaries(iterator, num_steps)
    return output