Commit 420fd1cf authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

[Clean up] Consolidate distribution utils.

PiperOrigin-RevId: 331359058
parent c7647f11
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper functions for running models in a distributed setting."""
import json
import os
import random
import string
from absl import logging
import tensorflow as tf
def _collective_communication(all_reduce_alg):
"""Return a CollectiveCommunication based on all_reduce_alg.
Args:
all_reduce_alg: a string specifying which collective communication to pick,
or None.
Returns:
tf.distribute.experimental.CollectiveCommunication object
Raises:
ValueError: if `all_reduce_alg` not in [None, "ring", "nccl"]
"""
collective_communication_options = {
None: tf.distribute.experimental.CollectiveCommunication.AUTO,
"ring": tf.distribute.experimental.CollectiveCommunication.RING,
"nccl": tf.distribute.experimental.CollectiveCommunication.NCCL
}
if all_reduce_alg not in collective_communication_options:
raise ValueError(
"When used with `multi_worker_mirrored`, valid values for "
"all_reduce_alg are [`ring`, `nccl`]. Supplied value: {}".format(
all_reduce_alg))
return collective_communication_options[all_reduce_alg]
def _mirrored_cross_device_ops(all_reduce_alg, num_packs):
"""Return a CrossDeviceOps based on all_reduce_alg and num_packs.
Args:
all_reduce_alg: a string specifying which cross device op to pick, or None.
num_packs: an integer specifying number of packs for the cross device op.
Returns:
tf.distribute.CrossDeviceOps object or None.
Raises:
ValueError: if `all_reduce_alg` not in [None, "nccl", "hierarchical_copy"].
"""
if all_reduce_alg is None:
return None
mirrored_all_reduce_options = {
"nccl": tf.distribute.NcclAllReduce,
"hierarchical_copy": tf.distribute.HierarchicalCopyAllReduce
}
if all_reduce_alg not in mirrored_all_reduce_options:
raise ValueError(
"When used with `mirrored`, valid values for all_reduce_alg are "
"[`nccl`, `hierarchical_copy`]. Supplied value: {}".format(
all_reduce_alg))
cross_device_ops_class = mirrored_all_reduce_options[all_reduce_alg]
return cross_device_ops_class(num_packs=num_packs)
def tpu_initialize(tpu_address):
"""Initializes TPU for TF 2.x training.
Args:
tpu_address: string, bns address of master TPU worker.
Returns:
A TPUClusterResolver.
"""
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu=tpu_address)
if tpu_address not in ("", "local"):
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
return cluster_resolver
def get_distribution_strategy(distribution_strategy="mirrored",
num_gpus=0,
all_reduce_alg=None,
num_packs=1,
tpu_address=None):
"""Return a DistributionStrategy for running the model.
Args:
distribution_strategy: a string specifying which distribution strategy to
use. Accepted values are "off", "one_device", "mirrored",
"parameter_server", "multi_worker_mirrored", and "tpu" -- case
insensitive. "off" means not to use Distribution Strategy; "tpu" means to
use TPUStrategy using `tpu_address`.
num_gpus: Number of GPUs to run this model.
all_reduce_alg: Optional. Specifies which algorithm to use when performing
all-reduce. For `MirroredStrategy`, valid values are "nccl" and
"hierarchical_copy". For `MultiWorkerMirroredStrategy`, valid values are
"ring" and "nccl". If None, DistributionStrategy will choose based on
device topology.
num_packs: Optional. Sets the `num_packs` in `tf.distribute.NcclAllReduce`
or `tf.distribute.HierarchicalCopyAllReduce` for `MirroredStrategy`.
tpu_address: Optional. String that represents TPU to connect to. Must not be
None if `distribution_strategy` is set to `tpu`.
Returns:
tf.distribute.DistibutionStrategy object.
Raises:
ValueError: if `distribution_strategy` is "off" or "one_device" and
`num_gpus` is larger than 1; or `num_gpus` is negative or if
`distribution_strategy` is `tpu` but `tpu_address` is not specified.
"""
if num_gpus < 0:
raise ValueError("`num_gpus` can not be negative.")
distribution_strategy = distribution_strategy.lower()
if distribution_strategy == "off":
if num_gpus > 1:
raise ValueError("When {} GPUs are specified, distribution_strategy "
"flag cannot be set to `off`.".format(num_gpus))
return None
if distribution_strategy == "tpu":
# When tpu_address is an empty string, we communicate with local TPUs.
cluster_resolver = tpu_initialize(tpu_address)
return tf.distribute.experimental.TPUStrategy(cluster_resolver)
if distribution_strategy == "multi_worker_mirrored":
return tf.distribute.experimental.MultiWorkerMirroredStrategy(
communication=_collective_communication(all_reduce_alg))
if distribution_strategy == "one_device":
if num_gpus == 0:
return tf.distribute.OneDeviceStrategy("device:CPU:0")
if num_gpus > 1:
raise ValueError("`OneDeviceStrategy` can not be used for more than "
"one device.")
return tf.distribute.OneDeviceStrategy("device:GPU:0")
if distribution_strategy == "mirrored":
if num_gpus == 0:
devices = ["device:CPU:0"]
else:
devices = ["device:GPU:%d" % i for i in range(num_gpus)]
return tf.distribute.MirroredStrategy(
devices=devices,
cross_device_ops=_mirrored_cross_device_ops(all_reduce_alg, num_packs))
if distribution_strategy == "parameter_server":
return tf.distribute.experimental.ParameterServerStrategy()
raise ValueError("Unrecognized Distribution Strategy: %r" %
distribution_strategy)
def configure_cluster(worker_hosts=None, task_index=-1):
"""Set multi-worker cluster spec in TF_CONFIG environment variable.
Args:
worker_hosts: comma-separated list of worker ip:port pairs.
Returns:
Number of workers in the cluster.
"""
tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
if tf_config:
num_workers = (
len(tf_config["cluster"].get("chief", [])) +
len(tf_config["cluster"].get("worker", [])))
elif worker_hosts:
workers = worker_hosts.split(",")
num_workers = len(workers)
if num_workers > 1 and task_index < 0:
raise ValueError("Must specify task_index when number of workers > 1")
task_index = 0 if num_workers == 1 else task_index
os.environ["TF_CONFIG"] = json.dumps({
"cluster": {
"worker": workers
},
"task": {
"type": "worker",
"index": task_index
}
})
else:
num_workers = 1
return num_workers
def get_strategy_scope(strategy):
if strategy:
strategy_scope = strategy.scope()
else:
strategy_scope = DummyContextManager()
return strategy_scope
class DummyContextManager(object):
def __enter__(self):
pass
def __exit__(self, *args):
pass
...@@ -14,32 +14,28 @@ ...@@ -14,32 +14,28 @@
# ============================================================================== # ==============================================================================
""" Tests for distribution util functions.""" """ Tests for distribution util functions."""
from __future__ import absolute_import import tensorflow as tf
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v2 as tf from official.common import distribute_utils
from official.utils.misc import distribution_utils
class GetDistributionStrategyTest(tf.test.TestCase): class GetDistributionStrategyTest(tf.test.TestCase):
"""Tests for get_distribution_strategy.""" """Tests for get_distribution_strategy."""
def test_one_device_strategy_cpu(self): def test_one_device_strategy_cpu(self):
ds = distribution_utils.get_distribution_strategy(num_gpus=0) ds = distribute_utils.get_distribution_strategy(num_gpus=0)
self.assertEquals(ds.num_replicas_in_sync, 1) self.assertEquals(ds.num_replicas_in_sync, 1)
self.assertEquals(len(ds.extended.worker_devices), 1) self.assertEquals(len(ds.extended.worker_devices), 1)
self.assertIn('CPU', ds.extended.worker_devices[0]) self.assertIn('CPU', ds.extended.worker_devices[0])
def test_one_device_strategy_gpu(self): def test_one_device_strategy_gpu(self):
ds = distribution_utils.get_distribution_strategy(num_gpus=1) ds = distribute_utils.get_distribution_strategy(num_gpus=1)
self.assertEquals(ds.num_replicas_in_sync, 1) self.assertEquals(ds.num_replicas_in_sync, 1)
self.assertEquals(len(ds.extended.worker_devices), 1) self.assertEquals(len(ds.extended.worker_devices), 1)
self.assertIn('GPU', ds.extended.worker_devices[0]) self.assertIn('GPU', ds.extended.worker_devices[0])
def test_mirrored_strategy(self): def test_mirrored_strategy(self):
ds = distribution_utils.get_distribution_strategy(num_gpus=5) ds = distribute_utils.get_distribution_strategy(num_gpus=5)
self.assertEquals(ds.num_replicas_in_sync, 5) self.assertEquals(ds.num_replicas_in_sync, 5)
self.assertEquals(len(ds.extended.worker_devices), 5) self.assertEquals(len(ds.extended.worker_devices), 5)
for device in ds.extended.worker_devices: for device in ds.extended.worker_devices:
......
...@@ -31,7 +31,7 @@ import tensorflow as tf ...@@ -31,7 +31,7 @@ import tensorflow as tf
from typing import Optional, Dict, List, Text, Callable, Union, Iterator, Any from typing import Optional, Dict, List, Text, Callable, Union, Iterator, Any
from official.modeling.hyperparams import params_dict from official.modeling.hyperparams import params_dict
from official.utils import hyperparams_flags from official.utils import hyperparams_flags
from official.utils.misc import distribution_utils from official.common import distribute_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -745,7 +745,7 @@ class ExecutorBuilder(object): ...@@ -745,7 +745,7 @@ class ExecutorBuilder(object):
""" """
def __init__(self, strategy_type=None, strategy_config=None): def __init__(self, strategy_type=None, strategy_config=None):
_ = distribution_utils.configure_cluster(strategy_config.worker_hosts, _ = distribute_utils.configure_cluster(strategy_config.worker_hosts,
strategy_config.task_index) strategy_config.task_index)
"""Constructor. """Constructor.
...@@ -756,7 +756,7 @@ class ExecutorBuilder(object): ...@@ -756,7 +756,7 @@ class ExecutorBuilder(object):
strategy_config: necessary config for constructing the proper Strategy. strategy_config: necessary config for constructing the proper Strategy.
Check strategy_flags_dict() for examples of the structure. Check strategy_flags_dict() for examples of the structure.
""" """
self._strategy = distribution_utils.get_distribution_strategy( self._strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=strategy_type, distribution_strategy=strategy_type,
num_gpus=strategy_config.num_gpus, num_gpus=strategy_config.num_gpus,
all_reduce_alg=strategy_config.all_reduce_alg, all_reduce_alg=strategy_config.all_reduce_alg,
......
...@@ -26,11 +26,10 @@ from absl import app ...@@ -26,11 +26,10 @@ from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.common import distribute_utils
from official.nlp.albert import configs as albert_configs from official.nlp.albert import configs as albert_configs
from official.nlp.bert import bert_models from official.nlp.bert import bert_models
from official.nlp.bert import run_classifier as run_classifier_bert from official.nlp.bert import run_classifier as run_classifier_bert
from official.utils.misc import distribution_utils
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -77,7 +76,7 @@ def main(_): ...@@ -77,7 +76,7 @@ def main(_):
if not FLAGS.model_dir: if not FLAGS.model_dir:
FLAGS.model_dir = '/tmp/bert20/' FLAGS.model_dir = '/tmp/bert20/'
strategy = distribution_utils.get_distribution_strategy( strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy, distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus, num_gpus=FLAGS.num_gpus,
tpu_address=FLAGS.tpu) tpu_address=FLAGS.tpu)
......
...@@ -27,12 +27,11 @@ from absl import app ...@@ -27,12 +27,11 @@ from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.common import distribute_utils
from official.nlp.albert import configs as albert_configs from official.nlp.albert import configs as albert_configs
from official.nlp.bert import run_squad_helper from official.nlp.bert import run_squad_helper
from official.nlp.bert import tokenization from official.nlp.bert import tokenization
from official.nlp.data import squad_lib_sp from official.nlp.data import squad_lib_sp
from official.utils.misc import distribution_utils
flags.DEFINE_string( flags.DEFINE_string(
'sp_model_file', None, 'sp_model_file', None,
...@@ -104,9 +103,8 @@ def main(_): ...@@ -104,9 +103,8 @@ def main(_):
# Configures cluster spec for multi-worker distribution strategy. # Configures cluster spec for multi-worker distribution strategy.
if FLAGS.num_gpus > 0: if FLAGS.num_gpus > 0:
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts, _ = distribute_utils.configure_cluster(FLAGS.worker_hosts, FLAGS.task_index)
FLAGS.task_index) strategy = distribute_utils.get_distribution_strategy(
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy, distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus, num_gpus=FLAGS.num_gpus,
all_reduce_alg=FLAGS.all_reduce_alg, all_reduce_alg=FLAGS.all_reduce_alg,
......
...@@ -25,8 +25,8 @@ import tempfile ...@@ -25,8 +25,8 @@ import tempfile
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from official.common import distribute_utils
from official.staging.training import grad_utils from official.staging.training import grad_utils
from official.utils.misc import distribution_utils
_SUMMARY_TXT = 'training_summary.txt' _SUMMARY_TXT = 'training_summary.txt'
_MIN_SUMMARY_STEPS = 10 _MIN_SUMMARY_STEPS = 10
...@@ -266,7 +266,7 @@ def run_customized_training_loop( ...@@ -266,7 +266,7 @@ def run_customized_training_loop(
train_iterator = _get_input_iterator(train_input_fn, strategy) train_iterator = _get_input_iterator(train_input_fn, strategy)
eval_loss_metric = tf.keras.metrics.Mean('training_loss', dtype=tf.float32) eval_loss_metric = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
with distribution_utils.get_strategy_scope(strategy): with distribute_utils.get_strategy_scope(strategy):
# To correctly place the model weights on accelerators, # To correctly place the model weights on accelerators,
# model and optimizer should be created in scope. # model and optimizer should be created in scope.
model, sub_model = model_fn() model, sub_model = model_fn()
......
...@@ -28,6 +28,7 @@ from absl import flags ...@@ -28,6 +28,7 @@ from absl import flags
from absl import logging from absl import logging
import gin import gin
import tensorflow as tf import tensorflow as tf
from official.common import distribute_utils
from official.modeling import performance from official.modeling import performance
from official.nlp import optimization from official.nlp import optimization
from official.nlp.bert import bert_models from official.nlp.bert import bert_models
...@@ -35,7 +36,6 @@ from official.nlp.bert import common_flags ...@@ -35,7 +36,6 @@ from official.nlp.bert import common_flags
from official.nlp.bert import configs as bert_configs from official.nlp.bert import configs as bert_configs
from official.nlp.bert import input_pipeline from official.nlp.bert import input_pipeline
from official.nlp.bert import model_saving_utils from official.nlp.bert import model_saving_utils
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
flags.DEFINE_enum( flags.DEFINE_enum(
...@@ -447,7 +447,7 @@ def custom_main(custom_callbacks=None, custom_metrics=None): ...@@ -447,7 +447,7 @@ def custom_main(custom_callbacks=None, custom_metrics=None):
FLAGS.model_dir) FLAGS.model_dir)
return return
strategy = distribution_utils.get_distribution_strategy( strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy, distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus, num_gpus=FLAGS.num_gpus,
tpu_address=FLAGS.tpu) tpu_address=FLAGS.tpu)
......
...@@ -23,6 +23,7 @@ from absl import flags ...@@ -23,6 +23,7 @@ from absl import flags
from absl import logging from absl import logging
import gin import gin
import tensorflow as tf import tensorflow as tf
from official.common import distribute_utils
from official.modeling import performance from official.modeling import performance
from official.nlp import optimization from official.nlp import optimization
from official.nlp.bert import bert_models from official.nlp.bert import bert_models
...@@ -30,7 +31,6 @@ from official.nlp.bert import common_flags ...@@ -30,7 +31,6 @@ from official.nlp.bert import common_flags
from official.nlp.bert import configs from official.nlp.bert import configs
from official.nlp.bert import input_pipeline from official.nlp.bert import input_pipeline
from official.nlp.bert import model_training_utils from official.nlp.bert import model_training_utils
from official.utils.misc import distribution_utils
flags.DEFINE_string('input_files', None, flags.DEFINE_string('input_files', None,
...@@ -205,9 +205,8 @@ def main(_): ...@@ -205,9 +205,8 @@ def main(_):
FLAGS.model_dir = '/tmp/bert20/' FLAGS.model_dir = '/tmp/bert20/'
# Configures cluster spec for multi-worker distribution strategy. # Configures cluster spec for multi-worker distribution strategy.
if FLAGS.num_gpus > 0: if FLAGS.num_gpus > 0:
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts, _ = distribute_utils.configure_cluster(FLAGS.worker_hosts, FLAGS.task_index)
FLAGS.task_index) strategy = distribute_utils.get_distribution_strategy(
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy, distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus, num_gpus=FLAGS.num_gpus,
all_reduce_alg=FLAGS.all_reduce_alg, all_reduce_alg=FLAGS.all_reduce_alg,
......
...@@ -28,12 +28,11 @@ from absl import flags ...@@ -28,12 +28,11 @@ from absl import flags
from absl import logging from absl import logging
import gin import gin
import tensorflow as tf import tensorflow as tf
from official.common import distribute_utils
from official.nlp.bert import configs as bert_configs from official.nlp.bert import configs as bert_configs
from official.nlp.bert import run_squad_helper from official.nlp.bert import run_squad_helper
from official.nlp.bert import tokenization from official.nlp.bert import tokenization
from official.nlp.data import squad_lib as squad_lib_wp from official.nlp.data import squad_lib as squad_lib_wp
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
...@@ -105,9 +104,8 @@ def main(_): ...@@ -105,9 +104,8 @@ def main(_):
# Configures cluster spec for multi-worker distribution strategy. # Configures cluster spec for multi-worker distribution strategy.
if FLAGS.num_gpus > 0: if FLAGS.num_gpus > 0:
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts, _ = distribute_utils.configure_cluster(FLAGS.worker_hosts, FLAGS.task_index)
FLAGS.task_index) strategy = distribute_utils.get_distribution_strategy(
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy, distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus, num_gpus=FLAGS.num_gpus,
all_reduce_alg=FLAGS.all_reduce_alg, all_reduce_alg=FLAGS.all_reduce_alg,
......
...@@ -27,13 +27,13 @@ from absl import flags ...@@ -27,13 +27,13 @@ from absl import flags
from absl import logging from absl import logging
from six.moves import zip from six.moves import zip
import tensorflow as tf import tensorflow as tf
from official.common import distribute_utils
from official.modeling.hyperparams import params_dict from official.modeling.hyperparams import params_dict
from official.nlp.nhnet import evaluation from official.nlp.nhnet import evaluation
from official.nlp.nhnet import input_pipeline from official.nlp.nhnet import input_pipeline
from official.nlp.nhnet import models from official.nlp.nhnet import models
from official.nlp.nhnet import optimizer from official.nlp.nhnet import optimizer
from official.nlp.transformer import metrics as transformer_metrics from official.nlp.transformer import metrics as transformer_metrics
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -185,7 +185,7 @@ def run(): ...@@ -185,7 +185,7 @@ def run():
if FLAGS.enable_mlir_bridge: if FLAGS.enable_mlir_bridge:
tf.config.experimental.enable_mlir_bridge() tf.config.experimental.enable_mlir_bridge()
strategy = distribution_utils.get_distribution_strategy( strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy, tpu_address=FLAGS.tpu) distribution_strategy=FLAGS.distribution_strategy, tpu_address=FLAGS.tpu)
if strategy: if strategy:
logging.info("***** Number of cores used : %d", logging.info("***** Number of cores used : %d",
......
...@@ -23,11 +23,11 @@ from official.core import train_utils ...@@ -23,11 +23,11 @@ from official.core import train_utils
# pylint: disable=unused-import # pylint: disable=unused-import
from official.common import registry_imports from official.common import registry_imports
# pylint: enable=unused-import # pylint: enable=unused-import
from official.common import distribute_utils
from official.common import flags as tfm_flags from official.common import flags as tfm_flags
from official.core import task_factory from official.core import task_factory
from official.core import train_lib from official.core import train_lib
from official.modeling import performance from official.modeling import performance
from official.utils.misc import distribution_utils
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -48,7 +48,7 @@ def main(_): ...@@ -48,7 +48,7 @@ def main(_):
if params.runtime.mixed_precision_dtype: if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype, performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale) params.runtime.loss_scale)
distribution_strategy = distribution_utils.get_distribution_strategy( distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy, distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg, all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus, num_gpus=params.runtime.num_gpus,
......
...@@ -28,13 +28,13 @@ import tensorflow as tf ...@@ -28,13 +28,13 @@ import tensorflow as tf
# pylint: disable=unused-import # pylint: disable=unused-import
from official.common import registry_imports from official.common import registry_imports
# pylint: enable=unused-import # pylint: enable=unused-import
from official.common import distribute_utils
from official.common import flags as tfm_flags from official.common import flags as tfm_flags
from official.core import task_factory from official.core import task_factory
from official.core import train_lib from official.core import train_lib
from official.core import train_utils from official.core import train_utils
from official.modeling import performance from official.modeling import performance
from official.modeling.hyperparams import config_definitions from official.modeling.hyperparams import config_definitions
from official.utils.misc import distribution_utils
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -77,7 +77,7 @@ def run_continuous_finetune( ...@@ -77,7 +77,7 @@ def run_continuous_finetune(
if params.runtime.mixed_precision_dtype: if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype, performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale) params.runtime.loss_scale)
distribution_strategy = distribution_utils.get_distribution_strategy( distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy, distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg, all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus, num_gpus=params.runtime.num_gpus,
......
...@@ -29,7 +29,7 @@ from absl import app ...@@ -29,7 +29,7 @@ from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.common import distribute_utils
from official.modeling import performance from official.modeling import performance
from official.nlp.transformer import compute_bleu from official.nlp.transformer import compute_bleu
from official.nlp.transformer import data_pipeline from official.nlp.transformer import data_pipeline
...@@ -40,7 +40,6 @@ from official.nlp.transformer import transformer ...@@ -40,7 +40,6 @@ from official.nlp.transformer import transformer
from official.nlp.transformer import translate from official.nlp.transformer import translate
from official.nlp.transformer.utils import tokenizer from official.nlp.transformer.utils import tokenizer
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
INF = int(1e9) INF = int(1e9)
...@@ -161,7 +160,7 @@ class TransformerTask(object): ...@@ -161,7 +160,7 @@ class TransformerTask(object):
params["steps_between_evals"] = flags_obj.steps_between_evals params["steps_between_evals"] = flags_obj.steps_between_evals
params["enable_checkpointing"] = flags_obj.enable_checkpointing params["enable_checkpointing"] = flags_obj.enable_checkpointing
self.distribution_strategy = distribution_utils.get_distribution_strategy( self.distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy, distribution_strategy=flags_obj.distribution_strategy,
num_gpus=num_gpus, num_gpus=num_gpus,
all_reduce_alg=flags_obj.all_reduce_alg, all_reduce_alg=flags_obj.all_reduce_alg,
...@@ -197,7 +196,7 @@ class TransformerTask(object): ...@@ -197,7 +196,7 @@ class TransformerTask(object):
keras_utils.set_session_config(enable_xla=flags_obj.enable_xla) keras_utils.set_session_config(enable_xla=flags_obj.enable_xla)
_ensure_dir(flags_obj.model_dir) _ensure_dir(flags_obj.model_dir)
with distribution_utils.get_strategy_scope(self.distribution_strategy): with distribute_utils.get_strategy_scope(self.distribution_strategy):
model = transformer.create_model(params, is_train=True) model = transformer.create_model(params, is_train=True)
opt = self._create_optimizer() opt = self._create_optimizer()
...@@ -376,7 +375,7 @@ class TransformerTask(object): ...@@ -376,7 +375,7 @@ class TransformerTask(object):
# We only want to create the model under DS scope for TPU case. # We only want to create the model under DS scope for TPU case.
# When 'distribution_strategy' is None, a no-op DummyContextManager will # When 'distribution_strategy' is None, a no-op DummyContextManager will
# be used. # be used.
with distribution_utils.get_strategy_scope(distribution_strategy): with distribute_utils.get_strategy_scope(distribution_strategy):
if not self.predict_model: if not self.predict_model:
self.predict_model = transformer.create_model(self.params, False) self.predict_model = transformer.create_model(self.params, False)
self._load_weights_if_possible( self._load_weights_if_possible(
......
...@@ -23,13 +23,13 @@ from absl import logging ...@@ -23,13 +23,13 @@ from absl import logging
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
# pylint: disable=unused-import # pylint: disable=unused-import
from official.common import distribute_utils
from official.nlp.xlnet import common_flags from official.nlp.xlnet import common_flags
from official.nlp.xlnet import data_utils from official.nlp.xlnet import data_utils
from official.nlp.xlnet import optimization from official.nlp.xlnet import optimization
from official.nlp.xlnet import training_utils from official.nlp.xlnet import training_utils
from official.nlp.xlnet import xlnet_config from official.nlp.xlnet import xlnet_config
from official.nlp.xlnet import xlnet_modeling as modeling from official.nlp.xlnet import xlnet_modeling as modeling
from official.utils.misc import distribution_utils
flags.DEFINE_integer("n_class", default=2, help="Number of classes.") flags.DEFINE_integer("n_class", default=2, help="Number of classes.")
flags.DEFINE_string( flags.DEFINE_string(
...@@ -130,7 +130,7 @@ def get_metric_fn(): ...@@ -130,7 +130,7 @@ def get_metric_fn():
def main(unused_argv): def main(unused_argv):
del unused_argv del unused_argv
strategy = distribution_utils.get_distribution_strategy( strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=FLAGS.strategy_type, distribution_strategy=FLAGS.strategy_type,
tpu_address=FLAGS.tpu) tpu_address=FLAGS.tpu)
if strategy: if strategy:
......
...@@ -23,13 +23,13 @@ from absl import flags ...@@ -23,13 +23,13 @@ from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
# pylint: disable=unused-import # pylint: disable=unused-import
from official.common import distribute_utils
from official.nlp.xlnet import common_flags from official.nlp.xlnet import common_flags
from official.nlp.xlnet import data_utils from official.nlp.xlnet import data_utils
from official.nlp.xlnet import optimization from official.nlp.xlnet import optimization
from official.nlp.xlnet import training_utils from official.nlp.xlnet import training_utils
from official.nlp.xlnet import xlnet_config from official.nlp.xlnet import xlnet_config
from official.nlp.xlnet import xlnet_modeling as modeling from official.nlp.xlnet import xlnet_modeling as modeling
from official.utils.misc import distribution_utils
flags.DEFINE_integer( flags.DEFINE_integer(
"num_predict", "num_predict",
...@@ -72,7 +72,7 @@ def get_pretrainxlnet_model(model_config, run_config): ...@@ -72,7 +72,7 @@ def get_pretrainxlnet_model(model_config, run_config):
def main(unused_argv): def main(unused_argv):
del unused_argv del unused_argv
num_hosts = 1 num_hosts = 1
strategy = distribution_utils.get_distribution_strategy( strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=FLAGS.strategy_type, distribution_strategy=FLAGS.strategy_type,
tpu_address=FLAGS.tpu) tpu_address=FLAGS.tpu)
if FLAGS.strategy_type == "tpu": if FLAGS.strategy_type == "tpu":
......
...@@ -27,6 +27,7 @@ from absl import logging ...@@ -27,6 +27,7 @@ from absl import logging
import tensorflow as tf import tensorflow as tf
# pylint: disable=unused-import # pylint: disable=unused-import
import sentencepiece as spm import sentencepiece as spm
from official.common import distribute_utils
from official.nlp.xlnet import common_flags from official.nlp.xlnet import common_flags
from official.nlp.xlnet import data_utils from official.nlp.xlnet import data_utils
from official.nlp.xlnet import optimization from official.nlp.xlnet import optimization
...@@ -34,7 +35,6 @@ from official.nlp.xlnet import squad_utils ...@@ -34,7 +35,6 @@ from official.nlp.xlnet import squad_utils
from official.nlp.xlnet import training_utils from official.nlp.xlnet import training_utils
from official.nlp.xlnet import xlnet_config from official.nlp.xlnet import xlnet_config
from official.nlp.xlnet import xlnet_modeling as modeling from official.nlp.xlnet import xlnet_modeling as modeling
from official.utils.misc import distribution_utils
flags.DEFINE_string( flags.DEFINE_string(
"test_feature_path", default=None, help="Path to feature of test set.") "test_feature_path", default=None, help="Path to feature of test set.")
...@@ -212,7 +212,7 @@ def get_qaxlnet_model(model_config, run_config, start_n_top, end_n_top): ...@@ -212,7 +212,7 @@ def get_qaxlnet_model(model_config, run_config, start_n_top, end_n_top):
def main(unused_argv): def main(unused_argv):
del unused_argv del unused_argv
strategy = distribution_utils.get_distribution_strategy( strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=FLAGS.strategy_type, distribution_strategy=FLAGS.strategy_type,
tpu_address=FLAGS.tpu) tpu_address=FLAGS.tpu)
if strategy: if strategy:
......
...@@ -21,20 +21,17 @@ from __future__ import print_function ...@@ -21,20 +21,17 @@ from __future__ import print_function
import json import json
import os import os
# pylint: disable=g-bad-import-order
import numpy as np
from absl import flags from absl import flags
from absl import logging from absl import logging
import numpy as np
import tensorflow as tf import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.common import distribute_utils
from official.recommendation import constants as rconst from official.recommendation import constants as rconst
from official.recommendation import data_pipeline from official.recommendation import data_pipeline
from official.recommendation import data_preprocessing from official.recommendation import data_preprocessing
from official.recommendation import movielens from official.recommendation import movielens
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -142,7 +139,7 @@ def get_v1_distribution_strategy(params): ...@@ -142,7 +139,7 @@ def get_v1_distribution_strategy(params):
tpu_cluster_resolver, steps_per_run=100) tpu_cluster_resolver, steps_per_run=100)
else: else:
distribution = distribution_utils.get_distribution_strategy( distribution = distribute_utils.get_distribution_strategy(
num_gpus=params["num_gpus"]) num_gpus=params["num_gpus"])
return distribution return distribution
......
...@@ -33,13 +33,13 @@ from absl import logging ...@@ -33,13 +33,13 @@ from absl import logging
import tensorflow.compat.v2 as tf import tensorflow.compat.v2 as tf
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
from official.common import distribute_utils
from official.recommendation import constants as rconst from official.recommendation import constants as rconst
from official.recommendation import movielens from official.recommendation import movielens
from official.recommendation import ncf_common from official.recommendation import ncf_common
from official.recommendation import ncf_input_pipeline from official.recommendation import ncf_input_pipeline
from official.recommendation import neumf_model from official.recommendation import neumf_model
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.utils.misc import model_helpers from official.utils.misc import model_helpers
...@@ -225,7 +225,7 @@ def run_ncf(_): ...@@ -225,7 +225,7 @@ def run_ncf(_):
loss_scale=flags_core.get_loss_scale(FLAGS, default_for_fp16="dynamic")) loss_scale=flags_core.get_loss_scale(FLAGS, default_for_fp16="dynamic"))
tf.keras.mixed_precision.experimental.set_policy(policy) tf.keras.mixed_precision.experimental.set_policy(policy)
strategy = distribution_utils.get_distribution_strategy( strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy, distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus, num_gpus=FLAGS.num_gpus,
tpu_address=FLAGS.tpu) tpu_address=FLAGS.tpu)
...@@ -271,7 +271,7 @@ def run_ncf(_): ...@@ -271,7 +271,7 @@ def run_ncf(_):
params, producer, input_meta_data, strategy)) params, producer, input_meta_data, strategy))
steps_per_epoch = None if generate_input_online else num_train_steps steps_per_epoch = None if generate_input_online else num_train_steps
with distribution_utils.get_strategy_scope(strategy): with distribute_utils.get_strategy_scope(strategy):
keras_model = _get_keras_model(params) keras_model = _get_keras_model(params)
optimizer = tf.keras.optimizers.Adam( optimizer = tf.keras.optimizers.Adam(
learning_rate=params["learning_rate"], learning_rate=params["learning_rate"],
......
...@@ -13,197 +13,5 @@ ...@@ -13,197 +13,5 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Helper functions for running models in a distributed setting.""" """Helper functions for running models in a distributed setting."""
# pylint: disable=wildcard-import
from __future__ import absolute_import from official.common.distribute_utils import *
from __future__ import division
from __future__ import print_function
import json
import os
import random
import string
from absl import logging
import tensorflow.compat.v2 as tf
from official.utils.misc import tpu_lib
def _collective_communication(all_reduce_alg):
"""Return a CollectiveCommunication based on all_reduce_alg.
Args:
all_reduce_alg: a string specifying which collective communication to pick,
or None.
Returns:
tf.distribute.experimental.CollectiveCommunication object
Raises:
ValueError: if `all_reduce_alg` not in [None, "ring", "nccl"]
"""
collective_communication_options = {
None: tf.distribute.experimental.CollectiveCommunication.AUTO,
"ring": tf.distribute.experimental.CollectiveCommunication.RING,
"nccl": tf.distribute.experimental.CollectiveCommunication.NCCL
}
if all_reduce_alg not in collective_communication_options:
raise ValueError(
"When used with `multi_worker_mirrored`, valid values for "
"all_reduce_alg are [`ring`, `nccl`]. Supplied value: {}".format(
all_reduce_alg))
return collective_communication_options[all_reduce_alg]
def _mirrored_cross_device_ops(all_reduce_alg, num_packs):
"""Return a CrossDeviceOps based on all_reduce_alg and num_packs.
Args:
all_reduce_alg: a string specifying which cross device op to pick, or None.
num_packs: an integer specifying number of packs for the cross device op.
Returns:
tf.distribute.CrossDeviceOps object or None.
Raises:
ValueError: if `all_reduce_alg` not in [None, "nccl", "hierarchical_copy"].
"""
if all_reduce_alg is None:
return None
mirrored_all_reduce_options = {
"nccl": tf.distribute.NcclAllReduce,
"hierarchical_copy": tf.distribute.HierarchicalCopyAllReduce
}
if all_reduce_alg not in mirrored_all_reduce_options:
raise ValueError(
"When used with `mirrored`, valid values for all_reduce_alg are "
"[`nccl`, `hierarchical_copy`]. Supplied value: {}".format(
all_reduce_alg))
cross_device_ops_class = mirrored_all_reduce_options[all_reduce_alg]
return cross_device_ops_class(num_packs=num_packs)
def get_distribution_strategy(distribution_strategy="mirrored",
num_gpus=0,
all_reduce_alg=None,
num_packs=1,
tpu_address=None):
"""Return a DistributionStrategy for running the model.
Args:
distribution_strategy: a string specifying which distribution strategy to
use. Accepted values are "off", "one_device", "mirrored",
"parameter_server", "multi_worker_mirrored", and "tpu" -- case
insensitive. "off" means not to use Distribution Strategy; "tpu" means to
use TPUStrategy using `tpu_address`.
num_gpus: Number of GPUs to run this model.
all_reduce_alg: Optional. Specifies which algorithm to use when performing
all-reduce. For `MirroredStrategy`, valid values are "nccl" and
"hierarchical_copy". For `MultiWorkerMirroredStrategy`, valid values are
"ring" and "nccl". If None, DistributionStrategy will choose based on
device topology.
num_packs: Optional. Sets the `num_packs` in `tf.distribute.NcclAllReduce`
or `tf.distribute.HierarchicalCopyAllReduce` for `MirroredStrategy`.
tpu_address: Optional. String that represents TPU to connect to. Must not be
None if `distribution_strategy` is set to `tpu`.
Returns:
tf.distribute.DistibutionStrategy object.
Raises:
ValueError: if `distribution_strategy` is "off" or "one_device" and
`num_gpus` is larger than 1; or `num_gpus` is negative or if
`distribution_strategy` is `tpu` but `tpu_address` is not specified.
"""
if num_gpus < 0:
raise ValueError("`num_gpus` can not be negative.")
distribution_strategy = distribution_strategy.lower()
if distribution_strategy == "off":
if num_gpus > 1:
raise ValueError("When {} GPUs are specified, distribution_strategy "
"flag cannot be set to `off`.".format(num_gpus))
return None
if distribution_strategy == "tpu":
# When tpu_address is an empty string, we communicate with local TPUs.
cluster_resolver = tpu_lib.tpu_initialize(tpu_address)
return tf.distribute.experimental.TPUStrategy(cluster_resolver)
if distribution_strategy == "multi_worker_mirrored":
return tf.distribute.experimental.MultiWorkerMirroredStrategy(
communication=_collective_communication(all_reduce_alg))
if distribution_strategy == "one_device":
if num_gpus == 0:
return tf.distribute.OneDeviceStrategy("device:CPU:0")
if num_gpus > 1:
raise ValueError("`OneDeviceStrategy` can not be used for more than "
"one device.")
return tf.distribute.OneDeviceStrategy("device:GPU:0")
if distribution_strategy == "mirrored":
if num_gpus == 0:
devices = ["device:CPU:0"]
else:
devices = ["device:GPU:%d" % i for i in range(num_gpus)]
return tf.distribute.MirroredStrategy(
devices=devices,
cross_device_ops=_mirrored_cross_device_ops(all_reduce_alg, num_packs))
if distribution_strategy == "parameter_server":
return tf.distribute.experimental.ParameterServerStrategy()
raise ValueError("Unrecognized Distribution Strategy: %r" %
distribution_strategy)
def configure_cluster(worker_hosts=None, task_index=-1):
"""Set multi-worker cluster spec in TF_CONFIG environment variable.
Args:
worker_hosts: comma-separated list of worker ip:port pairs.
Returns:
Number of workers in the cluster.
"""
tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
if tf_config:
num_workers = (
len(tf_config["cluster"].get("chief", [])) +
len(tf_config["cluster"].get("worker", [])))
elif worker_hosts:
workers = worker_hosts.split(",")
num_workers = len(workers)
if num_workers > 1 and task_index < 0:
raise ValueError("Must specify task_index when number of workers > 1")
task_index = 0 if num_workers == 1 else task_index
os.environ["TF_CONFIG"] = json.dumps({
"cluster": {
"worker": workers
},
"task": {
"type": "worker",
"index": task_index
}
})
else:
num_workers = 1
return num_workers
def get_strategy_scope(strategy):
if strategy:
strategy_scope = strategy.scope()
else:
strategy_scope = DummyContextManager()
return strategy_scope
class DummyContextManager(object):
def __enter__(self):
pass
def __exit__(self, *args):
pass
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Initializes TPU system for TF 2.0."""
import tensorflow as tf
def tpu_initialize(tpu_address):
"""Initializes TPU for TF 2.0 training.
Args:
tpu_address: string, bns address of master TPU worker.
Returns:
A TPUClusterResolver.
"""
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu=tpu_address)
if tpu_address not in ('', 'local'):
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
return cluster_resolver
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment