Commit 6264aa72 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Adds a feature to process a dictionary for `input_paths` in `DataConfig`, to...

Adds a feature to process a dictionary for `input_paths` in `DataConfig`, to allow combining multiple datasets using a user defined combine_fn.

PiperOrigin-RevId: 381363688
parent 191d9624
...@@ -29,12 +29,13 @@ class DataConfig(base_config.Config): ...@@ -29,12 +29,13 @@ class DataConfig(base_config.Config):
"""The base configuration for building datasets. """The base configuration for building datasets.
Attributes: Attributes:
input_path: The path to the input. It can be either (1) a str indicating input_path: The path to the input. It can be either (1) a str indicating a
a file path/pattern, or (2) a str indicating multiple file paths/patterns file path/pattern, or (2) a str indicating multiple file paths/patterns
separated by comma (e.g "a, b, c" or no spaces "a,b,c"), or separated by comma (e.g "a, b, c" or no spaces "a,b,c"), or (3) a list of
(3) a list of str, each of which is a file path/pattern or multiple file str, each of which is a file path/pattern or multiple file paths/patterns
paths/patterns separated by comma. separated by comma, or (4) a dictionary of the previous three approaches
It should not be specified when the following `tfds_name` is specified. for more advanced data mixing using named access. It should not be
specified when the following `tfds_name` is specified.
tfds_name: The name of the tensorflow dataset (TFDS). It should not be tfds_name: The name of the tensorflow dataset (TFDS). It should not be
specified when the above `input_path` is specified. specified when the above `input_path` is specified.
tfds_split: A str indicating which split of the data to load from TFDS. It tfds_split: A str indicating which split of the data to load from TFDS. It
...@@ -46,8 +47,8 @@ class DataConfig(base_config.Config): ...@@ -46,8 +47,8 @@ class DataConfig(base_config.Config):
shuffle_buffer_size: The buffer size used for shuffling training data. shuffle_buffer_size: The buffer size used for shuffling training data.
cache: Whether to cache dataset examples. If `True`, we will cache the cache: Whether to cache dataset examples. If `True`, we will cache the
dataset after applying the decode_fn and parse_fn. It can be used to avoid dataset after applying the decode_fn and parse_fn. It can be used to avoid
re-reading from disk, re-decoding and re-parsing the example on the re-reading from disk, re-decoding and re-parsing the example on the second
second epoch, but it requires significant memory overhead. epoch, but it requires significant memory overhead.
cycle_length: The number of files that will be processed concurrently when cycle_length: The number of files that will be processed concurrently when
interleaving files. interleaving files.
block_length: The number of consecutive elements to produce from each input block_length: The number of consecutive elements to produce from each input
...@@ -59,11 +60,10 @@ class DataConfig(base_config.Config): ...@@ -59,11 +60,10 @@ class DataConfig(base_config.Config):
tf_data_service_address: The URI of a tf.data service to offload tf_data_service_address: The URI of a tf.data service to offload
preprocessing onto during training. The URI should be in the format preprocessing onto during training. The URI should be in the format
"protocol://address", e.g. "grpc://tf-data-service:5050". It can be "protocol://address", e.g. "grpc://tf-data-service:5050". It can be
overridden by `FLAGS.tf_data_service` flag in the binary. overridden by `FLAGS.tf_data_service` flag in the binary.
tf_data_service_job_name: The name of the tf.data service job. This tf_data_service_job_name: The name of the tf.data service job. This argument
argument makes it possible for multiple datasets to share the same job. makes it possible for multiple datasets to share the same job. The default
The default behavior is that the dataset creates anonymous, exclusively behavior is that the dataset creates anonymous, exclusively owned jobs.
owned jobs.
tfds_data_dir: A str specifying the directory to read/write TFDS data. tfds_data_dir: A str specifying the directory to read/write TFDS data.
tfds_as_supervised: A bool. When loading dataset from TFDS, if True, the tfds_as_supervised: A bool. When loading dataset from TFDS, if True, the
returned tf.data.Dataset will have a 2-tuple structure (input, label) returned tf.data.Dataset will have a 2-tuple structure (input, label)
...@@ -75,7 +75,7 @@ class DataConfig(base_config.Config): ...@@ -75,7 +75,7 @@ class DataConfig(base_config.Config):
performance. performance.
seed: An optional seed to use for deterministic shuffling/preprocessing. seed: An optional seed to use for deterministic shuffling/preprocessing.
""" """
input_path: Union[Sequence[str], str] = "" input_path: Union[Sequence[str], str, base_config.Config] = ""
tfds_name: str = "" tfds_name: str = ""
tfds_split: str = "" tfds_split: str = ""
global_batch_size: int = 0 global_batch_size: int = 0
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
"""A common dataset reader.""" """A common dataset reader."""
import random import random
from typing import Any, Callable, List, Optional from typing import Any, Callable, List, Optional, Union, Dict, Sequence
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
...@@ -45,6 +45,7 @@ class InputReader: ...@@ -45,6 +45,7 @@ class InputReader:
params: cfg.DataConfig, params: cfg.DataConfig,
dataset_fn=tf.data.TFRecordDataset, dataset_fn=tf.data.TFRecordDataset,
decoder_fn: Optional[Callable[..., Any]] = None, decoder_fn: Optional[Callable[..., Any]] = None,
combine_fn: Optional[Callable[..., Any]] = None,
sample_fn: Optional[Callable[..., Any]] = None, sample_fn: Optional[Callable[..., Any]] = None,
parser_fn: Optional[Callable[..., Any]] = None, parser_fn: Optional[Callable[..., Any]] = None,
transform_and_batch_fn: Optional[Callable[ transform_and_batch_fn: Optional[Callable[
...@@ -59,6 +60,9 @@ class InputReader: ...@@ -59,6 +60,9 @@ class InputReader:
example, it can be `tf.data.TFRecordDataset`. example, it can be `tf.data.TFRecordDataset`.
decoder_fn: An optional `callable` that takes the serialized data string decoder_fn: An optional `callable` that takes the serialized data string
and decodes them into the raw tensor dictionary. and decodes them into the raw tensor dictionary.
combine_fn: An optional `callable` that takes a dictionarty of
`tf.data.Dataset` objects as input and outputs a combined dataset. It
will be executed after the decoder_fn and before the sample_fn.
sample_fn: An optional `callable` that takes a `tf.data.Dataset` object as sample_fn: An optional `callable` that takes a `tf.data.Dataset` object as
input and outputs the transformed dataset. It performs sampling on the input and outputs the transformed dataset. It performs sampling on the
decoded raw tensors dict before the parser_fn. decoded raw tensors dict before the parser_fn.
...@@ -78,10 +82,23 @@ class InputReader: ...@@ -78,10 +82,23 @@ class InputReader:
raise ValueError('At most one of `input_path` and `tfds_name` can be ' raise ValueError('At most one of `input_path` and `tfds_name` can be '
'specified, but got %s and %s.' % 'specified, but got %s and %s.' %
(params.input_path, params.tfds_name)) (params.input_path, params.tfds_name))
if isinstance(params.input_path,
cfg.base_config.Config) and combine_fn is None:
raise ValueError(
'A `combine_fn` is required if the `input_path` is a dictionary.')
self._tfds_builder = None self._tfds_builder = None
self._matched_files = [] self._matched_files = None
if params.input_path: if params.input_path:
self._matched_files = self._match_files(params.input_path) # we want to combine / mix datasets
if isinstance(params.input_path, cfg.base_config.Config):
self._matched_files = {}
for k, v in params.input_path.as_dict().items():
self._matched_files[k] = self._match_files(v)
# single dataset
else:
self._matched_files = self._match_files(params.input_path)
else: else:
# Read dataset from TFDS. # Read dataset from TFDS.
if not params.tfds_split: if not params.tfds_split:
...@@ -106,6 +123,7 @@ class InputReader: ...@@ -106,6 +123,7 @@ class InputReader:
self._dataset_fn = dataset_fn self._dataset_fn = dataset_fn
self._decoder_fn = decoder_fn self._decoder_fn = decoder_fn
self._combine_fn = combine_fn
self._sample_fn = sample_fn self._sample_fn = sample_fn
self._parser_fn = parser_fn self._parser_fn = parser_fn
self._transform_and_batch_fn = transform_and_batch_fn self._transform_and_batch_fn = transform_and_batch_fn
...@@ -131,7 +149,7 @@ class InputReader: ...@@ -131,7 +149,7 @@ class InputReader:
self._enable_round_robin_tf_data_service = params.get( self._enable_round_robin_tf_data_service = params.get(
'enable_round_robin_tf_data_service', False) 'enable_round_robin_tf_data_service', False)
def _match_files(self, input_path: str) -> List[str]: def _match_files(self, input_path: Union[Sequence[str], str]) -> List[str]:
"""Matches files from an input_path.""" """Matches files from an input_path."""
matched_files = [] matched_files = []
# Read dataset from files. # Read dataset from files.
...@@ -195,8 +213,8 @@ class InputReader: ...@@ -195,8 +213,8 @@ class InputReader:
# Do not enable sharding if tf.data service is enabled, as sharding will be # Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service. # handled inside tf.data service.
if self._sharding and input_context and ( if self._sharding and input_context and (input_context.num_input_pipelines >
input_context.num_input_pipelines > 1): 1):
dataset = dataset.shard(input_context.num_input_pipelines, dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id) input_context.input_pipeline_id)
...@@ -231,8 +249,8 @@ class InputReader: ...@@ -231,8 +249,8 @@ class InputReader:
dataset = dataset.with_options(options) dataset = dataset.with_options(options)
# Do not enable sharding if tf.data service is enabled, as sharding will be # Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service. # handled inside tf.data service.
if self._sharding and input_context and ( if self._sharding and input_context and (input_context.num_input_pipelines >
input_context.num_input_pipelines > 1): 1):
dataset = dataset.shard(input_context.num_input_pipelines, dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id) input_context.input_pipeline_id)
...@@ -281,42 +299,53 @@ class InputReader: ...@@ -281,42 +299,53 @@ class InputReader:
def _read_decode_and_parse_dataset( def _read_decode_and_parse_dataset(
self, self,
matched_files: List[str], matched_files: Union[Dict[str, List[str]], List[str]],
dataset_fn, dataset_fn,
batch_size: int, batch_size: int,
input_context: Optional[tf.distribute.InputContext] = None, input_context: Optional[tf.distribute.InputContext] = None,
tfds_builder: bool = False) -> tf.data.Dataset: tfds_builder: bool = False) -> tf.data.Dataset:
"""Returns a tf.data.Dataset object after reading, decoding, and parsing.""" """Returns a tf.data.Dataset object after reading, decoding, and parsing."""
def _files_to_dataset(files: List[str]) -> tf.data.Dataset:
if len(files) > 1:
if input_context and (len(files) < input_context.num_input_pipelines):
logging.warn(
'The number of files %d is less than the number of input pipelines '
'%d. We will send all input files to every worker. '
'Please consider sharding your data into more files.', len(files),
input_context.num_input_pipelines)
return self._read_files_then_shard(files, dataset_fn, input_context)
else:
return self._shard_files_then_read(files, dataset_fn, input_context)
elif len(files) == 1:
return self._read_files_then_shard(files, dataset_fn, input_context)
else:
raise ValueError('It is unexpected that `tfds_builder` is None and '
'there is also no `files`.')
def _shuffle_and_decode(ds):
# If cache is enabled, we will call `shuffle()` later after `cache()`.
if self._is_training and not self._cache:
ds = ds.shuffle(self._shuffle_buffer_size, seed=self._seed)
# Decode
ds = _maybe_map_fn(ds, self._decoder_fn)
return ds
if tfds_builder: if tfds_builder:
dataset = self._read_tfds(input_context) dataset = self._read_tfds(input_context)
elif len(matched_files) > 1: dataset = _shuffle_and_decode(dataset)
if input_context and (len(matched_files) < elif isinstance(matched_files, (list, tuple)):
input_context.num_input_pipelines): dataset = _files_to_dataset(matched_files)
logging.warn( dataset = _shuffle_and_decode(dataset)
'The number of files %d is less than the number of input pipelines ' elif isinstance(matched_files, dict):
'%d. We will send all input files to every worker. ' datasets = {}
'Please consider sharding your data into more files.', for k, fs in matched_files.items():
len(matched_files), input_context.num_input_pipelines) datasets[k] = _files_to_dataset(fs)
dataset = self._read_files_then_shard(matched_files, datasets[k] = _shuffle_and_decode(datasets[k])
dataset_fn, dataset = self._combine_fn(datasets)
input_context)
else:
dataset = self._shard_files_then_read(matched_files,
dataset_fn,
input_context)
elif len(matched_files) == 1:
dataset = self._read_files_then_shard(matched_files,
dataset_fn,
input_context)
else: else:
raise ValueError('It is unexpected that `tfds_builder` is None and ' raise ValueError('`matched_files` should be a list or dict.')
'there is also no `matched_files`.')
# If cache is enabled, we will call `shuffle()` later after `cache()`.
if self._is_training and not self._cache:
dataset = dataset.shuffle(self._shuffle_buffer_size, seed=self._seed)
dataset = _maybe_map_fn(dataset, self._decoder_fn)
if self._sample_fn is not None: if self._sample_fn is not None:
dataset = dataset.apply(self._sample_fn) dataset = dataset.apply(self._sample_fn)
dataset = _maybe_map_fn(dataset, self._parser_fn) dataset = _maybe_map_fn(dataset, self._parser_fn)
...@@ -333,8 +362,7 @@ class InputReader: ...@@ -333,8 +362,7 @@ class InputReader:
per_replica_batch_size = input_context.get_per_replica_batch_size( per_replica_batch_size = input_context.get_per_replica_batch_size(
batch_size) if input_context else batch_size batch_size) if input_context else batch_size
dataset = dataset.batch( dataset = dataset.batch(
per_replica_batch_size, drop_remainder=self._drop_remainder per_replica_batch_size, drop_remainder=self._drop_remainder)
)
return dataset return dataset
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment