Commit 420fd1cf authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

[Clean up] Consolidate distribution utils.

PiperOrigin-RevId: 331359058
parent c7647f11
......@@ -19,13 +19,13 @@ from absl import app
from absl import flags
import gin
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.common import registry_imports # pylint: disable=unused-import
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
from official.utils.misc import distribution_utils
FLAGS = flags.FLAGS
......@@ -46,7 +46,7 @@ def main(_):
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale)
distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
......
......@@ -14,28 +14,19 @@
# ==============================================================================
"""Main function to train various object detection models."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import functools
import pprint
# pylint: disable=g-bad-import-order
# Import libraries
import tensorflow as tf
from absl import app
from absl import flags
from absl import logging
# pylint: enable=g-bad-import-order
import tensorflow as tf
from official.common import distribute_utils
from official.modeling.hyperparams import params_dict
from official.modeling.training import distributed_executor as executor
from official.utils import hyperparams_flags
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.vision.detection.configs import factory as config_factory
from official.vision.detection.dataloader import input_reader
......@@ -87,9 +78,9 @@ def run_executor(params,
strategy = prebuilt_strategy
else:
strategy_config = params.strategy_config
distribution_utils.configure_cluster(strategy_config.worker_hosts,
strategy_config.task_index)
strategy = distribution_utils.get_distribution_strategy(
distribute_utils.configure_cluster(strategy_config.worker_hosts,
strategy_config.task_index)
strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.strategy_type,
num_gpus=strategy_config.num_gpus,
all_reduce_alg=strategy_config.all_reduce_alg,
......
......@@ -23,11 +23,10 @@ from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.common import distribute_utils
from official.modeling import hyperparams
from official.modeling import performance
from official.utils import hyperparams_flags
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.vision.image_classification import callbacks as custom_callbacks
from official.vision.image_classification import dataset_factory
......@@ -291,17 +290,17 @@ def train_and_eval(
"""Runs the train and eval path using compile/fit."""
logging.info('Running train and eval.')
distribution_utils.configure_cluster(params.runtime.worker_hosts,
params.runtime.task_index)
distribute_utils.configure_cluster(params.runtime.worker_hosts,
params.runtime.task_index)
# Note: for TPUs, strategy and scope should be created before the dataset
strategy = strategy_override or distribution_utils.get_distribution_strategy(
strategy = strategy_override or distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
strategy_scope = distribution_utils.get_strategy_scope(strategy)
strategy_scope = distribute_utils.get_strategy_scope(strategy)
logging.info('Detected %d devices.',
strategy.num_replicas_in_sync if strategy else 1)
......
......@@ -25,9 +25,8 @@ from absl import flags
from absl import logging
import tensorflow as tf
import tensorflow_datasets as tfds
from official.common import distribute_utils
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import model_helpers
from official.vision.image_classification.resnet import common
......@@ -82,12 +81,12 @@ def run(flags_obj, datasets_override=None, strategy_override=None):
Returns:
Dictionary of training and eval stats.
"""
strategy = strategy_override or distribution_utils.get_distribution_strategy(
strategy = strategy_override or distribute_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus,
tpu_address=flags_obj.tpu)
strategy_scope = distribution_utils.get_strategy_scope(strategy)
strategy_scope = distribute_utils.get_strategy_scope(strategy)
mnist = tfds.builder('mnist', data_dir=flags_obj.data_dir)
if flags_obj.download:
......
......@@ -23,10 +23,9 @@ from absl import flags
from absl import logging
import orbit
import tensorflow as tf
from official.common import distribute_utils
from official.modeling import performance
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.utils.misc import model_helpers
from official.vision.image_classification.resnet import common
......@@ -117,7 +116,7 @@ def run(flags_obj):
else 'channels_last')
tf.keras.backend.set_image_data_format(data_format)
strategy = distribution_utils.get_distribution_strategy(
strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus,
all_reduce_alg=flags_obj.all_reduce_alg,
......@@ -144,7 +143,7 @@ def run(flags_obj):
flags_obj.batch_size,
flags_obj.log_steps,
logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None)
with distribution_utils.get_strategy_scope(strategy):
with distribute_utils.get_strategy_scope(strategy):
runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback,
per_epoch_steps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment