distribution_utils.py 10.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper functions for running models in a distributed setting."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

21
22
import json
import os
23
24
import random
import string
25

26
from absl import logging
27
import tensorflow.compat.v2 as tf
28

29
30
from official.utils.misc import tpu_lib

31

32
33
def _collective_communication(all_reduce_alg):
  """Return a CollectiveCommunication based on all_reduce_alg.
34

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
  Args:
    all_reduce_alg: a string specifying which collective communication to pick,
      or None.

  Returns:
    tf.distribute.experimental.CollectiveCommunication object

  Raises:
    ValueError: if `all_reduce_alg` not in [None, 'ring', 'nccl']
  """
  collective_communication_options = {
      None: tf.distribute.experimental.CollectiveCommunication.AUTO,
      "ring": tf.distribute.experimental.CollectiveCommunication.RING,
      "nccl": tf.distribute.experimental.CollectiveCommunication.NCCL
  }
  if all_reduce_alg not in collective_communication_options:
    raise ValueError(
        "When used with `multi_worker_mirrored`, valid values for "
        "all_reduce_alg are ['ring', 'nccl'].  Supplied value: {}".format(
            all_reduce_alg))
  return collective_communication_options[all_reduce_alg]


def _mirrored_cross_device_ops(all_reduce_alg, num_packs):
  """Return a CrossDeviceOps based on all_reduce_alg and num_packs.

  Args:
    all_reduce_alg: a string specifying which cross device op to pick, or None.
    num_packs: an integer specifying number of packs for the cross device op.

  Returns:
    tf.distribute.CrossDeviceOps object or None.

  Raises:
    ValueError: if `all_reduce_alg` not in [None, 'nccl', 'hierarchical_copy'].
  """
  if all_reduce_alg is None:
    return None
  mirrored_all_reduce_options = {
      "nccl": tf.distribute.NcclAllReduce,
      "hierarchical_copy": tf.distribute.HierarchicalCopyAllReduce
  }
  if all_reduce_alg not in mirrored_all_reduce_options:
    raise ValueError(
        "When used with `mirrored`, valid values for all_reduce_alg are "
        "['nccl', 'hierarchical_copy'].  Supplied value: {}".format(
            all_reduce_alg))
  cross_device_ops_class = mirrored_all_reduce_options[all_reduce_alg]
  return cross_device_ops_class(num_packs=num_packs)
84

85

86
def get_distribution_strategy(distribution_strategy="mirrored",
87
                              num_gpus=0,
88
                              all_reduce_alg=None,
89
90
                              num_packs=1,
                              tpu_address=None):
91
92
93
  """Return a DistributionStrategy for running the model.

  Args:
94
    distribution_strategy: a string specifying which distribution strategy to
95
      use. Accepted values are 'off', 'one_device', 'mirrored',
96
      'parameter_server', 'multi_worker_mirrored', and 'tpu' -- case insensitive.
97
      'off' means not to use Distribution Strategy; 'tpu' means to use
98
      TPUStrategy using `tpu_address`.
99
    num_gpus: Number of GPUs to run this model.
100
101
102
103
104
    all_reduce_alg: Optional. Specifies which algorithm to use when performing
      all-reduce. For `MirroredStrategy`, valid values are "nccl" and
      "hierarchical_copy". For `MultiWorkerMirroredStrategy`, valid values are
      "ring" and "nccl".  If None, DistributionStrategy will choose based on
      device topology.
105
106
    num_packs: Optional.  Sets the `num_packs` in `tf.distribute.NcclAllReduce`
      or `tf.distribute.HierarchicalCopyAllReduce` for `MirroredStrategy`.
107
108
    tpu_address: Optional. String that represents TPU to connect to. Must not
      be None if `distribution_strategy` is set to `tpu`.
109
  Returns:
110
    tf.distribute.DistibutionStrategy object.
Shining Sun's avatar
Shining Sun committed
111
  Raises:
112
    ValueError: if `distribution_strategy` is 'off' or 'one_device' and
113
114
      `num_gpus` is larger than 1; or `num_gpus` is negative or if
      `distribution_strategy` is `tpu` but `tpu_address` is not specified.
115
  """
116
117
118
119
120
  if num_gpus < 0:
    raise ValueError("`num_gpus` can not be negative.")

  distribution_strategy = distribution_strategy.lower()
  if distribution_strategy == "off":
121
    if num_gpus > 1:
122
      raise ValueError(
123
124
          "When {} GPUs are specified, distribution_strategy "
          "flag cannot be set to 'off'.".format(num_gpus))
125
126
    return None

127
  if distribution_strategy == "tpu":
Hongkun Yu's avatar
Hongkun Yu committed
128
    # When tpu_address is an empty string, we communicate with local TPUs.
129
130
131
    cluster_resolver = tpu_lib.tpu_initialize(tpu_address)
    return tf.distribute.experimental.TPUStrategy(cluster_resolver)

132
  if distribution_strategy == "multi_worker_mirrored":
133
    return tf.distribute.experimental.MultiWorkerMirroredStrategy(
134
        communication=_collective_communication(all_reduce_alg))
135

136
  if distribution_strategy == "one_device":
137
    if num_gpus == 0:
Toby Boyd's avatar
Toby Boyd committed
138
      return tf.distribute.OneDeviceStrategy("device:CPU:0")
139
140
141
142
    if num_gpus > 1:
      raise ValueError("`OneDeviceStrategy` can not be used for more than "
                       "one device.")
    return tf.distribute.OneDeviceStrategy("device:GPU:0")
143

144
  if distribution_strategy == "mirrored":
145
146
    if num_gpus == 0:
      devices = ["device:CPU:0"]
Shining Sun's avatar
Shining Sun committed
147
    else:
148
      devices = ["device:GPU:%d" % i for i in range(num_gpus)]
149
150
    return tf.distribute.MirroredStrategy(
        devices=devices,
151
        cross_device_ops=_mirrored_cross_device_ops(all_reduce_alg, num_packs))
152

153
  if distribution_strategy == "parameter_server":
154
    return tf.distribute.experimental.ParameterServerStrategy()
155
156
157
158

  raise ValueError(
      "Unrecognized Distribution Strategy: %r" % distribution_strategy)

159

160
161
162
163
164
165
166
167
168
# The `SyntheticDataset` is a temporary solution for generating synthetic data
# directly on devices. It is only useful for Keras with Distribution
# Strategies. We will have better support in `tf.data` or Distribution Strategy
# later.
class SyntheticDataset(object):
  """A dataset that generates synthetic data on each device."""

  def __init__(self, dataset, split_by=1):
    # dataset.take(1) doesn't have GPU kernel.
Toby Boyd's avatar
Toby Boyd committed
169
    with tf.device('device:CPU:0'):
170
171
172
      tensor = tf.data.experimental.get_single_element(dataset.take(1))
    flat_tensor = tf.nest.flatten(tensor)
    variable_data = []
173
    initializers = []
174
175
176
    for t in flat_tensor:
      rebatched_t = tf.split(t, num_or_size_splits=split_by, axis=0)[0]
      assert rebatched_t.shape.is_fully_defined(), rebatched_t.shape
177
      v = tf.compat.v1.get_local_variable(self._random_name(),
Toby Boyd's avatar
Toby Boyd committed
178
                                          initializer=rebatched_t)
179
      variable_data.append(v)
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
      initializers.append(v.initializer)
    input_data = tf.nest.pack_sequence_as(tensor, variable_data)
    self._iterator = SyntheticIterator(input_data, initializers)

  def _random_name(self, size=10, chars=string.ascii_uppercase + string.digits):
    return ''.join(random.choice(chars) for _ in range(size))

  def __iter__(self):
    return self._iterator

  def make_one_shot_iterator(self):
    return self._iterator

  def make_initializable_iterator(self):
    return self._iterator


class SyntheticIterator(object):
  """A dataset that generates synthetic data on each device."""

  def __init__(self, input_data, initializers):
    self._input_data = input_data
    self._initializers = initializers
203
204
205
206

  def get_next(self):
    return self._input_data

207
208
209
210
211
212
213
214
215
  def next(self):
    return self.__next__()

  def __next__(self):
    try:
      return self.get_next()
    except tf.errors.OutOfRangeError:
      raise StopIteration

216
217
218
219
220
221
222
223
224
  def initialize(self):
    if tf.executing_eagerly():
      return tf.no_op()
    else:
      return self._initializers


def _monkey_patch_dataset_method(strategy):
  """Monkey-patch `strategy`'s `make_dataset_iterator` method."""
225
  def make_dataset(self, dataset):
226
    logging.info('Using pure synthetic data.')
227
228
229
230
231
232
    with self.scope():
      if self.extended._global_batch_size:  # pylint: disable=protected-access
        return SyntheticDataset(dataset, self.num_replicas_in_sync)
      else:
        return SyntheticDataset(dataset)

233
234
235
236
237
238
239
240
  def make_iterator(self, dataset):
    dist_dataset = make_dataset(self, dataset)
    return iter(dist_dataset)

  strategy.orig_make_dataset_iterator = strategy.make_dataset_iterator
  strategy.make_dataset_iterator = make_iterator
  strategy.orig_distribute_dataset = strategy.experimental_distribute_dataset
  strategy.experimental_distribute_dataset = make_dataset
241
242
243


def _undo_monkey_patch_dataset_method(strategy):
244
245
246
247
  if hasattr(strategy, 'orig_make_dataset_iterator'):
    strategy.make_dataset_iterator = strategy.orig_make_dataset_iterator
  if hasattr(strategy, 'orig_distribute_dataset'):
    strategy.make_dataset_iterator = strategy.orig_distribute_dataset
248
249
250


def set_up_synthetic_data():
251
  _monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy)
252
  _monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
253
254
  _monkey_patch_dataset_method(
      tf.distribute.experimental.MultiWorkerMirroredStrategy)
255
256
257


def undo_set_up_synthetic_data():
258
  _undo_monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy)
259
  _undo_monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
260
261
  _undo_monkey_patch_dataset_method(
      tf.distribute.experimental.MultiWorkerMirroredStrategy)
262
263
264
265
266
267
268
269
270
271
272
273
274


def configure_cluster(worker_hosts=None, task_index=-1):
  """Set multi-worker cluster spec in TF_CONFIG environment variable.

  Args:
    worker_hosts: comma-separated list of worker ip:port pairs.

  Returns:
    Number of workers in the cluster.
  """
  tf_config = json.loads(os.environ.get('TF_CONFIG', '{}'))
  if tf_config:
275
276
    num_workers = (len(tf_config['cluster'].get('chief', [])) +
                   len(tf_config['cluster'].get('worker', [])))
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
  elif worker_hosts:
    workers = worker_hosts.split(',')
    num_workers = len(workers)
    if num_workers > 1 and task_index < 0:
      raise ValueError('Must specify task_index when number of workers > 1')
    task_index = 0 if num_workers == 1 else task_index
    os.environ['TF_CONFIG'] = json.dumps({
        'cluster': {
            'worker': workers
        },
        'task': {'type': 'worker', 'index': task_index}
    })
  else:
    num_workers = 1
  return num_workers
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309


def get_strategy_scope(strategy):
  if strategy:
    strategy_scope = strategy.scope()
  else:
    strategy_scope = DummyContextManager()

  return strategy_scope


class DummyContextManager(object):

  def __enter__(self):
    pass

  def __exit__(self, *args):
    pass