Commit b6ece654 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

[Cleanup] Replace tf.distribute.experimental.TPUStrategy with tf.distribute.TPUStrategy

PiperOrigin-RevId: 342770296
parent a9d13bf1
......@@ -137,7 +137,7 @@ def get_distribution_strategy(distribution_strategy="mirrored",
if distribution_strategy == "tpu":
# When tpu_address is an empty string, we communicate with local TPUs.
cluster_resolver = tpu_initialize(tpu_address)
return tf.distribute.experimental.TPUStrategy(cluster_resolver)
return tf.distribute.TPUStrategy(cluster_resolver)
if distribution_strategy == "multi_worker_mirrored":
return tf.distribute.experimental.MultiWorkerMirroredStrategy(
......
......@@ -245,7 +245,9 @@ def run_customized_training_loop(
assert tf.executing_eagerly()
if run_eagerly:
if isinstance(strategy, tf.distribute.experimental.TPUStrategy):
if isinstance(
strategy,
(tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy)):
raise ValueError(
'TPUStrategy should not run eagerly as it heavily relies on graph'
' optimization for the distributed system.')
......
......@@ -186,7 +186,9 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(eager_strategy_combinations())
def test_train_eager_single_step(self, distribution):
model_dir = self.create_tempdir().full_path
if isinstance(distribution, tf.distribute.experimental.TPUStrategy):
if isinstance(
distribution,
(tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy)):
with self.assertRaises(ValueError):
self.run_training(
distribution, model_dir, steps_per_loop=1, run_eagerly=True)
......
......@@ -66,8 +66,9 @@ class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase):
mode="eager"))
def test_create_model_with_ds(self, distribution):
with distribution.scope():
padded_decode = isinstance(distribution,
tf.distribute.experimental.TPUStrategy)
padded_decode = isinstance(
distribution,
(tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy))
decode_max_length = 10
batch_size = 4
model = self._build_model(padded_decode, decode_max_length)
......
......@@ -218,7 +218,7 @@ def get_input_dataset(input_file_pattern,
# When using TPU pods, we need to clone dataset across
# workers and need to pass in function that returns the dataset rather
# than passing dataset instance itself.
use_dataset_fn = isinstance(strategy, tf.distribute.experimental.TPUStrategy)
use_dataset_fn = isinstance(strategy, tf.distribute.TPUStrategy)
if use_dataset_fn:
if batch_size % strategy.num_replicas_in_sync != 0:
raise ValueError(
......
......@@ -179,8 +179,9 @@ class Bert2BertTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(all_strategy_combinations())
def test_bert2bert_eval(self, distribution):
seq_length = 10
padded_decode = isinstance(distribution,
tf.distribute.experimental.TPUStrategy)
padded_decode = isinstance(
distribution,
(tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy))
self._config.override(
{
"beam_size": 3,
......@@ -286,8 +287,9 @@ class NHNetTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(all_strategy_combinations())
def test_nhnet_eval(self, distribution):
seq_length = 10
padded_decode = isinstance(distribution,
tf.distribute.experimental.TPUStrategy)
padded_decode = isinstance(
distribution,
(tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy))
self._nhnet_config.override(
{
"beam_size": 4,
......
......@@ -210,7 +210,7 @@ def run():
if "eval" in FLAGS.mode:
timeout = 0 if FLAGS.mode == "train_and_eval" else FLAGS.eval_timeout
# Uses padded decoding for TPU. Always uses cache.
padded_decode = isinstance(strategy, tf.distribute.experimental.TPUStrategy)
padded_decode = isinstance(strategy, tf.distribute.TPUStrategy)
params.override({
"padded_decode": padded_decode,
}, is_strict=False)
......
......@@ -182,8 +182,7 @@ class TransformerTask(object):
@property
def use_tpu(self):
if self.distribution_strategy:
return isinstance(self.distribution_strategy,
tf.distribute.experimental.TPUStrategy)
return isinstance(self.distribution_strategy, tf.distribute.TPUStrategy)
return False
def train(self):
......
......@@ -175,7 +175,7 @@ def get_classification_input_data(batch_size, seq_len, strategy, is_training,
# When using TPU pods, we need to clone dataset across
# workers and need to pass in function that returns the dataset rather
# than passing dataset instance itself.
use_dataset_fn = isinstance(strategy, tf.distribute.experimental.TPUStrategy)
use_dataset_fn = isinstance(strategy, tf.distribute.TPUStrategy)
if use_dataset_fn:
if batch_size % strategy.num_replicas_in_sync != 0:
raise ValueError(
......@@ -208,7 +208,7 @@ def get_squad_input_data(batch_size, seq_len, q_len, strategy, is_training,
# When using TPU pods, we need to clone dataset across
# workers and need to pass in function that returns the dataset rather
# than passing dataset instance itself.
use_dataset_fn = isinstance(strategy, tf.distribute.experimental.TPUStrategy)
use_dataset_fn = isinstance(strategy, tf.distribute.TPUStrategy)
if use_dataset_fn:
if batch_size % strategy.num_replicas_in_sync != 0:
raise ValueError(
......@@ -592,7 +592,7 @@ def get_pretrain_input_data(batch_size,
# When using TPU pods, we need to clone dataset across
# workers and need to pass in function that returns the dataset rather
# than passing dataset instance itself.
use_dataset_fn = isinstance(strategy, tf.distribute.experimental.TPUStrategy)
use_dataset_fn = isinstance(strategy, tf.distribute.TPUStrategy)
split = "train"
bsz_per_host = int(batch_size / num_hosts)
record_glob_base = format_filename(
......
......@@ -135,7 +135,7 @@ def get_v1_distribution_strategy(params):
}
os.environ["TF_CONFIG"] = json.dumps(tf_config_env)
distribution = tf.distribute.experimental.TPUStrategy(
distribution = tf.distribute.TPUStrategy(
tpu_cluster_resolver, steps_per_run=100)
else:
......
......@@ -135,7 +135,7 @@ class SemanticSegmentationTask(base_task.Task):
if training:
# TODO(arashwan): make MeanIoU tpu friendly.
if not isinstance(tf.distribute.get_strategy(),
tf.distribute.experimental.TPUStrategy):
tf.distribute.TPUStrategy):
metrics.append(segmentation_metrics.MeanIoU(
name='mean_iou',
num_classes=self.task_config.model.num_classes,
......
......@@ -43,7 +43,7 @@ builder to 'records' or 'tfds' in the configurations.
Note: These models will **not** work with TPUs on Colab.
You can train image classification models on Cloud TPUs using
[tf.distribute.experimental.TPUStrategy](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/TPUStrategy?version=nightly).
[tf.distribute.TPUStrategy](https://www.tensorflow.org/api_docs/python/tf.distribute.TPUStrategy?version=nightly).
If you are not familiar with Cloud TPUs, it is strongly recommended that you go
through the
[quickstart](https://cloud.google.com/tpu/docs/quickstart) to learn how to
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment