Commit faf7a4f5 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

[Clean up] Consolidate distribution utils.

PiperOrigin-RevId: 331359058
parent 420fd1cf
...@@ -14,29 +14,22 @@ ...@@ -14,29 +14,22 @@
# ============================================================================== # ==============================================================================
"""Executes BERT benchmarks and accuracy tests.""" """Executes BERT benchmarks and accuracy tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools import functools
import json import json
import math import math
import os import os
import time import time
# pylint: disable=g-bad-import-order
from absl import flags from absl import flags
from absl.testing import flagsaver from absl.testing import flagsaver
import tensorflow as tf import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.benchmark import benchmark_wrappers
from official.benchmark import bert_benchmark_utils as benchmark_utils from official.benchmark import bert_benchmark_utils as benchmark_utils
from official.benchmark import owner_utils from official.benchmark import owner_utils
from official.common import distribute_utils
from official.nlp.bert import configs from official.nlp.bert import configs
from official.nlp.bert import run_classifier from official.nlp.bert import run_classifier
from official.utils.misc import distribution_utils
from official.benchmark import benchmark_wrappers
# pylint: disable=line-too-long # pylint: disable=line-too-long
PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_model.ckpt' PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_model.ckpt'
...@@ -76,10 +69,10 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase): ...@@ -76,10 +69,10 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
eval_steps = int( eval_steps = int(
math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size)) math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size))
if self.tpu: if self.tpu:
strategy = distribution_utils.get_distribution_strategy( strategy = distribute_utils.get_distribution_strategy(
distribution_strategy='tpu', tpu_address=self.tpu) distribution_strategy='tpu', tpu_address=self.tpu)
else: else:
strategy = distribution_utils.get_distribution_strategy( strategy = distribute_utils.get_distribution_strategy(
distribution_strategy='mirrored' if use_ds else 'off', distribution_strategy='mirrored' if use_ds else 'off',
num_gpus=self.num_gpus) num_gpus=self.num_gpus)
......
...@@ -15,8 +15,6 @@ ...@@ -15,8 +15,6 @@
# ============================================================================== # ==============================================================================
"""Executes benchmark testing for bert pretraining.""" """Executes benchmark testing for bert pretraining."""
# pylint: disable=line-too-long # pylint: disable=line-too-long
from __future__ import print_function
import json import json
import os import os
import time import time
...@@ -24,14 +22,14 @@ from typing import Optional ...@@ -24,14 +22,14 @@ from typing import Optional
from absl import flags from absl import flags
from absl import logging from absl import logging
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf
from official.benchmark import benchmark_wrappers from official.benchmark import benchmark_wrappers
from official.benchmark import bert_benchmark_utils from official.benchmark import bert_benchmark_utils
from official.benchmark import owner_utils from official.benchmark import owner_utils
from official.common import distribute_utils
from official.nlp.bert import run_pretraining from official.nlp.bert import run_pretraining
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
# Pretrain masked lanauge modeling accuracy range: # Pretrain masked lanauge modeling accuracy range:
MIN_MLM_ACCURACY = 0.635 MIN_MLM_ACCURACY = 0.635
...@@ -85,13 +83,13 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase): ...@@ -85,13 +83,13 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
A `tf.distribute.DistibutionStrategy` object. A `tf.distribute.DistibutionStrategy` object.
""" """
if self.tpu or ds_type == 'tpu': if self.tpu or ds_type == 'tpu':
return distribution_utils.get_distribution_strategy( return distribute_utils.get_distribution_strategy(
distribution_strategy='tpu', tpu_address=self.tpu) distribution_strategy='tpu', tpu_address=self.tpu)
elif ds_type == 'multi_worker_mirrored': elif ds_type == 'multi_worker_mirrored':
# Configures cluster spec for multi-worker distribution strategy. # Configures cluster spec for multi-worker distribution strategy.
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts, _ = distribute_utils.configure_cluster(FLAGS.worker_hosts,
FLAGS.task_index) FLAGS.task_index)
return distribution_utils.get_distribution_strategy( return distribute_utils.get_distribution_strategy(
distribution_strategy=ds_type, distribution_strategy=ds_type,
num_gpus=FLAGS.num_gpus, num_gpus=FLAGS.num_gpus,
all_reduce_alg=FLAGS.all_reduce_alg) all_reduce_alg=FLAGS.all_reduce_alg)
......
...@@ -13,29 +13,21 @@ ...@@ -13,29 +13,21 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Executes BERT SQuAD benchmarks and accuracy tests.""" """Executes BERT SQuAD benchmarks and accuracy tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json import json
import os import os
import time import time
# pylint: disable=g-bad-import-order
from absl import flags from absl import flags
from absl import logging from absl import logging
from absl.testing import flagsaver from absl.testing import flagsaver
import tensorflow as tf import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.benchmark import benchmark_wrappers
from official.benchmark import bert_benchmark_utils as benchmark_utils from official.benchmark import bert_benchmark_utils as benchmark_utils
from official.benchmark import owner_utils from official.benchmark import owner_utils
from official.common import distribute_utils
from official.nlp.bert import run_squad from official.nlp.bert import run_squad
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.benchmark import benchmark_wrappers
# pylint: disable=line-too-long # pylint: disable=line-too-long
...@@ -83,13 +75,13 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase): ...@@ -83,13 +75,13 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
A `tf.distribute.DistibutionStrategy` object. A `tf.distribute.DistibutionStrategy` object.
""" """
if self.tpu or ds_type == 'tpu': if self.tpu or ds_type == 'tpu':
return distribution_utils.get_distribution_strategy( return distribute_utils.get_distribution_strategy(
distribution_strategy='tpu', tpu_address=self.tpu) distribution_strategy='tpu', tpu_address=self.tpu)
elif ds_type == 'multi_worker_mirrored': elif ds_type == 'multi_worker_mirrored':
# Configures cluster spec for multi-worker distribution strategy. # Configures cluster spec for multi-worker distribution strategy.
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts, _ = distribute_utils.configure_cluster(FLAGS.worker_hosts,
FLAGS.task_index) FLAGS.task_index)
return distribution_utils.get_distribution_strategy( return distribute_utils.get_distribution_strategy(
distribution_strategy=ds_type, distribution_strategy=ds_type,
num_gpus=self.num_gpus, num_gpus=self.num_gpus,
all_reduce_alg=FLAGS.all_reduce_alg) all_reduce_alg=FLAGS.all_reduce_alg)
......
...@@ -27,8 +27,8 @@ import tensorflow as tf ...@@ -27,8 +27,8 @@ import tensorflow as tf
from official.benchmark.models import cifar_preprocessing from official.benchmark.models import cifar_preprocessing
from official.benchmark.models import resnet_cifar_model from official.benchmark.models import resnet_cifar_model
from official.benchmark.models import synthetic_util from official.benchmark.models import synthetic_util
from official.common import distribute_utils
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.vision.image_classification.resnet import common from official.vision.image_classification.resnet import common
...@@ -142,7 +142,7 @@ def run(flags_obj): ...@@ -142,7 +142,7 @@ def run(flags_obj):
else 'channels_last') else 'channels_last')
tf.keras.backend.set_image_data_format(data_format) tf.keras.backend.set_image_data_format(data_format)
strategy = distribution_utils.get_distribution_strategy( strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy, distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus, num_gpus=flags_obj.num_gpus,
all_reduce_alg=flags_obj.all_reduce_alg, all_reduce_alg=flags_obj.all_reduce_alg,
...@@ -156,7 +156,7 @@ def run(flags_obj): ...@@ -156,7 +156,7 @@ def run(flags_obj):
flags_obj.enable_get_next_as_optional flags_obj.enable_get_next_as_optional
) )
strategy_scope = distribution_utils.get_strategy_scope(strategy) strategy_scope = distribute_utils.get_strategy_scope(strategy)
if flags_obj.use_synthetic_data: if flags_obj.use_synthetic_data:
synthetic_util.set_up_synthetic_data() synthetic_util.set_up_synthetic_data()
......
...@@ -25,10 +25,9 @@ from absl import app ...@@ -25,10 +25,9 @@ from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.common import distribute_utils
from official.modeling import performance from official.modeling import performance
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.utils.misc import model_helpers from official.utils.misc import model_helpers
from official.vision.image_classification import test_utils from official.vision.image_classification import test_utils
...@@ -102,10 +101,10 @@ def run(flags_obj): ...@@ -102,10 +101,10 @@ def run(flags_obj):
tf.keras.backend.set_image_data_format(data_format) tf.keras.backend.set_image_data_format(data_format)
# Configures cluster spec for distribution strategy. # Configures cluster spec for distribution strategy.
_ = distribution_utils.configure_cluster(flags_obj.worker_hosts, _ = distribute_utils.configure_cluster(flags_obj.worker_hosts,
flags_obj.task_index) flags_obj.task_index)
strategy = distribution_utils.get_distribution_strategy( strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy, distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus, num_gpus=flags_obj.num_gpus,
all_reduce_alg=flags_obj.all_reduce_alg, all_reduce_alg=flags_obj.all_reduce_alg,
...@@ -120,7 +119,7 @@ def run(flags_obj): ...@@ -120,7 +119,7 @@ def run(flags_obj):
flags_obj.enable_get_next_as_optional flags_obj.enable_get_next_as_optional
) )
strategy_scope = distribution_utils.get_strategy_scope(strategy) strategy_scope = distribute_utils.get_strategy_scope(strategy)
# pylint: disable=protected-access # pylint: disable=protected-access
if flags_obj.use_synthetic_data: if flags_obj.use_synthetic_data:
......
...@@ -26,10 +26,10 @@ from absl import app ...@@ -26,10 +26,10 @@ from absl import app
from absl import flags from absl import flags
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from official.common import distribute_utils
# pylint: enable=wrong-import-order # pylint: enable=wrong-import-order
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
EMBEDDING_DIM = 256 EMBEDDING_DIM = 256
...@@ -177,14 +177,14 @@ def train_model(flags_obj, dataset, vocab_size, strategy, checkpoint_dir=None): ...@@ -177,14 +177,14 @@ def train_model(flags_obj, dataset, vocab_size, strategy, checkpoint_dir=None):
train_steps = flags_obj.train_steps train_steps = flags_obj.train_steps
else: else:
train_steps = BATCHES_PER_EPOCH // flags_obj.batch_size train_steps = BATCHES_PER_EPOCH // flags_obj.batch_size
strategy_scope = distribution_utils.get_strategy_scope(strategy) strategy_scope = distribute_utils.get_strategy_scope(strategy)
with strategy_scope: with strategy_scope:
model = build_model(vocab_size=vocab_size, batch_size=flags_obj.batch_size, model = build_model(vocab_size=vocab_size, batch_size=flags_obj.batch_size,
use_cudnn=flags_obj.cudnn) use_cudnn=flags_obj.cudnn)
# When keras_use_ctl is False, Model.fit() automatically applies # When keras_use_ctl is False, Model.fit() automatically applies
# loss scaling so we don't need to create a LossScaleOptimizer. # loss scaling so we don't need to create a LossScaleOptimizer.
model.compile( model.compile(
optimizer=tf.keras.optimizers.Adam(), optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.CategoricalCrossentropy(), loss=tf.keras.losses.CategoricalCrossentropy(),
...@@ -276,7 +276,7 @@ def run(flags_obj): ...@@ -276,7 +276,7 @@ def run(flags_obj):
keras_utils.set_session_config( keras_utils.set_session_config(
enable_xla=flags_obj.enable_xla) enable_xla=flags_obj.enable_xla)
strategy = distribution_utils.get_distribution_strategy( strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy, distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus) num_gpus=flags_obj.num_gpus)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment