"git@developer.sourcefind.cn:OpenDAS/fastmoe.git" did not exist on "b690f2459a5c70f2a6e3b3450fe408edafc8cfb4"
Commit 4270e416 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 283770513
parent ad56514f
...@@ -32,32 +32,13 @@ import tensorflow as tf ...@@ -32,32 +32,13 @@ import tensorflow as tf
from typing import Optional, Dict, List, Text, Callable, Union, Iterator, Any from typing import Optional, Dict, List, Text, Callable, Union, Iterator, Any
from official.modeling.hyperparams import params_dict from official.modeling.hyperparams import params_dict
from official.utils.misc import tpu_lib from official.utils.misc import tpu_lib
from official.utils.misc import distribution_utils
from official.utils import hyperparams_flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
strategy_flags_dict = hyperparams_flags.strategy_flags_dict
def strategy_flags_dict(): hparam_flags_dict = hyperparams_flags.hparam_flags_dict
"""Returns TPU related flags in a dictionary."""
return {
# TPUStrategy related flags.
'tpu': FLAGS.tpu,
# MultiWorkerMirroredStrategy related flags.
'worker_hosts': FLAGS.worker_hosts,
'task_index': FLAGS.task_index,
}
def hparam_flags_dict():
"""Returns model params related flags in a dictionary."""
return {
'data_dir': FLAGS.data_dir,
'model_dir': FLAGS.model_dir,
'train_batch_size': FLAGS.train_batch_size,
'eval_batch_size': FLAGS.eval_batch_size,
'precision': FLAGS.precision,
'config_file': FLAGS.config_file,
'params_override': FLAGS.params_override,
}
def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix): def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix):
...@@ -647,7 +628,6 @@ class DistributedExecutor(object): ...@@ -647,7 +628,6 @@ class DistributedExecutor(object):
return NotImplementedError('Unimplmented function.') return NotImplementedError('Unimplmented function.')
# TODO(yeqing): Add unit test for MultiWorkerMirroredStrategy.
class ExecutorBuilder(object): class ExecutorBuilder(object):
"""Builder of DistributedExecutor. """Builder of DistributedExecutor.
...@@ -692,8 +672,15 @@ class ExecutorBuilder(object): ...@@ -692,8 +672,15 @@ class ExecutorBuilder(object):
""" """
def __init__(self, strategy_type=None, strategy_config=None): def __init__(self, strategy_type=None, strategy_config=None):
self._strategy_config = strategy_config num_workers = distribution_utils.configure_cluster(
self._strategy = self._build_strategy(strategy_type) strategy_config.worker_hosts, strategy_config.task_index)
self._strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=strategy_type,
num_gpus=strategy_config.num_gpus,
num_workers=num_workers,
all_reduce_alg=strategy_config.all_reduce_alg,
num_packs=strategy_config.num_packs,
tpu_address=strategy_config.tpu)
@property @property
def strategy(self): def strategy(self):
...@@ -705,66 +692,6 @@ class ExecutorBuilder(object): ...@@ -705,66 +692,6 @@ class ExecutorBuilder(object):
"""Sets default summary writer for the current thread.""" """Sets default summary writer for the current thread."""
self._strategy = new_strategy self._strategy = new_strategy
def _build_strategy(self, strategy_type):
"""Builds tf.distribute.Strategy instance.
Args:
strategy_type: string. One of 'tpu', 'one_device_gpu', 'mirrored', 'multi_worker_mirrored'.
Returns:
An tf.distribute.Strategy object. Returns None if strategy_type is None.
"""
if strategy_type is None:
return None
if strategy_type == 'tpu':
return self._build_tpu_strategy()
elif strategy_type == 'one_device_gpu':
return tf.distribute.OneDeviceStrategy("device:GPU:0")
elif strategy_type == 'mirrored':
return self._build_mirrored_strategy()
elif strategy_type == 'multi_worker_mirrored':
return self._build_multiworker_mirrored_strategy()
else:
raise NotImplementedError('Unsupport accelerator type "%s"' %
strategy_type)
def _build_mirrored_strategy(self):
"""Builds a MirroredStrategy object."""
return tf.distribute.MirroredStrategy()
def _build_tpu_strategy(self):
"""Builds a TPUStrategy object."""
tpu = self._strategy_config.tpu
logging.info('Use TPU at %s', tpu if tpu is not None else '')
cluster_resolver = tpu_lib.tpu_initialize(tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
return strategy
def _build_multiworker_mirrored_strategy(self):
"""Builds a MultiWorkerMirroredStrategy object."""
worker_hosts = self._strategy_config.worker_hosts
if worker_hosts is not None:
# Set TF_CONFIG environment variable
worker_hosts = worker_hosts.split(',')
task_index = self._strategy_config.task_index
os.environ['TF_CONFIG'] = json.dumps({
'cluster': {
'worker': worker_hosts
},
'task': {
'type': 'worker',
'index': task_index
}
})
multiworker_strategy = (
tf.distribute.experimental.MultiWorkerMirroredStrategy())
return multiworker_strategy
def build_executor(self, def build_executor(self,
class_ctor=DistributedExecutor, class_ctor=DistributedExecutor,
......
...@@ -20,6 +20,7 @@ from __future__ import division ...@@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl import flags from absl import flags
from official.utils.flags import core as flags_core
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -68,33 +69,51 @@ def define_common_hparams_flags(): ...@@ -68,33 +69,51 @@ def define_common_hparams_flags():
'The final override order of parameters: default_model_params --> ' 'The final override order of parameters: default_model_params --> '
'params from config_file --> params in params_override.' 'params from config_file --> params in params_override.'
'See also the help message of `--config_file`.')) 'See also the help message of `--config_file`.'))
flags.DEFINE_integer('save_checkpoint_freq', None,
flags.DEFINE_string( 'Number of steps to save checkpoint.')
'strategy_type', 'mirrored', 'Type of distribute strategy.'
'One of mirrored, tpu and multiworker.')
def initialize_common_flags(): def initialize_common_flags():
"""Define the common flags across models.""" """Define the common flags across models."""
key_flags = []
define_common_hparams_flags() define_common_hparams_flags()
flags_core.define_device(tpu=True)
flags_core.define_base(
num_gpu=True, model_dir=False, data_dir=False, batch_size=False)
flags_core.define_distribution(worker_hosts=True, task_index=True)
flags_core.define_performance(all_reduce_alg=True, num_packs=True)
# Reset the default value of num_gpus to zero.
FLAGS.num_gpus = 0
flags.DEFINE_string( flags.DEFINE_string(
'tpu', 'strategy_type', 'mirrored', 'Type of distribute strategy.'
default=None, 'One of mirrored, tpu and multiworker.')
help='The Cloud TPU to use for training. This should be either the name '
'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 '
'url.') def strategy_flags_dict():
# Parameters for MultiWorkerMirroredStrategy """Returns TPU and/or GPU related flags in a dictionary."""
flags.DEFINE_string( return {
'worker_hosts', # TPUStrategy related flags.
default=None, 'tpu': FLAGS.tpu,
help='Comma-separated list of worker ip:port pairs for running ' # MultiWorkerMirroredStrategy related flags.
'multi-worker models with distribution strategy. The user would ' 'all_reduce_alg': FLAGS.all_reduce_alg,
'start the program on each host with identical value for this flag.') 'worker_hosts': FLAGS.worker_hosts,
flags.DEFINE_integer( 'task_index': FLAGS.task_index,
'task_index', 0, # MirroredStrategy and OneDeviceStrategy
'If multi-worker training, the task_index of this worker.') 'num_gpus': FLAGS.num_gpus,
flags.DEFINE_integer('save_checkpoint_freq', None, 'num_packs': FLAGS.num_packs,
'Number of steps to save checkpoint.') }
return key_flags
def hparam_flags_dict():
"""Returns model params related flags in a dictionary."""
return {
'data_dir': FLAGS.data_dir,
'model_dir': FLAGS.model_dir,
'train_batch_size': FLAGS.train_batch_size,
'eval_batch_size': FLAGS.eval_batch_size,
'precision': FLAGS.precision,
'config_file': FLAGS.config_file,
'params_override': FLAGS.params_override,
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment