distribution_utils.py 7.26 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper functions for running models in a distributed setting."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

21
22
import json
import os
23
24
import random
import string
25

26
from absl import logging
27
import tensorflow.compat.v2 as tf
28

29
30
from official.utils.misc import tpu_lib

31

32
33
def _collective_communication(all_reduce_alg):
  """Return a CollectiveCommunication based on all_reduce_alg.
34

35
36
37
38
39
40
41
42
  Args:
    all_reduce_alg: a string specifying which collective communication to pick,
      or None.

  Returns:
    tf.distribute.experimental.CollectiveCommunication object

  Raises:
Hongkun Yu's avatar
Hongkun Yu committed
43
    ValueError: if `all_reduce_alg` not in [None, "ring", "nccl"]
44
45
46
47
48
49
50
51
52
  """
  collective_communication_options = {
      None: tf.distribute.experimental.CollectiveCommunication.AUTO,
      "ring": tf.distribute.experimental.CollectiveCommunication.RING,
      "nccl": tf.distribute.experimental.CollectiveCommunication.NCCL
  }
  if all_reduce_alg not in collective_communication_options:
    raise ValueError(
        "When used with `multi_worker_mirrored`, valid values for "
Hongkun Yu's avatar
Hongkun Yu committed
53
        "all_reduce_alg are [`ring`, `nccl`].  Supplied value: {}".format(
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
            all_reduce_alg))
  return collective_communication_options[all_reduce_alg]


def _mirrored_cross_device_ops(all_reduce_alg, num_packs):
  """Return a CrossDeviceOps based on all_reduce_alg and num_packs.

  Args:
    all_reduce_alg: a string specifying which cross device op to pick, or None.
    num_packs: an integer specifying number of packs for the cross device op.

  Returns:
    tf.distribute.CrossDeviceOps object or None.

  Raises:
Hongkun Yu's avatar
Hongkun Yu committed
69
    ValueError: if `all_reduce_alg` not in [None, "nccl", "hierarchical_copy"].
70
71
72
73
74
75
76
77
78
79
  """
  if all_reduce_alg is None:
    return None
  mirrored_all_reduce_options = {
      "nccl": tf.distribute.NcclAllReduce,
      "hierarchical_copy": tf.distribute.HierarchicalCopyAllReduce
  }
  if all_reduce_alg not in mirrored_all_reduce_options:
    raise ValueError(
        "When used with `mirrored`, valid values for all_reduce_alg are "
Hongkun Yu's avatar
Hongkun Yu committed
80
        "[`nccl`, `hierarchical_copy`].  Supplied value: {}".format(
81
82
83
            all_reduce_alg))
  cross_device_ops_class = mirrored_all_reduce_options[all_reduce_alg]
  return cross_device_ops_class(num_packs=num_packs)
84

85

86
def get_distribution_strategy(distribution_strategy="mirrored",
87
                              num_gpus=0,
88
                              all_reduce_alg=None,
89
90
                              num_packs=1,
                              tpu_address=None):
91
92
93
  """Return a DistributionStrategy for running the model.

  Args:
94
    distribution_strategy: a string specifying which distribution strategy to
Hongkun Yu's avatar
Hongkun Yu committed
95
96
97
      use. Accepted values are "off", "one_device", "mirrored",
      "parameter_server", "multi_worker_mirrored", and "tpu" -- case insensitive.
      "off" means not to use Distribution Strategy; "tpu" means to use
98
      TPUStrategy using `tpu_address`.
99
    num_gpus: Number of GPUs to run this model.
100
101
102
103
104
    all_reduce_alg: Optional. Specifies which algorithm to use when performing
      all-reduce. For `MirroredStrategy`, valid values are "nccl" and
      "hierarchical_copy". For `MultiWorkerMirroredStrategy`, valid values are
      "ring" and "nccl".  If None, DistributionStrategy will choose based on
      device topology.
105
106
    num_packs: Optional.  Sets the `num_packs` in `tf.distribute.NcclAllReduce`
      or `tf.distribute.HierarchicalCopyAllReduce` for `MirroredStrategy`.
107
108
    tpu_address: Optional. String that represents TPU to connect to. Must not
      be None if `distribution_strategy` is set to `tpu`.
109
  Returns:
110
    tf.distribute.DistibutionStrategy object.
Shining Sun's avatar
Shining Sun committed
111
  Raises:
Hongkun Yu's avatar
Hongkun Yu committed
112
    ValueError: if `distribution_strategy` is "off" or "one_device" and
113
114
      `num_gpus` is larger than 1; or `num_gpus` is negative or if
      `distribution_strategy` is `tpu` but `tpu_address` is not specified.
115
  """
116
117
118
119
120
  if num_gpus < 0:
    raise ValueError("`num_gpus` can not be negative.")

  distribution_strategy = distribution_strategy.lower()
  if distribution_strategy == "off":
121
    if num_gpus > 1:
122
      raise ValueError(
123
          "When {} GPUs are specified, distribution_strategy "
Hongkun Yu's avatar
Hongkun Yu committed
124
          "flag cannot be set to `off`.".format(num_gpus))
125
126
    return None

127
  if distribution_strategy == "tpu":
Hongkun Yu's avatar
Hongkun Yu committed
128
    # When tpu_address is an empty string, we communicate with local TPUs.
129
130
131
    cluster_resolver = tpu_lib.tpu_initialize(tpu_address)
    return tf.distribute.experimental.TPUStrategy(cluster_resolver)

132
  if distribution_strategy == "multi_worker_mirrored":
133
    return tf.distribute.experimental.MultiWorkerMirroredStrategy(
134
        communication=_collective_communication(all_reduce_alg))
135

136
  if distribution_strategy == "one_device":
137
    if num_gpus == 0:
Toby Boyd's avatar
Toby Boyd committed
138
      return tf.distribute.OneDeviceStrategy("device:CPU:0")
139
140
141
142
    if num_gpus > 1:
      raise ValueError("`OneDeviceStrategy` can not be used for more than "
                       "one device.")
    return tf.distribute.OneDeviceStrategy("device:GPU:0")
143

144
  if distribution_strategy == "mirrored":
145
146
    if num_gpus == 0:
      devices = ["device:CPU:0"]
Shining Sun's avatar
Shining Sun committed
147
    else:
148
      devices = ["device:GPU:%d" % i for i in range(num_gpus)]
149
150
    return tf.distribute.MirroredStrategy(
        devices=devices,
151
        cross_device_ops=_mirrored_cross_device_ops(all_reduce_alg, num_packs))
152

153
  if distribution_strategy == "parameter_server":
154
    return tf.distribute.experimental.ParameterServerStrategy()
155
156
157
158

  raise ValueError(
      "Unrecognized Distribution Strategy: %r" % distribution_strategy)

159

160
161
162
163
164
165
166
167
168
def configure_cluster(worker_hosts=None, task_index=-1):
  """Set multi-worker cluster spec in TF_CONFIG environment variable.

  Args:
    worker_hosts: comma-separated list of worker ip:port pairs.

  Returns:
    Number of workers in the cluster.
  """
Hongkun Yu's avatar
Hongkun Yu committed
169
  tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
170
  if tf_config:
Hongkun Yu's avatar
Hongkun Yu committed
171
172
    num_workers = (len(tf_config["cluster"].get("chief", [])) +
                   len(tf_config["cluster"].get("worker", [])))
173
  elif worker_hosts:
Hongkun Yu's avatar
Hongkun Yu committed
174
    workers = worker_hosts.split(",")
175
176
    num_workers = len(workers)
    if num_workers > 1 and task_index < 0:
Hongkun Yu's avatar
Hongkun Yu committed
177
      raise ValueError("Must specify task_index when number of workers > 1")
178
    task_index = 0 if num_workers == 1 else task_index
Hongkun Yu's avatar
Hongkun Yu committed
179
180
181
    os.environ["TF_CONFIG"] = json.dumps({
        "cluster": {
            "worker": workers
182
        },
Hongkun Yu's avatar
Hongkun Yu committed
183
        "task": {"type": "worker", "index": task_index}
184
185
186
187
    })
  else:
    num_workers = 1
  return num_workers
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205


def get_strategy_scope(strategy):
  if strategy:
    strategy_scope = strategy.scope()
  else:
    strategy_scope = DummyContextManager()

  return strategy_scope


class DummyContextManager(object):

  def __enter__(self):
    pass

  def __exit__(self, *args):
    pass