Commit 74a271b8 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

when the number of files is less than the number of input pipelines, fall back...

when the number of files is less than the number of input pipelines, fall back to "_read_single_file" which will send all files to every workers.

PiperOrigin-RevId: 348865649
parent e2a31b15
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""A common dataset reader.""" """A common dataset reader."""
import random import random
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
from absl import logging
import tensorflow as tf import tensorflow as tf
import tensorflow_datasets as tfds import tensorflow_datasets as tfds
...@@ -139,10 +139,9 @@ class InputReader: ...@@ -139,10 +139,9 @@ class InputReader:
self._tf_data_service_address = params.tf_data_service_address self._tf_data_service_address = params.tf_data_service_address
self._tf_data_service_job_name = params.tf_data_service_job_name self._tf_data_service_job_name = params.tf_data_service_job_name
def _read_sharded_files(self, def _shard_files_then_read(
input_context: Optional[ self, input_context: Optional[tf.distribute.InputContext] = None):
tf.distribute.InputContext] = None): """Shards the data files and then sent a split to every worker to read."""
"""Reads a dataset from sharded files."""
dataset = tf.data.Dataset.from_tensor_slices(self._matched_files) dataset = tf.data.Dataset.from_tensor_slices(self._matched_files)
# Shuffle and repeat at file level. # Shuffle and repeat at file level.
...@@ -170,14 +169,13 @@ class InputReader: ...@@ -170,14 +169,13 @@ class InputReader:
deterministic=self._deterministic) deterministic=self._deterministic)
return dataset return dataset
def _read_single_file(self, def _read_files_then_shard(
input_context: Optional[ self, input_context: Optional[tf.distribute.InputContext] = None):
tf.distribute.InputContext] = None): """Sends all data files to every worker and then shard by data."""
"""Reads a dataset from a single file."""
# Read from `self._shards` if it is provided.
dataset = self._dataset_fn(self._matched_files) dataset = self._dataset_fn(self._matched_files)
# When `input_file` is a path to a single file, disable auto sharding # When `input_file` is a path to a single file or the number of files is
# less than the number of input pipelines, disable auto sharding
# so that same input file is sent to all workers. # so that same input file is sent to all workers.
options = tf.data.Options() options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = ( options.experimental_distribute.auto_shard_policy = (
...@@ -238,9 +236,18 @@ class InputReader: ...@@ -238,9 +236,18 @@ class InputReader:
if self._tfds_builder: if self._tfds_builder:
dataset = self._read_tfds(input_context) dataset = self._read_tfds(input_context)
elif len(self._matched_files) > 1: elif len(self._matched_files) > 1:
dataset = self._read_sharded_files(input_context) if input_context and (len(self._matched_files) <
input_context.num_input_pipelines):
logging.warn(
'The number of files %d is less than the number of input pipelines '
'%d. We will send all input files to every worker. '
'Please consider sharding your data into more files.',
len(self._matched_files), input_context.num_input_pipelines)
dataset = self._read_files_then_shard(input_context)
else:
dataset = self._shard_files_then_read(input_context)
elif len(self._matched_files) == 1: elif len(self._matched_files) == 1:
dataset = self._read_single_file(input_context) dataset = self._read_files_then_shard(input_context)
else: else:
raise ValueError('It is unexpected that `tfds_builder` is None and ' raise ValueError('It is unexpected that `tfds_builder` is None and '
'there is also no `matched_files`.') 'there is also no `matched_files`.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment