loop_fns.py 4.9 KB
Newer Older
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for creating loop functions."""

from orbit.utils import tpu_summaries

import tensorflow as tf


def create_loop_fn(step_fn):
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
23
  """Creates a loop function driven by a Python `while` loop.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
24
25

  Args:
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
26
27
28
29
    step_fn: A function taking a nested structure of `tf.data.Iterator` or
      `DistributedIterator`. There are no constraints on the return value of the
      function (except that it must be compatible with any `reduce_fn` provided
      to the returned `loop_fn`).
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
30
31

  Returns:
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
32
33
34
35
    A loop function taking required `iterator` and `num_steps` parameters, as
    well as optional `state` and `reduce_fn` parameters for accumulating state
    over multiple iterations of the loop. See the `loop_fn` definition below for
    additional details.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
36
37
38
  """

  def loop_fn(iterator, num_steps, state=None, reduce_fn=None):
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    """Makes `num_steps` calls to `step_fn(iterator)`.

    Additionally, state may be accumulated across iterations of the loop.
    Conceptually, state accumulation is handled roughly as follows:

        for _ in range(num_steps):
          step_outputs  = step_fn(iterator)
          state = reduce_fn(state, step_outputs)
        return state

    However, the implementation is slightly more complicated in order to support
    looping until the iterator is exhausted (when `num_steps == -1`) and to
    properly catch exceptions when running under async remote eager (as is the
    case in TPU training setups involving separate coordinator/worker machines).
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
53
54

    Args:
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
55
      iterator: A nested structure of `tf.data.Iterator` or
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
56
        `DistributedIterator`.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
57
      num_steps: The number of steps in the loop. If `num_steps == -1`, will
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
58
59
        iterate until exausting the iterator.
      state: An optional initial state before running the loop.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
60
61
62
      reduce_fn: A callable taking two inputs, `state` and `value`, where
        `state` is the previous output from `reduce_fn`, and `value` is the
        output from `step_fn`.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
63
64

    Returns:
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
65
66
      The final state returned by `reduce_fn`, or `None` if `state` and
      `reduce_fn` are not provided.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
67
68
69
    """
    try:
      step = 0
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
70
71
      # To make sure the OutOfRangeError exception can be handled well under
      # async remote eager, we need to wrap the loop body in `async_scope`.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
72
      with tf.experimental.async_scope():
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
73
        while num_steps == -1 or step < num_steps:
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
74
75
76
77
78
79
80
81
82
83
84
85
86
          outputs = step_fn(iterator)
          if reduce_fn is not None:
            state = reduce_fn(state, outputs)
          step += 1
        return state
    except (StopIteration, tf.errors.OutOfRangeError):
      tf.experimental.async_clear_error()
      return state

  return loop_fn


def create_tf_while_loop_fn(step_fn):
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
87
  """Creates a loop function compatible with TF's AutoGraph loop conversion.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
88
89

  Args:
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
90
91
    step_fn: A function taking a nested structure of `tf.data.Iterator` or
      `DistributedIterator`. Currently, any return values are ignored.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
92
93

  Returns:
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
94
95
96
97
    A loop function taking required `iterator` and `num_steps` parameters. If
    called inside a `tf.function`, the loop will be converted by AutoGraph into
    a `tf.while_loop` construct. See the `loop_fn` definition below for
    additional details.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
98
99
100
  """

  def loop_fn(iterator, num_steps):
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
101
    """Makes `num_steps` calls to `step_fn(iterator)`.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
102
103

    Args:
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
104
      iterator: A nested structure of `tf.data.Iterator` or
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
105
        `DistributedIterator`.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
106
107
      num_steps: The number of steps in the loop. Should be passed as a
        `tf.Tensor`. Iterating until iterator exhaustion is not supported.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
108
109
    """
    if not isinstance(num_steps, tf.Tensor):
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
110
111
112
      raise ValueError(
          "`num_steps` should be a `tf.Tensor`. Passing a Python value can "
          "cause unnecessary retracing when wrapped by `tf.function`.")
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132

    for _ in tf.range(num_steps):
      step_fn(iterator)

  return loop_fn


class LoopFnWithSummaries(tpu_summaries.OptionalSummariesFunction):
  """Implements a two-program approach for optimizing summaries on TPU.

  This version works with the result of `create_tf_while_loop_fn`.
  """

  def __call__(self, iterator, num_steps):
    if tf.summary.should_record_summaries():
      output = self.with_summaries(iterator, tf.constant(1))
      num_steps -= 1
    if num_steps >= 1:
      output = self.without_summaries(iterator, num_steps)
    return output