Commit faf7a4f5 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

[Clean up] Consolidate distribution utils.

PiperOrigin-RevId: 331359058
parent 420fd1cf
......@@ -14,29 +14,22 @@
# ==============================================================================
"""Executes BERT benchmarks and accuracy tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import json
import math
import os
import time
# pylint: disable=g-bad-import-order
from absl import flags
from absl.testing import flagsaver
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.benchmark import benchmark_wrappers
from official.benchmark import bert_benchmark_utils as benchmark_utils
from official.benchmark import owner_utils
from official.common import distribute_utils
from official.nlp.bert import configs
from official.nlp.bert import run_classifier
from official.utils.misc import distribution_utils
from official.benchmark import benchmark_wrappers
# pylint: disable=line-too-long
PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_model.ckpt'
......@@ -76,10 +69,10 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
eval_steps = int(
math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size))
if self.tpu:
strategy = distribution_utils.get_distribution_strategy(
strategy = distribute_utils.get_distribution_strategy(
distribution_strategy='tpu', tpu_address=self.tpu)
else:
strategy = distribution_utils.get_distribution_strategy(
strategy = distribute_utils.get_distribution_strategy(
distribution_strategy='mirrored' if use_ds else 'off',
num_gpus=self.num_gpus)
......
......@@ -15,8 +15,6 @@
# ==============================================================================
"""Executes benchmark testing for bert pretraining."""
# pylint: disable=line-too-long
from __future__ import print_function
import json
import os
import time
......@@ -24,14 +22,14 @@ from typing import Optional
from absl import flags
from absl import logging
import tensorflow as tf # pylint: disable=g-bad-import-order
import tensorflow as tf
from official.benchmark import benchmark_wrappers
from official.benchmark import bert_benchmark_utils
from official.benchmark import owner_utils
from official.common import distribute_utils
from official.nlp.bert import run_pretraining
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
# Pretrain masked lanauge modeling accuracy range:
MIN_MLM_ACCURACY = 0.635
......@@ -85,13 +83,13 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
A `tf.distribute.DistibutionStrategy` object.
"""
if self.tpu or ds_type == 'tpu':
return distribution_utils.get_distribution_strategy(
return distribute_utils.get_distribution_strategy(
distribution_strategy='tpu', tpu_address=self.tpu)
elif ds_type == 'multi_worker_mirrored':
# Configures cluster spec for multi-worker distribution strategy.
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
FLAGS.task_index)
return distribution_utils.get_distribution_strategy(
_ = distribute_utils.configure_cluster(FLAGS.worker_hosts,
FLAGS.task_index)
return distribute_utils.get_distribution_strategy(
distribution_strategy=ds_type,
num_gpus=FLAGS.num_gpus,
all_reduce_alg=FLAGS.all_reduce_alg)
......
......@@ -13,29 +13,21 @@
# limitations under the License.
# ==============================================================================
"""Executes BERT SQuAD benchmarks and accuracy tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
import time
# pylint: disable=g-bad-import-order
from absl import flags
from absl import logging
from absl.testing import flagsaver
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.benchmark import benchmark_wrappers
from official.benchmark import bert_benchmark_utils as benchmark_utils
from official.benchmark import owner_utils
from official.common import distribute_utils
from official.nlp.bert import run_squad
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.benchmark import benchmark_wrappers
# pylint: disable=line-too-long
......@@ -83,13 +75,13 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
A `tf.distribute.DistibutionStrategy` object.
"""
if self.tpu or ds_type == 'tpu':
return distribution_utils.get_distribution_strategy(
return distribute_utils.get_distribution_strategy(
distribution_strategy='tpu', tpu_address=self.tpu)
elif ds_type == 'multi_worker_mirrored':
# Configures cluster spec for multi-worker distribution strategy.
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
FLAGS.task_index)
return distribution_utils.get_distribution_strategy(
_ = distribute_utils.configure_cluster(FLAGS.worker_hosts,
FLAGS.task_index)
return distribute_utils.get_distribution_strategy(
distribution_strategy=ds_type,
num_gpus=self.num_gpus,
all_reduce_alg=FLAGS.all_reduce_alg)
......
......@@ -27,8 +27,8 @@ import tensorflow as tf
from official.benchmark.models import cifar_preprocessing
from official.benchmark.models import resnet_cifar_model
from official.benchmark.models import synthetic_util
from official.common import distribute_utils
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.vision.image_classification.resnet import common
......@@ -142,7 +142,7 @@ def run(flags_obj):
else 'channels_last')
tf.keras.backend.set_image_data_format(data_format)
strategy = distribution_utils.get_distribution_strategy(
strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus,
all_reduce_alg=flags_obj.all_reduce_alg,
......@@ -156,7 +156,7 @@ def run(flags_obj):
flags_obj.enable_get_next_as_optional
)
strategy_scope = distribution_utils.get_strategy_scope(strategy)
strategy_scope = distribute_utils.get_strategy_scope(strategy)
if flags_obj.use_synthetic_data:
synthetic_util.set_up_synthetic_data()
......
......@@ -25,10 +25,9 @@ from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.common import distribute_utils
from official.modeling import performance
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.utils.misc import model_helpers
from official.vision.image_classification import test_utils
......@@ -102,10 +101,10 @@ def run(flags_obj):
tf.keras.backend.set_image_data_format(data_format)
# Configures cluster spec for distribution strategy.
_ = distribution_utils.configure_cluster(flags_obj.worker_hosts,
flags_obj.task_index)
_ = distribute_utils.configure_cluster(flags_obj.worker_hosts,
flags_obj.task_index)
strategy = distribution_utils.get_distribution_strategy(
strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus,
all_reduce_alg=flags_obj.all_reduce_alg,
......@@ -120,7 +119,7 @@ def run(flags_obj):
flags_obj.enable_get_next_as_optional
)
strategy_scope = distribution_utils.get_strategy_scope(strategy)
strategy_scope = distribute_utils.get_strategy_scope(strategy)
# pylint: disable=protected-access
if flags_obj.use_synthetic_data:
......
......@@ -26,10 +26,10 @@ from absl import app
from absl import flags
import numpy as np
import tensorflow as tf
from official.common import distribute_utils
# pylint: enable=wrong-import-order
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
EMBEDDING_DIM = 256
......@@ -177,14 +177,14 @@ def train_model(flags_obj, dataset, vocab_size, strategy, checkpoint_dir=None):
train_steps = flags_obj.train_steps
else:
train_steps = BATCHES_PER_EPOCH // flags_obj.batch_size
strategy_scope = distribution_utils.get_strategy_scope(strategy)
strategy_scope = distribute_utils.get_strategy_scope(strategy)
with strategy_scope:
model = build_model(vocab_size=vocab_size, batch_size=flags_obj.batch_size,
use_cudnn=flags_obj.cudnn)
# When keras_use_ctl is False, Model.fit() automatically applies
# loss scaling so we don't need to create a LossScaleOptimizer.
# When keras_use_ctl is False, Model.fit() automatically applies
# loss scaling so we don't need to create a LossScaleOptimizer.
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.CategoricalCrossentropy(),
......@@ -276,7 +276,7 @@ def run(flags_obj):
keras_utils.set_session_config(
enable_xla=flags_obj.enable_xla)
strategy = distribution_utils.get_distribution_strategy(
strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment