"sgl-kernel/vscode:/vscode.git/clone" did not exist on "6f509d550350d98999a7e9f2a16e78dcd478be6b"
common.py 3.63 KB
Newer Older
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Some layered modules/functions to help users writing custom training loop."""

import inspect

import tensorflow as tf


def create_global_step() -> tf.Variable:
  """Creates a `tf.Variable` suitable for use as a global step counter.

  Creating and managing a global step variable may be necessary for
  `AbstractTrainer` subclasses that perform multiple parameter updates per
  `Controller` "step", or use different optimizers on different steps.

  In these cases, an `optimizer.iterations` property generally can't be used
  directly, since it would correspond to parameter updates instead of iterations
  in the `Controller`'s training loop. Such use cases should simply call
  `step.assign_add(1)` at the end of each step.

  Returns:
    A non-trainable scalar `tf.Variable` of dtype `tf.int64`, with only the
    first replica's value retained when synchronizing across replicas in
    a distributed setting.
  """
  return tf.Variable(
      0,
      dtype=tf.int64,
      name="global_step",
      trainable=False,
      aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)


def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs):
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
48
  """A utility function to help create a `tf.distribute.DistributedDataset`.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
49
50
51

  Args:
    strategy: An instance of `tf.distribute.Strategy`.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
52
53
54
55
56
57
    dataset_or_fn: A instance of `tf.data.Dataset`, or a "dataset function"
      returning a `tf.data.Dataset`. If it is a function, it may optionally have
      an argument named `input_context` which will be passed a
      `tf.distribute.InputContext` instance.
    *args: Any positional arguments to pass through to `dataset_or_fn`.
    **kwargs: Any keyword arguments to pass through to `dataset_or_fn`.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
58
59
60
61
62
63
64
65

  Returns:
    A distributed Dataset.
  """
  if strategy is None:
    strategy = tf.distribute.get_strategy()

  if isinstance(dataset_or_fn, tf.data.Dataset):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
66
    return strategy.experimental_distribute_dataset(dataset_or_fn)
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
67
68
69

  if not callable(dataset_or_fn):
    raise ValueError("`dataset_or_fn` should be either callable or an instance "
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
70
                     "of `tf.data.Dataset`.")
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
71

Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
72
73
  def dataset_fn(input_context):
    """Wraps `dataset_or_fn` for strategy.distribute_datasets_from_function."""
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
74

Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
75
76
77
    # If `dataset_or_fn` is a function and has an argument named
    # `input_context`, pass through the given `input_context`. Otherwise
    # `input_context` will be ignored.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
78
    argspec = inspect.getfullargspec(dataset_or_fn)
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
79
    arg_names = argspec.args
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
80

Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
81
82
83
    if "input_context" in arg_names:
      kwargs["input_context"] = input_context
    return dataset_or_fn(*args, **kwargs)
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
84

Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
85
  return strategy.distribute_datasets_from_function(dataset_fn)
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
86
87


Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
88
89
def get_value(x):
  """Returns input values, converting any TensorFlow values to NumPy values.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
90
91

  Args:
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
92
    x: The input. May be a `tf.Tensor` or `tf.Variable`.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
93
94

  Returns:
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
95
96
    If the input is a TensorFlow `Tensor`, returns the `Tensor`'s equivalent
    NumPy value. Otherwise, just returns the input.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
97
98
99
100
  """
  if not tf.is_tensor(x):
    return x
  return x.numpy()