"benchmark/kernels/vscode:/vscode.git/clone" did not exist on "6e2da5156176ed2d7fe2445b7c7316bc1650b20a"
utils.py 13.3 KB
Newer Older
Hongkun Yu's avatar
Hongkun Yu committed
1
# Lint as: python3
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Some layered modules/functions to help users writing custom training loop."""

import abc
import contextlib
import functools
import inspect

Ruoxin Sang's avatar
Ruoxin Sang committed
23
import os
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import numpy as np
import tensorflow as tf


def create_loop_fn(step_fn):
  """Creates a multiple steps function driven by the python while loop.

  Args:
    step_fn: A function which takes `iterator` as input.

  Returns:
    A callable defined as the `loop_fn` defination below.
  """

  def loop_fn(iterator, num_steps, state=None, reduce_fn=None):
    """A loop function with multiple steps.

    Args:
      iterator: A nested structure of tf.data `Iterator` or
        `DistributedIterator`.
      num_steps: The number of steps in the loop. If `num_steps==-1`, will
        iterate until exausting the iterator.
      state: An optional initial state before running the loop.
      reduce_fn: a callable defined as `def reduce_fn(state, value)`, where
        `value` is the outputs from `step_fn`.

    Returns:
      The updated state.
    """
    try:
      step = 0
      # To make sure the OutOfRangeError exception can be handled well with
      # async remote eager, we need to wrap the loop body in a `async_scope`.
      with tf.experimental.async_scope():
        while (num_steps == -1 or step < num_steps):
          outputs = step_fn(iterator)
          if reduce_fn is not None:
            state = reduce_fn(state, outputs)
          step += 1
        return state
    except (StopIteration, tf.errors.OutOfRangeError):
      tf.experimental.async_clear_error()
      return state

  return loop_fn


def create_tf_while_loop_fn(step_fn):
  """Create a multiple steps function driven by tf.while_loop on the host.

  Args:
    step_fn: A function which takes `iterator` as input.

  Returns:
    A callable defined as the `loop_fn` defination below.
  """

  def loop_fn(iterator, num_steps):
    """A loop function with multiple steps.

    Args:
      iterator: A nested structure of tf.data `Iterator` or
        `DistributedIterator`.
      num_steps: The number of steps in the loop. Must be a tf.Tensor.
    """
    if not isinstance(num_steps, tf.Tensor):
      raise ValueError("`num_steps` should be an `tf.Tensor`. Python object "
                       "may cause retracing.")

    for _ in tf.range(num_steps):
      step_fn(iterator)

  return loop_fn


def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs):
  """A helper function to create distributed dataset.

  Args:
    strategy: An instance of `tf.distribute.Strategy`.
    dataset_or_fn: A instance of `tf.data.Dataset` or a function which takes an
      `tf.distribute.InputContext` as input and returns a `tf.data.Dataset`. If
      it is a function, it could optionally have an argument named
      `input_context` which is `tf.distribute.InputContext` argument type.
    *args: The list of arguments to be passed to dataset_or_fn.
    **kwargs: Any keyword arguments to be passed.

  Returns:
    A distributed Dataset.
  """
  if strategy is None:
    strategy = tf.distribute.get_strategy()

  if isinstance(dataset_or_fn, tf.data.Dataset):
    return strategy.experimental_distribute_dataset(dataset_or_fn)

  if not callable(dataset_or_fn):
    raise ValueError("`dataset_or_fn` should be either callable or an instance "
                     "of `tf.data.Dataset`")

  def dataset_fn(ctx):
    """Wrapped dataset function for creating distributed dataset.."""

    # If `dataset_or_fn` is a function and has `input_context` as argument
    # names, pass `ctx` as the value of `input_context` when calling
    # `dataset_or_fn`. Otherwise `ctx` will not be used when calling
    # `dataset_or_fn`.
Hongkun Yu's avatar
Hongkun Yu committed
131
    argspec = inspect.getfullargspec(dataset_or_fn)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
132
133
134
135
136
137
138
139
140
141
    args_names = argspec.args

    if "input_context" in args_names:
      kwargs["input_context"] = ctx
    ds = dataset_or_fn(*args, **kwargs)
    return ds

  return strategy.experimental_distribute_datasets_from_function(dataset_fn)


Hongkun Yu's avatar
Hongkun Yu committed
142
class SummaryManager:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
  """A class manages writing summaries."""

  def __init__(self, summary_dir, summary_fn, global_step=None):
    """Construct a summary manager object.

    Args:
      summary_dir: the directory to write summaries.
      summary_fn: A callable defined as `def summary_fn(name, tensor,
        step=None)`, which describes the summary operation.
      global_step: A `tf.Variable` instance for the global step.
    """
    self._enabled = (summary_dir is not None)
    self._summary_dir = summary_dir
    self._summary_fn = summary_fn
Ruoxin Sang's avatar
Ruoxin Sang committed
157
    self._summary_writers = {}
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
158
159
160
161
162
163

    if global_step is None:
      self._global_step = tf.summary.experimental.get_step()
    else:
      self._global_step = global_step

Ruoxin Sang's avatar
Ruoxin Sang committed
164
165
166
167
168
169
170
171
172
173
  def summary_writer(self, relative_path=""):
    """Returns the underlying summary writer.

    Args:
      relative_path: The current path in which to write summaries, relative to
        the summary directory. By default it is empty, which specifies the root
        directory.
    """
    if self._summary_writers and relative_path in self._summary_writers:
      return self._summary_writers[relative_path]
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
174
    if self._enabled:
Ruoxin Sang's avatar
Ruoxin Sang committed
175
176
      self._summary_writers[relative_path] = tf.summary.create_file_writer(
          os.path.join(self._summary_dir, relative_path))
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
177
    else:
Ruoxin Sang's avatar
Ruoxin Sang committed
178
179
      self._summary_writers[relative_path] = tf.summary.create_noop_writer()
    return self._summary_writers[relative_path]
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
180
181

  def flush(self):
Ruoxin Sang's avatar
Ruoxin Sang committed
182
    """Flush the underlying summary writers."""
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
183
    if self._enabled:
Ruoxin Sang's avatar
Ruoxin Sang committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
      tf.nest.map_structure(tf.summary.flush, self._summary_writers)

  def write_summaries(self, summary_dict):
    """Write summaries for the given values.

    This recursively creates sub-directories for any nested dictionaries
    provided in `summary_dict`, yielding a hierarchy of directories which will
    then be reflected in the TensorBoard UI as different colored curves.

    E.g. users may evaluate on muliple datasets and return `summary_dict` as a
    nested
    dictionary.
    ```
    {
        "dataset": {
            "loss": loss,
            "accuracy": accuracy
        },
        "dataset2": {
            "loss": loss2,
            "accuracy": accuracy2
        },
    }
    ```
    It will create two sub directories "dataset" and "dataset2" inside summary
    root directory. And each directory write both "loss" and "accuracy"
    summaries inside.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
211
212

    Args:
Ruoxin Sang's avatar
Ruoxin Sang committed
213
214
215
216
217
      summary_dict: A dictionary of values. If any value in `summary_dict` is
        itself a dictionary, then the function will recursively create
        subdirectories with names given by the keys in the dictionary. The
        Tensor values are summarized using the summary writer instance specific
        to the parent relative path.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
218
219
220
    """
    if not self._enabled:
      return
Ruoxin Sang's avatar
Ruoxin Sang committed
221
222
223
224
225
226
227
228
229
230
    self._write_summaries(summary_dict)

  def _write_summaries(self, summary_dict, relative_path=""):
    for name, value in summary_dict.items():
      if isinstance(value, dict):
        self._write_summaries(
            value, relative_path=os.path.join(relative_path, name))
      else:
        with self.summary_writer(relative_path).as_default():
          self._summary_fn(name, value, step=self._global_step)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
231
232


Hongkun Yu's avatar
Hongkun Yu committed
233
class Trigger(metaclass=abc.ABCMeta):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
  """An abstract class representing a "trigger" for some event."""

  @abc.abstractmethod
  def __call__(self, value: float, force_trigger=False):
    """Maybe trigger the event based on the given value.

    Args:
      value: the value for triggering.
      force_trigger: Whether the trigger is forced triggered.

    Returns:
      `True` if the trigger is triggered on the given `value`, and
      `False` otherwise.
    """

  @abc.abstractmethod
  def reset(self):
    """Reset states in the trigger."""


class IntervalTrigger(Trigger):
  """Triggers on every fixed interval."""

  def __init__(self, interval, start=0):
    """Constructs the IntervalTrigger.

    Args:
      interval: The triggering interval.
      start: An initial value for the trigger.
    """
    self._interval = interval
    self._last_trigger_value = start

  def __call__(self, value, force_trigger=False):
    """Maybe trigger the event based on the given value.

    Args:
      value: the value for triggering.
      force_trigger: If True, the trigger will be forced triggered unless the
        last trigger value is equal to `value`.

    Returns:
      `True` if the trigger is triggered on the given `value`, and
      `False` otherwise.
    """
    if force_trigger and value != self._last_trigger_value:
      self._last_trigger_value = value
      return True

    if self._interval and self._interval > 0:
      if value >= self._last_trigger_value + self._interval:
        self._last_trigger_value = value
        return True
    return False

  def reset(self):
    """See base class."""
    self._last_trigger_value = 0


Hongkun Yu's avatar
Hongkun Yu committed
294
class EpochHelper:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
  """A Helper class to handle epochs in Customized Training Loop."""

  def __init__(self, epoch_steps, global_step):
    """Constructs the EpochHelper.

    Args:
      epoch_steps: An integer indicates how many steps in an epoch.
      global_step: A `tf.Variable` instance indicates the current global step.
    """
    self._epoch_steps = epoch_steps
    self._global_step = global_step
    self._current_epoch = None
    self._epoch_start_step = None
    self._in_epoch = False

  def epoch_begin(self):
    """Returns whether a new epoch should begin."""
    if self._in_epoch:
      return False
    current_step = self._global_step.numpy()
    self._epoch_start_step = current_step
    self._current_epoch = current_step // self._epoch_steps
    self._in_epoch = True
    return True

  def epoch_end(self):
    """Returns whether the current epoch should end."""
    if not self._in_epoch:
      raise ValueError("`epoch_end` can only be called inside an epoch")
    current_step = self._global_step.numpy()
    epoch = current_step // self._epoch_steps

    if epoch > self._current_epoch:
      self._in_epoch = False
      return True
    return False

  @property
  def batch_index(self):
    """Index of the next batch within the current epoch."""
    return self._global_step.numpy() - self._epoch_start_step

  @property
  def current_epoch(self):
    return self._current_epoch


@contextlib.contextmanager
def _soft_device_placement():
  """Context manager for soft device placement, allowing summaries on CPU."""
  original_setting = tf.config.get_soft_device_placement()
  try:
    tf.config.set_soft_device_placement(True)
    yield
  finally:
    tf.config.set_soft_device_placement(original_setting)


def train_function_with_summaries(*args, **kwargs):
  """Utility function to support TPU summaries via multiple `tf.function`s.

  This permits interleaving summaries inside TPU-compatible code, but without
  any performance impact on steps that do not write summaries.

  Usage is as a decorator, similar to `tf.function`, and any `tf.function`
  arguments will be passed through if supplied:

      @trainer.train_function_with_summaries
      def train(self, num_steps):
        ...

  The decorated function is assumed to be a loop method accepting a `num_steps`
  parameter, as for instance would be called within the `Controller`'s outer
  train loop. The implementation here assumes that `summary_frequency` is
  divisible by `steps_per_loop`. The decorated method should accept two
  arguments, `self` and `num_steps`.

  Two `tf.function` versions of `train_fn` are created: one inside a summary
  writer scope with soft device placement enabled (used on steps that require
  summary writing), and one with no summary writer present and soft device
  placement disabled (used on all other steps).

  Args:
    *args: Arguments to pass through to `tf.function`.
    **kwargs: Keyword arguments to pass through to `tf.function`.

  Returns:
    If the first argument is a callable, returns the decorated callable.
    Otherwise, returns a decorator.
  """

  def decorator(train_fn):
    # TODO(dhr): Validate the signature of train_fn?

    train_fn_with_summaries = tf.function(train_fn, *args, **kwargs)
    train_fn_without_summaries = tf.function(train_fn, *args, **kwargs)

    @functools.wraps(train_fn)
    def wrapper(self, num_steps):
      if tf.summary.should_record_summaries():
        with _soft_device_placement():
          output = train_fn_with_summaries(self, tf.constant(1))
          num_steps -= 1
      if num_steps >= 1:
        with tf.summary.record_if(False):
          output = train_fn_without_summaries(self, num_steps)
      return output

    return wrapper

  if args and callable(args[0]):
    train_fn, args = args[0], args[1:]
    return decorator(train_fn)
  return decorator


def get_value(x) -> np.ndarray:
  """Returns the value of a variable/tensor.

  Args:
      x: input variable.

  Returns:
Simon Kornblith's avatar
Simon Kornblith committed
418
      A Numpy array or number.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
419
420
421
422
  """
  if not tf.is_tensor(x):
    return x
  return x.numpy()