distribution_utils.py 13.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper functions for running models in a distributed setting."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

21
22
import json
import os
23
24
import random
import string
25
26
import tensorflow as tf

27
28
from official.utils.misc import tpu_lib

29

30
31
def _collective_communication(all_reduce_alg):
  """Return a CollectiveCommunication based on all_reduce_alg.
32

33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
  Args:
    all_reduce_alg: a string specifying which collective communication to pick,
      or None.

  Returns:
    tf.distribute.experimental.CollectiveCommunication object

  Raises:
    ValueError: if `all_reduce_alg` not in [None, 'ring', 'nccl']
  """
  collective_communication_options = {
      None: tf.distribute.experimental.CollectiveCommunication.AUTO,
      "ring": tf.distribute.experimental.CollectiveCommunication.RING,
      "nccl": tf.distribute.experimental.CollectiveCommunication.NCCL
  }
  if all_reduce_alg not in collective_communication_options:
    raise ValueError(
        "When used with `multi_worker_mirrored`, valid values for "
        "all_reduce_alg are ['ring', 'nccl'].  Supplied value: {}".format(
            all_reduce_alg))
  return collective_communication_options[all_reduce_alg]


def _mirrored_cross_device_ops(all_reduce_alg, num_packs):
  """Return a CrossDeviceOps based on all_reduce_alg and num_packs.

  Args:
    all_reduce_alg: a string specifying which cross device op to pick, or None.
    num_packs: an integer specifying number of packs for the cross device op.

  Returns:
    tf.distribute.CrossDeviceOps object or None.

  Raises:
    ValueError: if `all_reduce_alg` not in [None, 'nccl', 'hierarchical_copy'].
  """
  if all_reduce_alg is None:
    return None
  mirrored_all_reduce_options = {
      "nccl": tf.distribute.NcclAllReduce,
      "hierarchical_copy": tf.distribute.HierarchicalCopyAllReduce
  }
  if all_reduce_alg not in mirrored_all_reduce_options:
    raise ValueError(
        "When used with `mirrored`, valid values for all_reduce_alg are "
        "['nccl', 'hierarchical_copy'].  Supplied value: {}".format(
            all_reduce_alg))
  cross_device_ops_class = mirrored_all_reduce_options[all_reduce_alg]
  return cross_device_ops_class(num_packs=num_packs)
82

83

84
85
def get_distribution_strategy(distribution_strategy="default",
                              num_gpus=0,
86
                              num_workers=1,
87
                              all_reduce_alg=None,
88
89
                              num_packs=1,
                              tpu_address=None):
90
91
92
  """Return a DistributionStrategy for running the model.

  Args:
93
94
    distribution_strategy: a string specifying which distribution strategy to
      use. Accepted values are 'off', 'default', 'one_device', 'mirrored',
95
96
      'parameter_server', 'multi_worker_mirrored', and 'tpu' -- case insensitive.
      'off' means not to use Distribution Strategy; 'default' means to choose from
97
      `MirroredStrategy`, `MultiWorkerMirroredStrategy`, or `OneDeviceStrategy`
98
99
      according to the number of GPUs and number of workers. 'tpu' means to use
      TPUStrategy using `tpu_address`.
100
    num_gpus: Number of GPUs to run this model.
101
    num_workers: Number of workers to run this model.
102
103
104
105
106
    all_reduce_alg: Optional. Specifies which algorithm to use when performing
      all-reduce. For `MirroredStrategy`, valid values are "nccl" and
      "hierarchical_copy". For `MultiWorkerMirroredStrategy`, valid values are
      "ring" and "nccl".  If None, DistributionStrategy will choose based on
      device topology.
107
108
    num_packs: Optional.  Sets the `num_packs` in `tf.distribute.NcclAllReduce`
      or `tf.distribute.HierarchicalCopyAllReduce` for `MirroredStrategy`.
109
110
    tpu_address: Optional. String that represents TPU to connect to. Must not
      be None if `distribution_strategy` is set to `tpu`.
111
  Returns:
112
    tf.distribute.DistibutionStrategy object.
Shining Sun's avatar
Shining Sun committed
113
  Raises:
114
    ValueError: if `distribution_strategy` is 'off' or 'one_device' and
115
116
      `num_gpus` is larger than 1; or `num_gpus` is negative or if
      `distribution_strategy` is `tpu` but `tpu_address` is not specified.
117
  """
118
119
120
121
122
  if num_gpus < 0:
    raise ValueError("`num_gpus` can not be negative.")

  distribution_strategy = distribution_strategy.lower()
  if distribution_strategy == "off":
123
    if num_gpus > 1:
124
125
126
      raise ValueError(
          "When {} GPUs and  {} workers are specified, distribution_strategy "
          "flag cannot be set to 'off'.".format(num_gpus, num_workers))
127
128
    return None

129
  if distribution_strategy == "tpu":
Hongkun Yu's avatar
Hongkun Yu committed
130
    # When tpu_address is an empty string, we communicate with local TPUs.
131
132
133
    cluster_resolver = tpu_lib.tpu_initialize(tpu_address)
    return tf.distribute.experimental.TPUStrategy(cluster_resolver)

134
  if distribution_strategy == "multi_worker_mirrored":
135
    return tf.distribute.experimental.MultiWorkerMirroredStrategy(
136
        communication=_collective_communication(all_reduce_alg))
137

138
139
140
  if (distribution_strategy == "one_device" or
      (distribution_strategy == "default" and num_gpus <= 1)):
    if num_gpus == 0:
Toby Boyd's avatar
Toby Boyd committed
141
      return tf.distribute.OneDeviceStrategy("device:CPU:0")
Toby Boyd's avatar
Toby Boyd committed
142
    else:
143
144
145
      if num_gpus > 1:
        raise ValueError("`OneDeviceStrategy` can not be used for more than "
                         "one device.")
Toby Boyd's avatar
Toby Boyd committed
146
      return tf.distribute.OneDeviceStrategy("device:GPU:0")
147
148
149
150
151

  if distribution_strategy in ("mirrored", "default"):
    if num_gpus == 0:
      assert distribution_strategy == "mirrored"
      devices = ["device:CPU:0"]
Shining Sun's avatar
Shining Sun committed
152
    else:
153
      devices = ["device:GPU:%d" % i for i in range(num_gpus)]
154
155
    return tf.distribute.MirroredStrategy(
        devices=devices,
156
        cross_device_ops=_mirrored_cross_device_ops(all_reduce_alg, num_packs))
157

158
  if distribution_strategy == "parameter_server":
159
    return tf.distribute.experimental.ParameterServerStrategy()
160
161
162
163

  raise ValueError(
      "Unrecognized Distribution Strategy: %r" % distribution_strategy)

164

165
def per_replica_batch_size(batch_size, num_gpus):
166
167
  """For multi-gpu, batch-size must be a multiple of the number of GPUs.

168
169
170

  Note that distribution strategy handles this automatically when used with
  Keras. For using with Estimator, we need to get per GPU batch.
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187

  Args:
    batch_size: Global batch size to be divided among devices. This should be
      equal to num_gpus times the single-GPU batch_size for multi-gpu training.
    num_gpus: How many GPUs are used with DistributionStrategies.

  Returns:
    Batch size per device.

  Raises:
    ValueError: if batch_size is not divisible by number of devices
  """
  if num_gpus <= 1:
    return batch_size

  remainder = batch_size % num_gpus
  if remainder:
Toby Boyd's avatar
Toby Boyd committed
188
189
190
    err = ('When running with multiple GPUs, batch size '
           'must be a multiple of the number of available GPUs. Found {} '
           'GPUs with a batch size of {}; try --batch_size={} instead.'
191
192
193
          ).format(num_gpus, batch_size, batch_size - remainder)
    raise ValueError(err)
  return int(batch_size / num_gpus)
194

Toby Boyd's avatar
Toby Boyd committed
195

196
197
198
199
200
201
202
203
204
# The `SyntheticDataset` is a temporary solution for generating synthetic data
# directly on devices. It is only useful for Keras with Distribution
# Strategies. We will have better support in `tf.data` or Distribution Strategy
# later.
class SyntheticDataset(object):
  """A dataset that generates synthetic data on each device."""

  def __init__(self, dataset, split_by=1):
    # dataset.take(1) doesn't have GPU kernel.
Toby Boyd's avatar
Toby Boyd committed
205
    with tf.device('device:CPU:0'):
206
207
208
      tensor = tf.data.experimental.get_single_element(dataset.take(1))
    flat_tensor = tf.nest.flatten(tensor)
    variable_data = []
209
    initializers = []
210
211
212
    for t in flat_tensor:
      rebatched_t = tf.split(t, num_or_size_splits=split_by, axis=0)[0]
      assert rebatched_t.shape.is_fully_defined(), rebatched_t.shape
213
      v = tf.compat.v1.get_local_variable(self._random_name(),
Toby Boyd's avatar
Toby Boyd committed
214
                                          initializer=rebatched_t)
215
      variable_data.append(v)
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
      initializers.append(v.initializer)
    input_data = tf.nest.pack_sequence_as(tensor, variable_data)
    self._iterator = SyntheticIterator(input_data, initializers)

  def _random_name(self, size=10, chars=string.ascii_uppercase + string.digits):
    return ''.join(random.choice(chars) for _ in range(size))

  def __iter__(self):
    return self._iterator

  def make_one_shot_iterator(self):
    return self._iterator

  def make_initializable_iterator(self):
    return self._iterator


class SyntheticIterator(object):
  """A dataset that generates synthetic data on each device."""

  def __init__(self, input_data, initializers):
    self._input_data = input_data
    self._initializers = initializers
239
240
241
242

  def get_next(self):
    return self._input_data

243
244
245
246
247
248
249
250
251
  def next(self):
    return self.__next__()

  def __next__(self):
    try:
      return self.get_next()
    except tf.errors.OutOfRangeError:
      raise StopIteration

252
253
254
255
256
257
258
259
260
  def initialize(self):
    if tf.executing_eagerly():
      return tf.no_op()
    else:
      return self._initializers


def _monkey_patch_dataset_method(strategy):
  """Monkey-patch `strategy`'s `make_dataset_iterator` method."""
261
  def make_dataset(self, dataset):
Toby Boyd's avatar
Toby Boyd committed
262
    tf.compat.v1.logging.info('Using pure synthetic data.')
263
264
265
266
267
268
    with self.scope():
      if self.extended._global_batch_size:  # pylint: disable=protected-access
        return SyntheticDataset(dataset, self.num_replicas_in_sync)
      else:
        return SyntheticDataset(dataset)

269
270
271
272
273
274
275
276
  def make_iterator(self, dataset):
    dist_dataset = make_dataset(self, dataset)
    return iter(dist_dataset)

  strategy.orig_make_dataset_iterator = strategy.make_dataset_iterator
  strategy.make_dataset_iterator = make_iterator
  strategy.orig_distribute_dataset = strategy.experimental_distribute_dataset
  strategy.experimental_distribute_dataset = make_dataset
277
278
279


def _undo_monkey_patch_dataset_method(strategy):
280
281
282
283
  if hasattr(strategy, 'orig_make_dataset_iterator'):
    strategy.make_dataset_iterator = strategy.orig_make_dataset_iterator
  if hasattr(strategy, 'orig_distribute_dataset'):
    strategy.make_dataset_iterator = strategy.orig_distribute_dataset
284
285
286


def set_up_synthetic_data():
287
  _monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy)
288
  _monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
289
290
  _monkey_patch_dataset_method(
      tf.distribute.experimental.MultiWorkerMirroredStrategy)
Toby Boyd's avatar
Toby Boyd committed
291
292
293
294
  # TODO(tobyboyd): Remove when contrib.distribute is all in core.
  if hasattr(tf, 'contrib'):
    _monkey_patch_dataset_method(tf.contrib.distribute.MirroredStrategy)
    _monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy)
295
296
    _monkey_patch_dataset_method(
        tf.contrib.distribute.CollectiveAllReduceStrategy)
Toby Boyd's avatar
Toby Boyd committed
297
298
  else:
    print('Contrib missing: Skip monkey patch tf.contrib.distribute.*')
299
300
301


def undo_set_up_synthetic_data():
302
  _undo_monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy)
303
  _undo_monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
304
305
  _undo_monkey_patch_dataset_method(
      tf.distribute.experimental.MultiWorkerMirroredStrategy)
Toby Boyd's avatar
Toby Boyd committed
306
307
308
309
  # TODO(tobyboyd): Remove when contrib.distribute is all in core.
  if hasattr(tf, 'contrib'):
    _undo_monkey_patch_dataset_method(tf.contrib.distribute.MirroredStrategy)
    _undo_monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy)
310
311
    _undo_monkey_patch_dataset_method(
        tf.contrib.distribute.CollectiveAllReduceStrategy)
Toby Boyd's avatar
Toby Boyd committed
312
313
  else:
    print('Contrib missing: Skip remove monkey patch tf.contrib.distribute.*')
314
315
316
317
318
319
320
321
322
323
324
325
326


def configure_cluster(worker_hosts=None, task_index=-1):
  """Set multi-worker cluster spec in TF_CONFIG environment variable.

  Args:
    worker_hosts: comma-separated list of worker ip:port pairs.

  Returns:
    Number of workers in the cluster.
  """
  tf_config = json.loads(os.environ.get('TF_CONFIG', '{}'))
  if tf_config:
327
328
    num_workers = (len(tf_config['cluster'].get('chief', [])) +
                   len(tf_config['cluster'].get('worker', [])))
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
  elif worker_hosts:
    workers = worker_hosts.split(',')
    num_workers = len(workers)
    if num_workers > 1 and task_index < 0:
      raise ValueError('Must specify task_index when number of workers > 1')
    task_index = 0 if num_workers == 1 else task_index
    os.environ['TF_CONFIG'] = json.dumps({
        'cluster': {
            'worker': workers
        },
        'task': {'type': 'worker', 'index': task_index}
    })
  else:
    num_workers = 1
  return num_workers
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361


def get_strategy_scope(strategy):
  if strategy:
    strategy_scope = strategy.scope()
  else:
    strategy_scope = DummyContextManager()

  return strategy_scope


class DummyContextManager(object):

  def __enter__(self):
    pass

  def __exit__(self, *args):
    pass