Commit 8f345563 authored by Chenkai Kuang's avatar Chenkai Kuang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 334517779
parent 2986bcaf
...@@ -207,8 +207,7 @@ class DistributedExecutor(object): ...@@ -207,8 +207,7 @@ class DistributedExecutor(object):
# across workers. Since Dataset instance cannot be cloned in eager mode, # across workers. Since Dataset instance cannot be cloned in eager mode,
# we instead pass callable that returns a dataset. # we instead pass callable that returns a dataset.
if self._is_multi_host: if self._is_multi_host:
return iter( return iter(strategy.distribute_datasets_from_function(input_fn))
strategy.experimental_distribute_datasets_from_function(input_fn))
else: else:
input_data = input_fn() input_data = input_fn()
return iter(strategy.experimental_distribute_dataset(input_data)) return iter(strategy.experimental_distribute_dataset(input_data))
......
...@@ -65,8 +65,7 @@ def _get_input_iterator(input_fn, strategy): ...@@ -65,8 +65,7 @@ def _get_input_iterator(input_fn, strategy):
# pass callable that returns a dataset. # pass callable that returns a dataset.
if not callable(input_fn): if not callable(input_fn):
raise ValueError('`input_fn` should be a closure that returns a dataset.') raise ValueError('`input_fn` should be a closure that returns a dataset.')
iterator = iter( iterator = iter(strategy.distribute_datasets_from_function(input_fn))
strategy.experimental_distribute_datasets_from_function(input_fn))
return iterator return iterator
......
...@@ -325,8 +325,7 @@ def get_predictions_and_labels(strategy, ...@@ -325,8 +325,7 @@ def get_predictions_and_labels(strategy,
tf.experimental.async_clear_error() tf.experimental.async_clear_error()
return preds, golds return preds, golds
test_iter = iter( test_iter = iter(strategy.distribute_datasets_from_function(eval_input_fn))
strategy.experimental_distribute_datasets_from_function(eval_input_fn))
predictions, labels = _run_evaluation(test_iter) predictions, labels = _run_evaluation(test_iter)
return predictions, labels return predictions, labels
......
...@@ -186,8 +186,7 @@ def predict_squad_customized(strategy, input_meta_data, predict_tfrecord_path, ...@@ -186,8 +186,7 @@ def predict_squad_customized(strategy, input_meta_data, predict_tfrecord_path,
FLAGS.predict_batch_size, FLAGS.predict_batch_size,
is_training=False) is_training=False)
predict_iterator = iter( predict_iterator = iter(
strategy.experimental_distribute_datasets_from_function( strategy.distribute_datasets_from_function(predict_dataset_fn))
predict_dataset_fn))
@tf.function @tf.function
def predict_step(iterator): def predict_step(iterator):
......
...@@ -230,7 +230,7 @@ def get_input_dataset(input_file_pattern, ...@@ -230,7 +230,7 @@ def get_input_dataset(input_file_pattern,
strategy.num_replicas_in_sync)) strategy.num_replicas_in_sync))
# As auto rebatching is not supported in # As auto rebatching is not supported in
# `experimental_distribute_datasets_from_function()` API, which is # `distribute_datasets_from_function()` API, which is
# required when cloning dataset to multiple workers in eager mode, # required when cloning dataset to multiple workers in eager mode,
# we use per-replica batch size. # we use per-replica batch size.
batch_size = int(batch_size / strategy.num_replicas_in_sync) batch_size = int(batch_size / strategy.num_replicas_in_sync)
...@@ -249,6 +249,6 @@ def get_input_dataset(input_file_pattern, ...@@ -249,6 +249,6 @@ def get_input_dataset(input_file_pattern,
input_pipeline_context=ctx) input_pipeline_context=ctx)
if use_dataset_fn: if use_dataset_fn:
return strategy.experimental_distribute_datasets_from_function(_dataset_fn) return strategy.distribute_datasets_from_function(_dataset_fn)
else: else:
return strategy.experimental_distribute_dataset(_dataset_fn()) return strategy.experimental_distribute_dataset(_dataset_fn())
...@@ -80,7 +80,7 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -80,7 +80,7 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
metrics = task.build_metrics() metrics = task.build_metrics()
strategy = tf.distribute.get_strategy() strategy = tf.distribute.get_strategy()
dataset = strategy.experimental_distribute_datasets_from_function( dataset = strategy.distribute_datasets_from_function(
functools.partial(task.build_inputs, config.train_data)) functools.partial(task.build_inputs, config.train_data))
iterator = iter(dataset) iterator = iter(dataset)
......
...@@ -66,7 +66,7 @@ class TaggingTest(tf.test.TestCase): ...@@ -66,7 +66,7 @@ class TaggingTest(tf.test.TestCase):
metrics = task.build_metrics() metrics = task.build_metrics()
strategy = tf.distribute.get_strategy() strategy = tf.distribute.get_strategy()
dataset = strategy.experimental_distribute_datasets_from_function( dataset = strategy.distribute_datasets_from_function(
functools.partial(task.build_inputs, config.train_data)) functools.partial(task.build_inputs, config.train_data))
iterator = iter(dataset) iterator = iter(dataset)
......
...@@ -227,12 +227,11 @@ class TransformerTask(object): ...@@ -227,12 +227,11 @@ class TransformerTask(object):
if self.use_tpu: if self.use_tpu:
# Different from experimental_distribute_dataset, # Different from experimental_distribute_dataset,
# experimental_distribute_datasets_from_function requires # distribute_datasets_from_function requires
# per-replica/local batch size. # per-replica/local batch size.
params["batch_size"] /= self.distribution_strategy.num_replicas_in_sync params["batch_size"] /= self.distribution_strategy.num_replicas_in_sync
train_ds = ( train_ds = (
self.distribution_strategy self.distribution_strategy.distribute_datasets_from_function(
.experimental_distribute_datasets_from_function(
lambda ctx: data_pipeline.train_input_fn(params, ctx))) lambda ctx: data_pipeline.train_input_fn(params, ctx)))
else: else:
train_ds = data_pipeline.train_input_fn(params) train_ds = data_pipeline.train_input_fn(params)
......
...@@ -167,8 +167,7 @@ def get_input_iterator(input_fn, strategy): ...@@ -167,8 +167,7 @@ def get_input_iterator(input_fn, strategy):
# pass callable that returns a dataset. # pass callable that returns a dataset.
input_data = input_fn() input_data = input_fn()
if callable(input_data): if callable(input_data):
iterator = iter( iterator = iter(strategy.distribute_datasets_from_function(input_data))
strategy.experimental_distribute_datasets_from_function(input_data))
else: else:
iterator = iter(strategy.experimental_distribute_dataset(input_data)) iterator = iter(strategy.experimental_distribute_dataset(input_data))
return iterator return iterator
...@@ -189,7 +188,7 @@ def get_classification_input_data(batch_size, seq_len, strategy, is_training, ...@@ -189,7 +188,7 @@ def get_classification_input_data(batch_size, seq_len, strategy, is_training,
strategy.num_replicas_in_sync)) strategy.num_replicas_in_sync))
# As auto rebatching is not supported in # As auto rebatching is not supported in
# `experimental_distribute_datasets_from_function()` API, which is # `distribute_datasets_from_function()` API, which is
# required when cloning dataset to multiple workers in eager mode, # required when cloning dataset to multiple workers in eager mode,
# we use per-replica batch size. # we use per-replica batch size.
batch_size = int(batch_size / strategy.num_replicas_in_sync) batch_size = int(batch_size / strategy.num_replicas_in_sync)
...@@ -222,7 +221,7 @@ def get_squad_input_data(batch_size, seq_len, q_len, strategy, is_training, ...@@ -222,7 +221,7 @@ def get_squad_input_data(batch_size, seq_len, q_len, strategy, is_training,
strategy.num_replicas_in_sync)) strategy.num_replicas_in_sync))
# As auto rebatching is not supported in # As auto rebatching is not supported in
# `experimental_distribute_datasets_from_function()` API, which is # `distribute_datasets_from_function()` API, which is
# required when cloning dataset to multiple workers in eager mode, # required when cloning dataset to multiple workers in eager mode,
# we use per-replica batch size. # we use per-replica batch size.
batch_size = int(batch_size / strategy.num_replicas_in_sync) batch_size = int(batch_size / strategy.num_replicas_in_sync)
...@@ -624,7 +623,7 @@ def get_pretrain_input_data(batch_size, ...@@ -624,7 +623,7 @@ def get_pretrain_input_data(batch_size,
strategy.num_replicas_in_sync)) strategy.num_replicas_in_sync))
# As auto rebatching is not supported in # As auto rebatching is not supported in
# `experimental_distribute_datasets_from_function()` API, which is # `distribute_datasets_from_function()` API, which is
# required when cloning dataset to multiple workers in eager mode, # required when cloning dataset to multiple workers in eager mode,
# we use per-replica batch size. # we use per-replica batch size.
batch_size = int(batch_size / strategy.num_replicas_in_sync) batch_size = int(batch_size / strategy.num_replicas_in_sync)
......
...@@ -297,8 +297,7 @@ class DatasetBuilder: ...@@ -297,8 +297,7 @@ class DatasetBuilder:
'Passed a strategy with %d devices, but expected' 'Passed a strategy with %d devices, but expected'
'%d devices.', strategy.num_replicas_in_sync, '%d devices.', strategy.num_replicas_in_sync,
self.config.num_devices) self.config.num_devices)
dataset = strategy.experimental_distribute_datasets_from_function( dataset = strategy.distribute_datasets_from_function(self._build)
self._build)
else: else:
dataset = self._build() dataset = self._build()
......
...@@ -42,7 +42,7 @@ class ResnetRunnable(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -42,7 +42,7 @@ class ResnetRunnable(orbit.StandardTrainer, orbit.StandardEvaluator):
self.strategy.num_replicas_in_sync)) self.strategy.num_replicas_in_sync))
# As auto rebatching is not supported in # As auto rebatching is not supported in
# `experimental_distribute_datasets_from_function()` API, which is # `distribute_datasets_from_function()` API, which is
# required when cloning dataset to multiple workers in eager mode, # required when cloning dataset to multiple workers in eager mode,
# we use per-replica batch size. # we use per-replica batch size.
self.batch_size = int(batch_size / self.strategy.num_replicas_in_sync) self.batch_size = int(batch_size / self.strategy.num_replicas_in_sync)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment