Commit 33a4c207 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

[Clean up] Consolidate distribution utils.

PiperOrigin-RevId: 331359058
parent 41a1e1d6
...@@ -19,13 +19,13 @@ from absl import app ...@@ -19,13 +19,13 @@ from absl import app
from absl import flags from absl import flags
import gin import gin
from official.common import distribute_utils
from official.common import flags as tfm_flags from official.common import flags as tfm_flags
from official.common import registry_imports # pylint: disable=unused-import from official.common import registry_imports # pylint: disable=unused-import
from official.core import task_factory from official.core import task_factory
from official.core import train_lib from official.core import train_lib
from official.core import train_utils from official.core import train_utils
from official.modeling import performance from official.modeling import performance
from official.utils.misc import distribution_utils
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -46,7 +46,7 @@ def main(_): ...@@ -46,7 +46,7 @@ def main(_):
if params.runtime.mixed_precision_dtype: if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype, performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale) params.runtime.loss_scale)
distribution_strategy = distribution_utils.get_distribution_strategy( distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy, distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg, all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus, num_gpus=params.runtime.num_gpus,
......
...@@ -14,28 +14,19 @@ ...@@ -14,28 +14,19 @@
# ============================================================================== # ==============================================================================
"""Main function to train various object detection models.""" """Main function to train various object detection models."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import functools import functools
import pprint import pprint
# pylint: disable=g-bad-import-order
# Import libraries
import tensorflow as tf
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
# pylint: enable=g-bad-import-order import tensorflow as tf
from official.common import distribute_utils
from official.modeling.hyperparams import params_dict from official.modeling.hyperparams import params_dict
from official.modeling.training import distributed_executor as executor from official.modeling.training import distributed_executor as executor
from official.utils import hyperparams_flags from official.utils import hyperparams_flags
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.vision.detection.configs import factory as config_factory from official.vision.detection.configs import factory as config_factory
from official.vision.detection.dataloader import input_reader from official.vision.detection.dataloader import input_reader
...@@ -87,9 +78,9 @@ def run_executor(params, ...@@ -87,9 +78,9 @@ def run_executor(params,
strategy = prebuilt_strategy strategy = prebuilt_strategy
else: else:
strategy_config = params.strategy_config strategy_config = params.strategy_config
distribution_utils.configure_cluster(strategy_config.worker_hosts, distribute_utils.configure_cluster(strategy_config.worker_hosts,
strategy_config.task_index) strategy_config.task_index)
strategy = distribution_utils.get_distribution_strategy( strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.strategy_type, distribution_strategy=params.strategy_type,
num_gpus=strategy_config.num_gpus, num_gpus=strategy_config.num_gpus,
all_reduce_alg=strategy_config.all_reduce_alg, all_reduce_alg=strategy_config.all_reduce_alg,
......
...@@ -23,11 +23,10 @@ from absl import app ...@@ -23,11 +23,10 @@ from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.common import distribute_utils
from official.modeling import hyperparams from official.modeling import hyperparams
from official.modeling import performance from official.modeling import performance
from official.utils import hyperparams_flags from official.utils import hyperparams_flags
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.vision.image_classification import callbacks as custom_callbacks from official.vision.image_classification import callbacks as custom_callbacks
from official.vision.image_classification import dataset_factory from official.vision.image_classification import dataset_factory
...@@ -291,17 +290,17 @@ def train_and_eval( ...@@ -291,17 +290,17 @@ def train_and_eval(
"""Runs the train and eval path using compile/fit.""" """Runs the train and eval path using compile/fit."""
logging.info('Running train and eval.') logging.info('Running train and eval.')
distribution_utils.configure_cluster(params.runtime.worker_hosts, distribute_utils.configure_cluster(params.runtime.worker_hosts,
params.runtime.task_index) params.runtime.task_index)
# Note: for TPUs, strategy and scope should be created before the dataset # Note: for TPUs, strategy and scope should be created before the dataset
strategy = strategy_override or distribution_utils.get_distribution_strategy( strategy = strategy_override or distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy, distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg, all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus, num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu) tpu_address=params.runtime.tpu)
strategy_scope = distribution_utils.get_strategy_scope(strategy) strategy_scope = distribute_utils.get_strategy_scope(strategy)
logging.info('Detected %d devices.', logging.info('Detected %d devices.',
strategy.num_replicas_in_sync if strategy else 1) strategy.num_replicas_in_sync if strategy else 1)
......
...@@ -25,9 +25,8 @@ from absl import flags ...@@ -25,9 +25,8 @@ from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
import tensorflow_datasets as tfds import tensorflow_datasets as tfds
from official.common import distribute_utils
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import model_helpers from official.utils.misc import model_helpers
from official.vision.image_classification.resnet import common from official.vision.image_classification.resnet import common
...@@ -82,12 +81,12 @@ def run(flags_obj, datasets_override=None, strategy_override=None): ...@@ -82,12 +81,12 @@ def run(flags_obj, datasets_override=None, strategy_override=None):
Returns: Returns:
Dictionary of training and eval stats. Dictionary of training and eval stats.
""" """
strategy = strategy_override or distribution_utils.get_distribution_strategy( strategy = strategy_override or distribute_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy, distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus, num_gpus=flags_obj.num_gpus,
tpu_address=flags_obj.tpu) tpu_address=flags_obj.tpu)
strategy_scope = distribution_utils.get_strategy_scope(strategy) strategy_scope = distribute_utils.get_strategy_scope(strategy)
mnist = tfds.builder('mnist', data_dir=flags_obj.data_dir) mnist = tfds.builder('mnist', data_dir=flags_obj.data_dir)
if flags_obj.download: if flags_obj.download:
......
...@@ -23,10 +23,9 @@ from absl import flags ...@@ -23,10 +23,9 @@ from absl import flags
from absl import logging from absl import logging
import orbit import orbit
import tensorflow as tf import tensorflow as tf
from official.common import distribute_utils
from official.modeling import performance from official.modeling import performance
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.utils.misc import model_helpers from official.utils.misc import model_helpers
from official.vision.image_classification.resnet import common from official.vision.image_classification.resnet import common
...@@ -117,7 +116,7 @@ def run(flags_obj): ...@@ -117,7 +116,7 @@ def run(flags_obj):
else 'channels_last') else 'channels_last')
tf.keras.backend.set_image_data_format(data_format) tf.keras.backend.set_image_data_format(data_format)
strategy = distribution_utils.get_distribution_strategy( strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy, distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus, num_gpus=flags_obj.num_gpus,
all_reduce_alg=flags_obj.all_reduce_alg, all_reduce_alg=flags_obj.all_reduce_alg,
...@@ -144,7 +143,7 @@ def run(flags_obj): ...@@ -144,7 +143,7 @@ def run(flags_obj):
flags_obj.batch_size, flags_obj.batch_size,
flags_obj.log_steps, flags_obj.log_steps,
logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None) logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None)
with distribution_utils.get_strategy_scope(strategy): with distribute_utils.get_strategy_scope(strategy):
runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback, runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback,
per_epoch_steps) per_epoch_steps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment