Commit fe2a7b30 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Support DataConfig.input_path to be a list of files, and remove the `shards`...

Support DataConfig.input_path to be a list of files, and remove the `shards` argument from InputReader()

PiperOrigin-RevId: 331902956
parent 92b86f9e
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
"""A common dataset reader.""" """A common dataset reader."""
import random import random
from typing import Any, Callable, List, Optional from typing import Any, Callable, Optional
import tensorflow as tf import tensorflow as tf
import tensorflow_datasets as tfds import tensorflow_datasets as tfds
...@@ -33,7 +33,6 @@ class InputReader: ...@@ -33,7 +33,6 @@ class InputReader:
def __init__(self, def __init__(self,
params: cfg.DataConfig, params: cfg.DataConfig,
shards: Optional[List[str]] = None,
dataset_fn=tf.data.TFRecordDataset, dataset_fn=tf.data.TFRecordDataset,
decoder_fn: Optional[Callable[..., Any]] = None, decoder_fn: Optional[Callable[..., Any]] = None,
parser_fn: Optional[Callable[..., Any]] = None, parser_fn: Optional[Callable[..., Any]] = None,
...@@ -45,8 +44,6 @@ class InputReader: ...@@ -45,8 +44,6 @@ class InputReader:
Args: Args:
params: A config_definitions.DataConfig object. params: A config_definitions.DataConfig object.
shards: A list of files to be read. If given, read from these files.
Otherwise, read from params.input_path.
dataset_fn: A `tf.data.Dataset` that consumes the input files. For dataset_fn: A `tf.data.Dataset` that consumes the input files. For
example, it can be `tf.data.TFRecordDataset`. example, it can be `tf.data.TFRecordDataset`.
decoder_fn: An optional `callable` that takes the serialized data string decoder_fn: An optional `callable` that takes the serialized data string
...@@ -56,36 +53,54 @@ class InputReader: ...@@ -56,36 +53,54 @@ class InputReader:
model. It will be executed after decoder_fn. model. It will be executed after decoder_fn.
transform_and_batch_fn: An optional `callable` that takes a transform_and_batch_fn: An optional `callable` that takes a
`tf.data.Dataset` object and an optional `tf.distribute.InputContext` as `tf.data.Dataset` object and an optional `tf.distribute.InputContext` as
input, and returns a `tf.data.Dataset` object. It will be input, and returns a `tf.data.Dataset` object. It will be executed after
executed after `parser_fn` to transform and batch the dataset; if None, `parser_fn` to transform and batch the dataset; if None, after
after `parser_fn` is executed, the dataset will be batched into `parser_fn` is executed, the dataset will be batched into per-replica
per-replica batch size. batch size.
postprocess_fn: A optional `callable` that processes batched tensors. It postprocess_fn: A optional `callable` that processes batched tensors. It
will be executed after batching. will be executed after batching.
""" """
if params.input_path and params.tfds_name: if params.input_path and params.tfds_name:
raise ValueError('At most one of `input_path` and `tfds_name` can be ' raise ValueError('At most one of `input_path` and `tfds_name` can be '
'specified, but got %s and %s.' % ( 'specified, but got %s and %s.' %
params.input_path, params.tfds_name)) (params.input_path, params.tfds_name))
self._shards = shards
self._tfds_builder = None self._tfds_builder = None
if self._shards: self._matched_files = []
self._num_files = len(self._shards) if params.input_path:
elif not params.tfds_name: # Read dataset from files.
self._input_patterns = params.input_path.strip().split(',') usage = ('`input_path` should be either (1) a str indicating a file '
self._num_files = 0 'path/pattern, or (2) a str indicating multiple file '
for input_pattern in self._input_patterns: 'paths/patterns separated by comma (e.g "a, b, c" or no spaces '
input_pattern = input_pattern.strip() '"a,b,c", or (3) a list of str, each of which is a file '
if not input_pattern: 'path/pattern or multiple file paths/patterns separated by '
continue 'comma, but got: %s')
matched_files = tf.io.gfile.glob(input_pattern) if isinstance(params.input_path, str):
if not matched_files: input_path_list = [params.input_path]
raise ValueError('%s does not match any files.' % input_pattern) elif isinstance(params.input_path, (list, tuple)):
else: if any(not isinstance(x, str) for x in params.input_path):
self._num_files += len(matched_files) raise ValueError(usage % params.input_path)
if self._num_files == 0: input_path_list = params.input_path
else:
raise ValueError(usage % params.input_path)
for input_path in input_path_list:
input_patterns = input_path.strip().split(',')
for input_pattern in input_patterns:
input_pattern = input_pattern.strip()
if not input_pattern:
continue
if '*' in input_pattern or '?' in input_pattern:
tmp_matched_files = tf.io.gfile.glob(input_pattern)
if not tmp_matched_files:
raise ValueError('%s does not match any files.' % input_pattern)
self._matched_files.extend(tmp_matched_files)
else:
self._matched_files.append(input_pattern)
if not self._matched_files:
raise ValueError('%s does not match any files.' % params.input_path) raise ValueError('%s does not match any files.' % params.input_path)
else: else:
# Read dataset from TFDS.
if not params.tfds_split: if not params.tfds_split:
raise ValueError( raise ValueError(
'`tfds_name` is %s, but `tfds_split` is not specified.' % '`tfds_name` is %s, but `tfds_split` is not specified.' %
...@@ -119,23 +134,16 @@ class InputReader: ...@@ -119,23 +134,16 @@ class InputReader:
self._tf_data_service_address = params.tf_data_service_address self._tf_data_service_address = params.tf_data_service_address
self._tf_data_service_job_name = params.tf_data_service_job_name self._tf_data_service_job_name = params.tf_data_service_job_name
def _read_sharded_files( def _read_sharded_files(self,
self, input_context: Optional[
input_context: Optional[tf.distribute.InputContext] = None): tf.distribute.InputContext] = None):
"""Reads a dataset from sharded files.""" """Reads a dataset from sharded files."""
# Read from `self._shards` if it is provided. dataset = tf.data.Dataset.from_tensor_slices(self._matched_files)
if self._shards:
dataset = tf.data.Dataset.from_tensor_slices(self._shards)
else:
dataset = tf.data.Dataset.list_files(
self._input_patterns,
seed=self._seed,
shuffle=self._is_training)
# Shuffle and repeat at file level. # Shuffle and repeat at file level.
if self._shards and self._is_training: if self._is_training:
dataset = dataset.shuffle( dataset = dataset.shuffle(
len(self._shards), len(self._matched_files),
seed=self._seed, seed=self._seed,
reshuffle_each_iteration=True) reshuffle_each_iteration=True)
...@@ -157,12 +165,12 @@ class InputReader: ...@@ -157,12 +165,12 @@ class InputReader:
deterministic=self._deterministic) deterministic=self._deterministic)
return dataset return dataset
def _read_single_file( def _read_single_file(self,
self, input_context: Optional[
input_context: Optional[tf.distribute.InputContext] = None): tf.distribute.InputContext] = None):
"""Reads a dataset from a single file.""" """Reads a dataset from a single file."""
# Read from `self._shards` if it is provided. # Read from `self._shards` if it is provided.
dataset = self._dataset_fn(self._shards or self._input_patterns) dataset = self._dataset_fn(self._matched_files)
# When `input_file` is a path to a single file, disable auto sharding # When `input_file` is a path to a single file, disable auto sharding
# so that same input file is sent to all workers. # so that same input file is sent to all workers.
...@@ -224,11 +232,13 @@ class InputReader: ...@@ -224,11 +232,13 @@ class InputReader:
"""Generates a tf.data.Dataset object.""" """Generates a tf.data.Dataset object."""
if self._tfds_builder: if self._tfds_builder:
dataset = self._read_tfds(input_context) dataset = self._read_tfds(input_context)
elif self._num_files > 1: elif len(self._matched_files) > 1:
dataset = self._read_sharded_files(input_context) dataset = self._read_sharded_files(input_context)
else: elif len(self._matched_files) == 1:
assert self._num_files == 1
dataset = self._read_single_file(input_context) dataset = self._read_single_file(input_context)
else:
raise ValueError('It is unexpected that `tfds_builder` is None and '
'there is also no `matched_files`.')
if self._cache: if self._cache:
dataset = dataset.cache() dataset = dataset.cache()
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# ============================================================================== # ==============================================================================
"""Common configuration settings.""" """Common configuration settings."""
from typing import Optional, Union from typing import Optional, Sequence, Union
import dataclasses import dataclasses
...@@ -30,9 +30,12 @@ class DataConfig(base_config.Config): ...@@ -30,9 +30,12 @@ class DataConfig(base_config.Config):
"""The base configuration for building datasets. """The base configuration for building datasets.
Attributes: Attributes:
input_path: The path to the input. It can be either (1) a file pattern, or input_path: The path to the input. It can be either (1) a str indicating
(2) multiple file patterns separated by comma. It should not be specified a file path/pattern, or (2) a str indicating multiple file paths/patterns
when the following `tfds_name` is specified. separated by comma (e.g "a, b, c" or no spaces "a,b,c"), or
(3) a list of str, each of which is a file path/pattern or multiple file
paths/patterns separated by comma.
It should not be specified when the following `tfds_name` is specified.
tfds_name: The name of the tensorflow dataset (TFDS). It should not be tfds_name: The name of the tensorflow dataset (TFDS). It should not be
specified when the above `input_path` is specified. specified when the above `input_path` is specified.
tfds_split: A str indicating which split of the data to load from TFDS. It tfds_split: A str indicating which split of the data to load from TFDS. It
...@@ -71,7 +74,7 @@ class DataConfig(base_config.Config): ...@@ -71,7 +74,7 @@ class DataConfig(base_config.Config):
features. The main use case is to skip the image/video decoding for better features. The main use case is to skip the image/video decoding for better
performance. performance.
""" """
input_path: str = "" input_path: Union[Sequence[str], str] = ""
tfds_name: str = "" tfds_name: str = ""
tfds_split: str = "" tfds_split: str = ""
global_batch_size: int = 0 global_batch_size: int = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment