Commit 9ff61873 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

[Cleanup] Replace tf.distribute.experimental.TPUStrategy with tf.distribute.TPUStrategy

PiperOrigin-RevId: 342770296
parent 09cb3dff
...@@ -137,7 +137,7 @@ def get_distribution_strategy(distribution_strategy="mirrored", ...@@ -137,7 +137,7 @@ def get_distribution_strategy(distribution_strategy="mirrored",
if distribution_strategy == "tpu": if distribution_strategy == "tpu":
# When tpu_address is an empty string, we communicate with local TPUs. # When tpu_address is an empty string, we communicate with local TPUs.
cluster_resolver = tpu_initialize(tpu_address) cluster_resolver = tpu_initialize(tpu_address)
return tf.distribute.experimental.TPUStrategy(cluster_resolver) return tf.distribute.TPUStrategy(cluster_resolver)
if distribution_strategy == "multi_worker_mirrored": if distribution_strategy == "multi_worker_mirrored":
return tf.distribute.experimental.MultiWorkerMirroredStrategy( return tf.distribute.experimental.MultiWorkerMirroredStrategy(
......
...@@ -245,7 +245,9 @@ def run_customized_training_loop( ...@@ -245,7 +245,9 @@ def run_customized_training_loop(
assert tf.executing_eagerly() assert tf.executing_eagerly()
if run_eagerly: if run_eagerly:
if isinstance(strategy, tf.distribute.experimental.TPUStrategy): if isinstance(
strategy,
(tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy)):
raise ValueError( raise ValueError(
'TPUStrategy should not run eagerly as it heavily relies on graph' 'TPUStrategy should not run eagerly as it heavily relies on graph'
' optimization for the distributed system.') ' optimization for the distributed system.')
......
...@@ -186,7 +186,9 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -186,7 +186,9 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(eager_strategy_combinations()) @combinations.generate(eager_strategy_combinations())
def test_train_eager_single_step(self, distribution): def test_train_eager_single_step(self, distribution):
model_dir = self.create_tempdir().full_path model_dir = self.create_tempdir().full_path
if isinstance(distribution, tf.distribute.experimental.TPUStrategy): if isinstance(
distribution,
(tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy)):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
self.run_training( self.run_training(
distribution, model_dir, steps_per_loop=1, run_eagerly=True) distribution, model_dir, steps_per_loop=1, run_eagerly=True)
......
...@@ -66,8 +66,9 @@ class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -66,8 +66,9 @@ class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase):
mode="eager")) mode="eager"))
def test_create_model_with_ds(self, distribution): def test_create_model_with_ds(self, distribution):
with distribution.scope(): with distribution.scope():
padded_decode = isinstance(distribution, padded_decode = isinstance(
tf.distribute.experimental.TPUStrategy) distribution,
(tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy))
decode_max_length = 10 decode_max_length = 10
batch_size = 4 batch_size = 4
model = self._build_model(padded_decode, decode_max_length) model = self._build_model(padded_decode, decode_max_length)
......
...@@ -218,7 +218,7 @@ def get_input_dataset(input_file_pattern, ...@@ -218,7 +218,7 @@ def get_input_dataset(input_file_pattern,
# When using TPU pods, we need to clone dataset across # When using TPU pods, we need to clone dataset across
# workers and need to pass in function that returns the dataset rather # workers and need to pass in function that returns the dataset rather
# than passing dataset instance itself. # than passing dataset instance itself.
use_dataset_fn = isinstance(strategy, tf.distribute.experimental.TPUStrategy) use_dataset_fn = isinstance(strategy, tf.distribute.TPUStrategy)
if use_dataset_fn: if use_dataset_fn:
if batch_size % strategy.num_replicas_in_sync != 0: if batch_size % strategy.num_replicas_in_sync != 0:
raise ValueError( raise ValueError(
......
...@@ -179,8 +179,9 @@ class Bert2BertTest(tf.test.TestCase, parameterized.TestCase): ...@@ -179,8 +179,9 @@ class Bert2BertTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(all_strategy_combinations()) @combinations.generate(all_strategy_combinations())
def test_bert2bert_eval(self, distribution): def test_bert2bert_eval(self, distribution):
seq_length = 10 seq_length = 10
padded_decode = isinstance(distribution, padded_decode = isinstance(
tf.distribute.experimental.TPUStrategy) distribution,
(tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy))
self._config.override( self._config.override(
{ {
"beam_size": 3, "beam_size": 3,
...@@ -286,8 +287,9 @@ class NHNetTest(tf.test.TestCase, parameterized.TestCase): ...@@ -286,8 +287,9 @@ class NHNetTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(all_strategy_combinations()) @combinations.generate(all_strategy_combinations())
def test_nhnet_eval(self, distribution): def test_nhnet_eval(self, distribution):
seq_length = 10 seq_length = 10
padded_decode = isinstance(distribution, padded_decode = isinstance(
tf.distribute.experimental.TPUStrategy) distribution,
(tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy))
self._nhnet_config.override( self._nhnet_config.override(
{ {
"beam_size": 4, "beam_size": 4,
......
...@@ -210,7 +210,7 @@ def run(): ...@@ -210,7 +210,7 @@ def run():
if "eval" in FLAGS.mode: if "eval" in FLAGS.mode:
timeout = 0 if FLAGS.mode == "train_and_eval" else FLAGS.eval_timeout timeout = 0 if FLAGS.mode == "train_and_eval" else FLAGS.eval_timeout
# Uses padded decoding for TPU. Always uses cache. # Uses padded decoding for TPU. Always uses cache.
padded_decode = isinstance(strategy, tf.distribute.experimental.TPUStrategy) padded_decode = isinstance(strategy, tf.distribute.TPUStrategy)
params.override({ params.override({
"padded_decode": padded_decode, "padded_decode": padded_decode,
}, is_strict=False) }, is_strict=False)
......
...@@ -182,8 +182,7 @@ class TransformerTask(object): ...@@ -182,8 +182,7 @@ class TransformerTask(object):
@property @property
def use_tpu(self): def use_tpu(self):
if self.distribution_strategy: if self.distribution_strategy:
return isinstance(self.distribution_strategy, return isinstance(self.distribution_strategy, tf.distribute.TPUStrategy)
tf.distribute.experimental.TPUStrategy)
return False return False
def train(self): def train(self):
......
...@@ -175,7 +175,7 @@ def get_classification_input_data(batch_size, seq_len, strategy, is_training, ...@@ -175,7 +175,7 @@ def get_classification_input_data(batch_size, seq_len, strategy, is_training,
# When using TPU pods, we need to clone dataset across # When using TPU pods, we need to clone dataset across
# workers and need to pass in function that returns the dataset rather # workers and need to pass in function that returns the dataset rather
# than passing dataset instance itself. # than passing dataset instance itself.
use_dataset_fn = isinstance(strategy, tf.distribute.experimental.TPUStrategy) use_dataset_fn = isinstance(strategy, tf.distribute.TPUStrategy)
if use_dataset_fn: if use_dataset_fn:
if batch_size % strategy.num_replicas_in_sync != 0: if batch_size % strategy.num_replicas_in_sync != 0:
raise ValueError( raise ValueError(
...@@ -208,7 +208,7 @@ def get_squad_input_data(batch_size, seq_len, q_len, strategy, is_training, ...@@ -208,7 +208,7 @@ def get_squad_input_data(batch_size, seq_len, q_len, strategy, is_training,
# When using TPU pods, we need to clone dataset across # When using TPU pods, we need to clone dataset across
# workers and need to pass in function that returns the dataset rather # workers and need to pass in function that returns the dataset rather
# than passing dataset instance itself. # than passing dataset instance itself.
use_dataset_fn = isinstance(strategy, tf.distribute.experimental.TPUStrategy) use_dataset_fn = isinstance(strategy, tf.distribute.TPUStrategy)
if use_dataset_fn: if use_dataset_fn:
if batch_size % strategy.num_replicas_in_sync != 0: if batch_size % strategy.num_replicas_in_sync != 0:
raise ValueError( raise ValueError(
...@@ -592,7 +592,7 @@ def get_pretrain_input_data(batch_size, ...@@ -592,7 +592,7 @@ def get_pretrain_input_data(batch_size,
# When using TPU pods, we need to clone dataset across # When using TPU pods, we need to clone dataset across
# workers and need to pass in function that returns the dataset rather # workers and need to pass in function that returns the dataset rather
# than passing dataset instance itself. # than passing dataset instance itself.
use_dataset_fn = isinstance(strategy, tf.distribute.experimental.TPUStrategy) use_dataset_fn = isinstance(strategy, tf.distribute.TPUStrategy)
split = "train" split = "train"
bsz_per_host = int(batch_size / num_hosts) bsz_per_host = int(batch_size / num_hosts)
record_glob_base = format_filename( record_glob_base = format_filename(
......
...@@ -135,7 +135,7 @@ def get_v1_distribution_strategy(params): ...@@ -135,7 +135,7 @@ def get_v1_distribution_strategy(params):
} }
os.environ["TF_CONFIG"] = json.dumps(tf_config_env) os.environ["TF_CONFIG"] = json.dumps(tf_config_env)
distribution = tf.distribute.experimental.TPUStrategy( distribution = tf.distribute.TPUStrategy(
tpu_cluster_resolver, steps_per_run=100) tpu_cluster_resolver, steps_per_run=100)
else: else:
......
...@@ -135,7 +135,7 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -135,7 +135,7 @@ class SemanticSegmentationTask(base_task.Task):
if training: if training:
# TODO(arashwan): make MeanIoU tpu friendly. # TODO(arashwan): make MeanIoU tpu friendly.
if not isinstance(tf.distribute.get_strategy(), if not isinstance(tf.distribute.get_strategy(),
tf.distribute.experimental.TPUStrategy): tf.distribute.TPUStrategy):
metrics.append(segmentation_metrics.MeanIoU( metrics.append(segmentation_metrics.MeanIoU(
name='mean_iou', name='mean_iou',
num_classes=self.task_config.model.num_classes, num_classes=self.task_config.model.num_classes,
......
...@@ -43,7 +43,7 @@ builder to 'records' or 'tfds' in the configurations. ...@@ -43,7 +43,7 @@ builder to 'records' or 'tfds' in the configurations.
Note: These models will **not** work with TPUs on Colab. Note: These models will **not** work with TPUs on Colab.
You can train image classification models on Cloud TPUs using You can train image classification models on Cloud TPUs using
[tf.distribute.experimental.TPUStrategy](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/TPUStrategy?version=nightly). [tf.distribute.TPUStrategy](https://www.tensorflow.org/api_docs/python/tf.distribute.TPUStrategy?version=nightly).
If you are not familiar with Cloud TPUs, it is strongly recommended that you go If you are not familiar with Cloud TPUs, it is strongly recommended that you go
through the through the
[quickstart](https://cloud.google.com/tpu/docs/quickstart) to learn how to [quickstart](https://cloud.google.com/tpu/docs/quickstart) to learn how to
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment