Commit ca509f79 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 358289482
parent a5e7e2ce
......@@ -14,7 +14,7 @@
"""A common dataset reader."""
import random
from typing import Any, Callable, Optional
from typing import Any, Callable, List, Optional
from absl import logging
import tensorflow as tf
......@@ -27,6 +27,13 @@ def _get_random_integer():
return random.randint(0, (1 << 31) - 1)
def _maybe_map_fn(dataset: tf.data.Dataset,
fn: Optional[Callable[..., Any]] = None) -> tf.data.Dataset:
"""Calls dataset.map if a valid function is passed in."""
return dataset if fn is None else dataset.map(
fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
class InputReader:
"""Input reader that returns a tf.data.Dataset instance."""
......@@ -74,38 +81,7 @@ class InputReader:
self._tfds_builder = None
self._matched_files = []
if params.input_path:
# Read dataset from files.
usage = ('`input_path` should be either (1) a str indicating a file '
'path/pattern, or (2) a str indicating multiple file '
'paths/patterns separated by comma (e.g "a, b, c" or no spaces '
'"a,b,c", or (3) a list of str, each of which is a file '
'path/pattern or multiple file paths/patterns separated by '
'comma, but got: %s')
if isinstance(params.input_path, str):
input_path_list = [params.input_path]
elif isinstance(params.input_path, (list, tuple)):
if any(not isinstance(x, str) for x in params.input_path):
raise ValueError(usage % params.input_path)
input_path_list = params.input_path
else:
raise ValueError(usage % params.input_path)
for input_path in input_path_list:
input_patterns = input_path.strip().split(',')
for input_pattern in input_patterns:
input_pattern = input_pattern.strip()
if not input_pattern:
continue
if '*' in input_pattern or '?' in input_pattern:
tmp_matched_files = tf.io.gfile.glob(input_pattern)
if not tmp_matched_files:
raise ValueError('%s does not match any files.' % input_pattern)
self._matched_files.extend(tmp_matched_files)
else:
self._matched_files.append(input_pattern)
if not self._matched_files:
raise ValueError('%s does not match any files.' % params.input_path)
self._matched_files = self._match_files(params.input_path)
else:
# Read dataset from TFDS.
if not params.tfds_split:
......@@ -148,15 +124,57 @@ class InputReader:
self._enable_round_robin_tf_data_service = params.get(
'enable_round_robin_tf_data_service', False)
def _match_files(self, input_path: str) -> List[str]:
"""Matches files from an input_path."""
matched_files = []
# Read dataset from files.
usage = ('`input_path` should be either (1) a str indicating a file '
'path/pattern, or (2) a str indicating multiple file '
'paths/patterns separated by comma (e.g "a, b, c" or no spaces '
'"a,b,c", or (3) a list of str, each of which is a file '
'path/pattern or multiple file paths/patterns separated by '
'comma, but got: %s')
if isinstance(input_path, str):
input_path_list = [input_path]
elif isinstance(input_path, (list, tuple)):
if any(not isinstance(x, str) for x in input_path):
raise ValueError(usage % input_path)
input_path_list = input_path
else:
raise ValueError(usage % input_path)
for input_path in input_path_list:
input_patterns = input_path.strip().split(',')
for input_pattern in input_patterns:
input_pattern = input_pattern.strip()
if not input_pattern:
continue
if '*' in input_pattern or '?' in input_pattern:
tmp_matched_files = tf.io.gfile.glob(input_pattern)
if not tmp_matched_files:
raise ValueError('%s does not match any files.' % input_pattern)
matched_files.extend(tmp_matched_files)
else:
matched_files.append(input_pattern)
if not matched_files:
raise ValueError('%s does not match any files.' % input_path)
return matched_files
def _shard_files_then_read(
self, input_context: Optional[tf.distribute.InputContext] = None):
self,
matched_files: List[str],
dataset_fn,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Shards the data files and then sent a split to every worker to read."""
dataset = tf.data.Dataset.from_tensor_slices(self._matched_files)
dataset = tf.data.Dataset.from_tensor_slices(matched_files)
# Shuffle and repeat at file level.
if self._is_training:
dataset = dataset.shuffle(
len(self._matched_files),
len(matched_files),
seed=self._seed,
reshuffle_each_iteration=True)
......@@ -171,7 +189,7 @@ class InputReader:
dataset = dataset.repeat()
dataset = dataset.interleave(
map_func=self._dataset_fn,
map_func=dataset_fn,
cycle_length=self._cycle_length,
block_length=self._block_length,
num_parallel_calls=(self._cycle_length if self._cycle_length else
......@@ -180,9 +198,13 @@ class InputReader:
return dataset
def _read_files_then_shard(
self, input_context: Optional[tf.distribute.InputContext] = None):
self,
matched_files: List[str],
dataset_fn,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Sends all data files to every worker and then shard by data."""
dataset = self._dataset_fn(self._matched_files)
dataset = dataset_fn(matched_files)
# When `input_file` is a path to a single file or the number of files is
# less than the number of input pipelines, disable auto sharding
......@@ -238,26 +260,35 @@ class InputReader:
raise ValueError('tfds_info is not available, because the dataset '
'is not loaded from tfds.')
def read(
def _read_decode_and_parse_dataset(
self,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Generates a tf.data.Dataset object."""
if self._tfds_builder:
matched_files: List[str],
dataset_fn,
batch_size: int,
input_context: Optional[tf.distribute.InputContext] = None,
tfds_builder: bool = False) -> tf.data.Dataset:
"""Returns a tf.data.Dataset object after reading, decoding, and parsing."""
if tfds_builder:
dataset = self._read_tfds(input_context)
elif len(self._matched_files) > 1:
if input_context and (len(self._matched_files) <
if input_context and (len(matched_files) <
input_context.num_input_pipelines):
logging.warn(
'The number of files %d is less than the number of input pipelines '
'%d. We will send all input files to every worker. '
'Please consider sharding your data into more files.',
len(self._matched_files), input_context.num_input_pipelines)
dataset = self._read_files_then_shard(input_context)
len(matched_files), input_context.num_input_pipelines)
dataset = self._read_files_then_shard(matched_files,
dataset_fn,
input_context)
else:
dataset = self._shard_files_then_read(input_context)
elif len(self._matched_files) == 1:
dataset = self._read_files_then_shard(input_context)
dataset = self._shard_files_then_read(matched_files,
dataset_fn,
input_context)
elif len(matched_files) == 1:
dataset = self._read_files_then_shard(matched_files,
dataset_fn,
input_context)
else:
raise ValueError('It is unexpected that `tfds_builder` is None and '
'there is also no `matched_files`.')
......@@ -268,25 +299,28 @@ class InputReader:
if self._is_training:
dataset = dataset.shuffle(self._shuffle_buffer_size)
def maybe_map_fn(dataset, fn):
return dataset if fn is None else dataset.map(
fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = maybe_map_fn(dataset, self._decoder_fn)
dataset = _maybe_map_fn(dataset, self._decoder_fn)
if self._sample_fn is not None:
dataset = dataset.apply(self._sample_fn)
dataset = maybe_map_fn(dataset, self._parser_fn)
dataset = _maybe_map_fn(dataset, self._parser_fn)
if self._transform_and_batch_fn is not None:
dataset = self._transform_and_batch_fn(dataset, input_context)
else:
per_replica_batch_size = input_context.get_per_replica_batch_size(
self._global_batch_size) if input_context else self._global_batch_size
batch_size) if input_context else batch_size
dataset = dataset.batch(
per_replica_batch_size, drop_remainder=self._drop_remainder)
per_replica_batch_size, drop_remainder=self._drop_remainder
)
dataset = maybe_map_fn(dataset, self._postprocess_fn)
return dataset
def _maybe_apply_data_service(
self,
dataset: tf.data.Dataset,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Potentially distributes a dataset."""
if self._enable_tf_data_service and input_context:
if self._enable_round_robin_tf_data_service:
replicas_per_input_pipeline = input_context.num_replicas_in_sync // (
......@@ -316,6 +350,20 @@ class InputReader:
processing_mode='parallel_epochs',
service=self._tf_data_service_address,
job_name=self._tf_data_service_job_name))
return dataset
def read(
self,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Generates a tf.data.Dataset object."""
dataset = self._read_decode_and_parse_dataset(self._matched_files,
self._dataset_fn,
self._global_batch_size,
input_context,
self._tfds_builder)
dataset = _maybe_map_fn(dataset, self._postprocess_fn)
dataset = self._maybe_apply_data_service(dataset, input_context)
if self._deterministic is not None:
options = tf.data.Options()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment