Commit 8f345563 authored by Chenkai Kuang's avatar Chenkai Kuang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 334517779
parent 2986bcaf
......@@ -207,8 +207,7 @@ class DistributedExecutor(object):
# across workers. Since Dataset instance cannot be cloned in eager mode,
# we instead pass callable that returns a dataset.
if self._is_multi_host:
return iter(
strategy.experimental_distribute_datasets_from_function(input_fn))
return iter(strategy.distribute_datasets_from_function(input_fn))
else:
input_data = input_fn()
return iter(strategy.experimental_distribute_dataset(input_data))
......
......@@ -65,8 +65,7 @@ def _get_input_iterator(input_fn, strategy):
# pass callable that returns a dataset.
if not callable(input_fn):
raise ValueError('`input_fn` should be a closure that returns a dataset.')
iterator = iter(
strategy.experimental_distribute_datasets_from_function(input_fn))
iterator = iter(strategy.distribute_datasets_from_function(input_fn))
return iterator
......
......@@ -325,8 +325,7 @@ def get_predictions_and_labels(strategy,
tf.experimental.async_clear_error()
return preds, golds
test_iter = iter(
strategy.experimental_distribute_datasets_from_function(eval_input_fn))
test_iter = iter(strategy.distribute_datasets_from_function(eval_input_fn))
predictions, labels = _run_evaluation(test_iter)
return predictions, labels
......
......@@ -186,8 +186,7 @@ def predict_squad_customized(strategy, input_meta_data, predict_tfrecord_path,
FLAGS.predict_batch_size,
is_training=False)
predict_iterator = iter(
strategy.experimental_distribute_datasets_from_function(
predict_dataset_fn))
strategy.distribute_datasets_from_function(predict_dataset_fn))
@tf.function
def predict_step(iterator):
......
......@@ -230,7 +230,7 @@ def get_input_dataset(input_file_pattern,
strategy.num_replicas_in_sync))
# As auto rebatching is not supported in
# `experimental_distribute_datasets_from_function()` API, which is
# `distribute_datasets_from_function()` API, which is
# required when cloning dataset to multiple workers in eager mode,
# we use per-replica batch size.
batch_size = int(batch_size / strategy.num_replicas_in_sync)
......@@ -249,6 +249,6 @@ def get_input_dataset(input_file_pattern,
input_pipeline_context=ctx)
if use_dataset_fn:
return strategy.experimental_distribute_datasets_from_function(_dataset_fn)
return strategy.distribute_datasets_from_function(_dataset_fn)
else:
return strategy.experimental_distribute_dataset(_dataset_fn())
......@@ -80,7 +80,7 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
metrics = task.build_metrics()
strategy = tf.distribute.get_strategy()
dataset = strategy.experimental_distribute_datasets_from_function(
dataset = strategy.distribute_datasets_from_function(
functools.partial(task.build_inputs, config.train_data))
iterator = iter(dataset)
......
......@@ -66,7 +66,7 @@ class TaggingTest(tf.test.TestCase):
metrics = task.build_metrics()
strategy = tf.distribute.get_strategy()
dataset = strategy.experimental_distribute_datasets_from_function(
dataset = strategy.distribute_datasets_from_function(
functools.partial(task.build_inputs, config.train_data))
iterator = iter(dataset)
......
......@@ -227,12 +227,11 @@ class TransformerTask(object):
if self.use_tpu:
# Different from experimental_distribute_dataset,
# experimental_distribute_datasets_from_function requires
# distribute_datasets_from_function requires
# per-replica/local batch size.
params["batch_size"] /= self.distribution_strategy.num_replicas_in_sync
train_ds = (
self.distribution_strategy
.experimental_distribute_datasets_from_function(
self.distribution_strategy.distribute_datasets_from_function(
lambda ctx: data_pipeline.train_input_fn(params, ctx)))
else:
train_ds = data_pipeline.train_input_fn(params)
......
......@@ -167,8 +167,7 @@ def get_input_iterator(input_fn, strategy):
# pass callable that returns a dataset.
input_data = input_fn()
if callable(input_data):
iterator = iter(
strategy.experimental_distribute_datasets_from_function(input_data))
iterator = iter(strategy.distribute_datasets_from_function(input_data))
else:
iterator = iter(strategy.experimental_distribute_dataset(input_data))
return iterator
......@@ -189,7 +188,7 @@ def get_classification_input_data(batch_size, seq_len, strategy, is_training,
strategy.num_replicas_in_sync))
# As auto rebatching is not supported in
# `experimental_distribute_datasets_from_function()` API, which is
# `distribute_datasets_from_function()` API, which is
# required when cloning dataset to multiple workers in eager mode,
# we use per-replica batch size.
batch_size = int(batch_size / strategy.num_replicas_in_sync)
......@@ -222,7 +221,7 @@ def get_squad_input_data(batch_size, seq_len, q_len, strategy, is_training,
strategy.num_replicas_in_sync))
# As auto rebatching is not supported in
# `experimental_distribute_datasets_from_function()` API, which is
# `distribute_datasets_from_function()` API, which is
# required when cloning dataset to multiple workers in eager mode,
# we use per-replica batch size.
batch_size = int(batch_size / strategy.num_replicas_in_sync)
......@@ -624,7 +623,7 @@ def get_pretrain_input_data(batch_size,
strategy.num_replicas_in_sync))
# As auto rebatching is not supported in
# `experimental_distribute_datasets_from_function()` API, which is
# `distribute_datasets_from_function()` API, which is
# required when cloning dataset to multiple workers in eager mode,
# we use per-replica batch size.
batch_size = int(batch_size / strategy.num_replicas_in_sync)
......
......@@ -297,8 +297,7 @@ class DatasetBuilder:
'Passed a strategy with %d devices, but expected'
'%d devices.', strategy.num_replicas_in_sync,
self.config.num_devices)
dataset = strategy.experimental_distribute_datasets_from_function(
self._build)
dataset = strategy.distribute_datasets_from_function(self._build)
else:
dataset = self._build()
......
......@@ -42,7 +42,7 @@ class ResnetRunnable(orbit.StandardTrainer, orbit.StandardEvaluator):
self.strategy.num_replicas_in_sync))
# As auto rebatching is not supported in
# `experimental_distribute_datasets_from_function()` API, which is
# `distribute_datasets_from_function()` API, which is
# required when cloning dataset to multiple workers in eager mode,
# we use per-replica batch size.
self.batch_size = int(batch_size / self.strategy.num_replicas_in_sync)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment