"...resnet50_tensorflow.git" did not exist on "9cc7eac1eb2b33e79611789827da0b2376ca8954"
Commit 88b2a354 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Adds a feature to process a dictionary for `input_paths` in `DataConfig`, to...

Adds a feature to process a dictionary for `input_paths` in `DataConfig`, to allow combining multiple datasets using a user defined combine_fn.

PiperOrigin-RevId: 381363688
parent 50ebc683
......@@ -29,12 +29,13 @@ class DataConfig(base_config.Config):
"""The base configuration for building datasets.
Attributes:
input_path: The path to the input. It can be either (1) a str indicating
a file path/pattern, or (2) a str indicating multiple file paths/patterns
separated by comma (e.g "a, b, c" or no spaces "a,b,c"), or
(3) a list of str, each of which is a file path/pattern or multiple file
paths/patterns separated by comma.
It should not be specified when the following `tfds_name` is specified.
input_path: The path to the input. It can be either (1) a str indicating a
file path/pattern, or (2) a str indicating multiple file paths/patterns
separated by comma (e.g "a, b, c" or no spaces "a,b,c"), or (3) a list of
str, each of which is a file path/pattern or multiple file paths/patterns
separated by comma, or (4) a dictionary of the previous three approaches
for more advanced data mixing using named access. It should not be
specified when the following `tfds_name` is specified.
tfds_name: The name of the tensorflow dataset (TFDS). It should not be
specified when the above `input_path` is specified.
tfds_split: A str indicating which split of the data to load from TFDS. It
......@@ -46,8 +47,8 @@ class DataConfig(base_config.Config):
shuffle_buffer_size: The buffer size used for shuffling training data.
cache: Whether to cache dataset examples. If `True`, we will cache the
dataset after applying the decode_fn and parse_fn. It can be used to avoid
re-reading from disk, re-decoding and re-parsing the example on the
second epoch, but it requires significant memory overhead.
re-reading from disk, re-decoding and re-parsing the example on the second
epoch, but it requires significant memory overhead.
cycle_length: The number of files that will be processed concurrently when
interleaving files.
block_length: The number of consecutive elements to produce from each input
......@@ -59,11 +60,10 @@ class DataConfig(base_config.Config):
tf_data_service_address: The URI of a tf.data service to offload
preprocessing onto during training. The URI should be in the format
"protocol://address", e.g. "grpc://tf-data-service:5050". It can be
overridden by `FLAGS.tf_data_service` flag in the binary.
tf_data_service_job_name: The name of the tf.data service job. This
argument makes it possible for multiple datasets to share the same job.
The default behavior is that the dataset creates anonymous, exclusively
owned jobs.
overridden by `FLAGS.tf_data_service` flag in the binary.
tf_data_service_job_name: The name of the tf.data service job. This argument
makes it possible for multiple datasets to share the same job. The default
behavior is that the dataset creates anonymous, exclusively owned jobs.
tfds_data_dir: A str specifying the directory to read/write TFDS data.
tfds_as_supervised: A bool. When loading dataset from TFDS, if True, the
returned tf.data.Dataset will have a 2-tuple structure (input, label)
......@@ -75,7 +75,7 @@ class DataConfig(base_config.Config):
performance.
seed: An optional seed to use for deterministic shuffling/preprocessing.
"""
input_path: Union[Sequence[str], str] = ""
input_path: Union[Sequence[str], str, base_config.Config] = ""
tfds_name: str = ""
tfds_split: str = ""
global_batch_size: int = 0
......
......@@ -14,7 +14,7 @@
"""A common dataset reader."""
import random
from typing import Any, Callable, List, Optional
from typing import Any, Callable, List, Optional, Union, Dict, Sequence
from absl import logging
import tensorflow as tf
......@@ -45,6 +45,7 @@ class InputReader:
params: cfg.DataConfig,
dataset_fn=tf.data.TFRecordDataset,
decoder_fn: Optional[Callable[..., Any]] = None,
combine_fn: Optional[Callable[..., Any]] = None,
sample_fn: Optional[Callable[..., Any]] = None,
parser_fn: Optional[Callable[..., Any]] = None,
transform_and_batch_fn: Optional[Callable[
......@@ -59,6 +60,9 @@ class InputReader:
example, it can be `tf.data.TFRecordDataset`.
decoder_fn: An optional `callable` that takes the serialized data string
and decodes them into the raw tensor dictionary.
combine_fn: An optional `callable` that takes a dictionarty of
`tf.data.Dataset` objects as input and outputs a combined dataset. It
will be executed after the decoder_fn and before the sample_fn.
sample_fn: An optional `callable` that takes a `tf.data.Dataset` object as
input and outputs the transformed dataset. It performs sampling on the
decoded raw tensors dict before the parser_fn.
......@@ -78,10 +82,23 @@ class InputReader:
raise ValueError('At most one of `input_path` and `tfds_name` can be '
'specified, but got %s and %s.' %
(params.input_path, params.tfds_name))
if isinstance(params.input_path,
cfg.base_config.Config) and combine_fn is None:
raise ValueError(
'A `combine_fn` is required if the `input_path` is a dictionary.')
self._tfds_builder = None
self._matched_files = []
self._matched_files = None
if params.input_path:
self._matched_files = self._match_files(params.input_path)
# we want to combine / mix datasets
if isinstance(params.input_path, cfg.base_config.Config):
self._matched_files = {}
for k, v in params.input_path.as_dict().items():
self._matched_files[k] = self._match_files(v)
# single dataset
else:
self._matched_files = self._match_files(params.input_path)
else:
# Read dataset from TFDS.
if not params.tfds_split:
......@@ -106,6 +123,7 @@ class InputReader:
self._dataset_fn = dataset_fn
self._decoder_fn = decoder_fn
self._combine_fn = combine_fn
self._sample_fn = sample_fn
self._parser_fn = parser_fn
self._transform_and_batch_fn = transform_and_batch_fn
......@@ -131,7 +149,7 @@ class InputReader:
self._enable_round_robin_tf_data_service = params.get(
'enable_round_robin_tf_data_service', False)
def _match_files(self, input_path: str) -> List[str]:
def _match_files(self, input_path: Union[Sequence[str], str]) -> List[str]:
"""Matches files from an input_path."""
matched_files = []
# Read dataset from files.
......@@ -195,8 +213,8 @@ class InputReader:
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
if self._sharding and input_context and (
input_context.num_input_pipelines > 1):
if self._sharding and input_context and (input_context.num_input_pipelines >
1):
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
......@@ -231,8 +249,8 @@ class InputReader:
dataset = dataset.with_options(options)
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
if self._sharding and input_context and (
input_context.num_input_pipelines > 1):
if self._sharding and input_context and (input_context.num_input_pipelines >
1):
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
......@@ -281,42 +299,53 @@ class InputReader:
def _read_decode_and_parse_dataset(
self,
matched_files: List[str],
matched_files: Union[Dict[str, List[str]], List[str]],
dataset_fn,
batch_size: int,
input_context: Optional[tf.distribute.InputContext] = None,
tfds_builder: bool = False) -> tf.data.Dataset:
"""Returns a tf.data.Dataset object after reading, decoding, and parsing."""
def _files_to_dataset(files: List[str]) -> tf.data.Dataset:
if len(files) > 1:
if input_context and (len(files) < input_context.num_input_pipelines):
logging.warn(
'The number of files %d is less than the number of input pipelines '
'%d. We will send all input files to every worker. '
'Please consider sharding your data into more files.', len(files),
input_context.num_input_pipelines)
return self._read_files_then_shard(files, dataset_fn, input_context)
else:
return self._shard_files_then_read(files, dataset_fn, input_context)
elif len(files) == 1:
return self._read_files_then_shard(files, dataset_fn, input_context)
else:
raise ValueError('It is unexpected that `tfds_builder` is None and '
'there is also no `files`.')
def _shuffle_and_decode(ds):
# If cache is enabled, we will call `shuffle()` later after `cache()`.
if self._is_training and not self._cache:
ds = ds.shuffle(self._shuffle_buffer_size, seed=self._seed)
# Decode
ds = _maybe_map_fn(ds, self._decoder_fn)
return ds
if tfds_builder:
dataset = self._read_tfds(input_context)
elif len(matched_files) > 1:
if input_context and (len(matched_files) <
input_context.num_input_pipelines):
logging.warn(
'The number of files %d is less than the number of input pipelines '
'%d. We will send all input files to every worker. '
'Please consider sharding your data into more files.',
len(matched_files), input_context.num_input_pipelines)
dataset = self._read_files_then_shard(matched_files,
dataset_fn,
input_context)
else:
dataset = self._shard_files_then_read(matched_files,
dataset_fn,
input_context)
elif len(matched_files) == 1:
dataset = self._read_files_then_shard(matched_files,
dataset_fn,
input_context)
dataset = _shuffle_and_decode(dataset)
elif isinstance(matched_files, (list, tuple)):
dataset = _files_to_dataset(matched_files)
dataset = _shuffle_and_decode(dataset)
elif isinstance(matched_files, dict):
datasets = {}
for k, fs in matched_files.items():
datasets[k] = _files_to_dataset(fs)
datasets[k] = _shuffle_and_decode(datasets[k])
dataset = self._combine_fn(datasets)
else:
raise ValueError('It is unexpected that `tfds_builder` is None and '
'there is also no `matched_files`.')
# If cache is enabled, we will call `shuffle()` later after `cache()`.
if self._is_training and not self._cache:
dataset = dataset.shuffle(self._shuffle_buffer_size, seed=self._seed)
raise ValueError('`matched_files` should be a list or dict.')
dataset = _maybe_map_fn(dataset, self._decoder_fn)
if self._sample_fn is not None:
dataset = dataset.apply(self._sample_fn)
dataset = _maybe_map_fn(dataset, self._parser_fn)
......@@ -333,8 +362,7 @@ class InputReader:
per_replica_batch_size = input_context.get_per_replica_batch_size(
batch_size) if input_context else batch_size
dataset = dataset.batch(
per_replica_batch_size, drop_remainder=self._drop_remainder
)
per_replica_batch_size, drop_remainder=self._drop_remainder)
return dataset
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment