Commit fe2a7b30 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Support DataConfig.input_path to be a list of files, and remove the `shards`...

Support DataConfig.input_path to be a list of files, and remove the `shards` argument from InputReader()

PiperOrigin-RevId: 331902956
parent 92b86f9e
......@@ -16,7 +16,7 @@
"""A common dataset reader."""
import random
from typing import Any, Callable, List, Optional
from typing import Any, Callable, Optional
import tensorflow as tf
import tensorflow_datasets as tfds
......@@ -33,7 +33,6 @@ class InputReader:
def __init__(self,
params: cfg.DataConfig,
shards: Optional[List[str]] = None,
dataset_fn=tf.data.TFRecordDataset,
decoder_fn: Optional[Callable[..., Any]] = None,
parser_fn: Optional[Callable[..., Any]] = None,
......@@ -45,8 +44,6 @@ class InputReader:
Args:
params: A config_definitions.DataConfig object.
shards: A list of files to be read. If given, read from these files.
Otherwise, read from params.input_path.
dataset_fn: A `tf.data.Dataset` that consumes the input files. For
example, it can be `tf.data.TFRecordDataset`.
decoder_fn: An optional `callable` that takes the serialized data string
......@@ -56,36 +53,54 @@ class InputReader:
model. It will be executed after decoder_fn.
transform_and_batch_fn: An optional `callable` that takes a
`tf.data.Dataset` object and an optional `tf.distribute.InputContext` as
input, and returns a `tf.data.Dataset` object. It will be
executed after `parser_fn` to transform and batch the dataset; if None,
after `parser_fn` is executed, the dataset will be batched into
per-replica batch size.
input, and returns a `tf.data.Dataset` object. It will be executed after
`parser_fn` to transform and batch the dataset; if None, after
`parser_fn` is executed, the dataset will be batched into per-replica
batch size.
postprocess_fn: A optional `callable` that processes batched tensors. It
will be executed after batching.
"""
if params.input_path and params.tfds_name:
raise ValueError('At most one of `input_path` and `tfds_name` can be '
'specified, but got %s and %s.' % (
params.input_path, params.tfds_name))
self._shards = shards
'specified, but got %s and %s.' %
(params.input_path, params.tfds_name))
self._tfds_builder = None
if self._shards:
self._num_files = len(self._shards)
elif not params.tfds_name:
self._input_patterns = params.input_path.strip().split(',')
self._num_files = 0
for input_pattern in self._input_patterns:
input_pattern = input_pattern.strip()
if not input_pattern:
continue
matched_files = tf.io.gfile.glob(input_pattern)
if not matched_files:
raise ValueError('%s does not match any files.' % input_pattern)
else:
self._num_files += len(matched_files)
if self._num_files == 0:
self._matched_files = []
if params.input_path:
# Read dataset from files.
usage = ('`input_path` should be either (1) a str indicating a file '
'path/pattern, or (2) a str indicating multiple file '
'paths/patterns separated by comma (e.g "a, b, c" or no spaces '
'"a,b,c", or (3) a list of str, each of which is a file '
'path/pattern or multiple file paths/patterns separated by '
'comma, but got: %s')
if isinstance(params.input_path, str):
input_path_list = [params.input_path]
elif isinstance(params.input_path, (list, tuple)):
if any(not isinstance(x, str) for x in params.input_path):
raise ValueError(usage % params.input_path)
input_path_list = params.input_path
else:
raise ValueError(usage % params.input_path)
for input_path in input_path_list:
input_patterns = input_path.strip().split(',')
for input_pattern in input_patterns:
input_pattern = input_pattern.strip()
if not input_pattern:
continue
if '*' in input_pattern or '?' in input_pattern:
tmp_matched_files = tf.io.gfile.glob(input_pattern)
if not tmp_matched_files:
raise ValueError('%s does not match any files.' % input_pattern)
self._matched_files.extend(tmp_matched_files)
else:
self._matched_files.append(input_pattern)
if not self._matched_files:
raise ValueError('%s does not match any files.' % params.input_path)
else:
# Read dataset from TFDS.
if not params.tfds_split:
raise ValueError(
'`tfds_name` is %s, but `tfds_split` is not specified.' %
......@@ -119,23 +134,16 @@ class InputReader:
self._tf_data_service_address = params.tf_data_service_address
self._tf_data_service_job_name = params.tf_data_service_job_name
def _read_sharded_files(
self,
input_context: Optional[tf.distribute.InputContext] = None):
def _read_sharded_files(self,
input_context: Optional[
tf.distribute.InputContext] = None):
"""Reads a dataset from sharded files."""
# Read from `self._shards` if it is provided.
if self._shards:
dataset = tf.data.Dataset.from_tensor_slices(self._shards)
else:
dataset = tf.data.Dataset.list_files(
self._input_patterns,
seed=self._seed,
shuffle=self._is_training)
dataset = tf.data.Dataset.from_tensor_slices(self._matched_files)
# Shuffle and repeat at file level.
if self._shards and self._is_training:
if self._is_training:
dataset = dataset.shuffle(
len(self._shards),
len(self._matched_files),
seed=self._seed,
reshuffle_each_iteration=True)
......@@ -157,12 +165,12 @@ class InputReader:
deterministic=self._deterministic)
return dataset
def _read_single_file(
self,
input_context: Optional[tf.distribute.InputContext] = None):
def _read_single_file(self,
input_context: Optional[
tf.distribute.InputContext] = None):
"""Reads a dataset from a single file."""
# Read from `self._shards` if it is provided.
dataset = self._dataset_fn(self._shards or self._input_patterns)
dataset = self._dataset_fn(self._matched_files)
# When `input_file` is a path to a single file, disable auto sharding
# so that same input file is sent to all workers.
......@@ -224,11 +232,13 @@ class InputReader:
"""Generates a tf.data.Dataset object."""
if self._tfds_builder:
dataset = self._read_tfds(input_context)
elif self._num_files > 1:
elif len(self._matched_files) > 1:
dataset = self._read_sharded_files(input_context)
else:
assert self._num_files == 1
elif len(self._matched_files) == 1:
dataset = self._read_single_file(input_context)
else:
raise ValueError('It is unexpected that `tfds_builder` is None and '
'there is also no `matched_files`.')
if self._cache:
dataset = dataset.cache()
......
......@@ -15,7 +15,7 @@
# ==============================================================================
"""Common configuration settings."""
from typing import Optional, Union
from typing import Optional, Sequence, Union
import dataclasses
......@@ -30,9 +30,12 @@ class DataConfig(base_config.Config):
"""The base configuration for building datasets.
Attributes:
input_path: The path to the input. It can be either (1) a file pattern, or
(2) multiple file patterns separated by comma. It should not be specified
when the following `tfds_name` is specified.
input_path: The path to the input. It can be either (1) a str indicating
a file path/pattern, or (2) a str indicating multiple file paths/patterns
separated by comma (e.g "a, b, c" or no spaces "a,b,c"), or
(3) a list of str, each of which is a file path/pattern or multiple file
paths/patterns separated by comma.
It should not be specified when the following `tfds_name` is specified.
tfds_name: The name of the tensorflow dataset (TFDS). It should not be
specified when the above `input_path` is specified.
tfds_split: A str indicating which split of the data to load from TFDS. It
......@@ -71,7 +74,7 @@ class DataConfig(base_config.Config):
features. The main use case is to skip the image/video decoding for better
performance.
"""
input_path: str = ""
input_path: Union[Sequence[str], str] = ""
tfds_name: str = ""
tfds_split: str = ""
global_batch_size: int = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment