Commit 5a3af75c authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Support to read a dataset from TFDS.

PiperOrigin-RevId: 315774221
parent 8a1dbbad
......@@ -18,6 +18,7 @@
from typing import Any, Callable, List, Optional
import tensorflow as tf
import tensorflow_datasets as tfds
from official.modeling.hyperparams import config_definitions as cfg
......@@ -53,11 +54,15 @@ class InputReader:
postprocess_fn: A optional `callable` that processes batched tensors. It
will be executed after batching.
"""
# TODO(chendouble): Support TFDS as input_path.
if params.input_path and params.tfds_name:
raise ValueError('At most one of `input_path` and `tfds_name` can be '
'specified, but got %s and %s.' % (
params.input_path, params.tfds_name))
self._shards = shards
self._tfds_builder = None
if self._shards:
self._num_files = len(self._shards)
else:
elif not params.tfds_name:
self._input_patterns = params.input_path.strip().split(',')
self._num_files = 0
for input_pattern in self._input_patterns:
......@@ -71,6 +76,13 @@ class InputReader:
self._num_files += len(matched_files)
if self._num_files == 0:
raise ValueError('%s does not match any files.' % params.input_path)
else:
if not params.tfds_split:
raise ValueError(
'`tfds_name` is %s, but `tfds_split` is not specified.' %
params.tfds_name)
self._tfds_builder = tfds.builder(
params.tfds_name, data_dir=params.tfds_data_dir)
self._global_batch_size = params.global_batch_size
self._is_training = params.is_training
......@@ -78,8 +90,13 @@ class InputReader:
self._shuffle_buffer_size = params.shuffle_buffer_size
self._cache = params.cache
self._cycle_length = params.cycle_length
self._block_length = params.block_length
self._sharding = params.sharding
self._examples_consume = params.examples_consume
self._tfds_split = params.tfds_split
self._tfds_download = params.tfds_download
self._tfds_as_supervised = params.tfds_as_supervised
self._tfds_skip_decoding_feature = params.tfds_skip_decoding_feature
self._dataset_fn = dataset_fn
self._decoder_fn = decoder_fn
......@@ -107,6 +124,7 @@ class InputReader:
dataset = dataset.interleave(
map_func=self._dataset_fn,
cycle_length=self._cycle_length,
block_length=self._block_length,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
......@@ -131,12 +149,47 @@ class InputReader:
dataset = dataset.repeat()
return dataset
def _read_tfds(
self,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Reads a dataset from tfds."""
if self._tfds_download:
self._tfds_builder.download_and_prepare()
read_config = tfds.ReadConfig(
interleave_cycle_length=self._cycle_length,
interleave_block_length=self._block_length,
input_context=input_context)
decoders = {}
if self._tfds_skip_decoding_feature:
for skip_feature in self._tfds_skip_decoding_feature.split(','):
decoders[skip_feature.strip()] = tfds.decode.SkipDecoding()
dataset = self._tfds_builder.as_dataset(
split=self._tfds_split,
shuffle_files=self._is_training,
as_supervised=self._tfds_as_supervised,
decoders=decoders,
read_config=read_config)
return dataset
@property
def tfds_info(self) -> tfds.core.DatasetInfo:
"""Returns TFDS dataset info, if available."""
if self._tfds_builder:
return self._tfds_builder.info
else:
raise ValueError('tfds_info is not available, because the dataset '
'is not loaded from tfds.')
def read(
self,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Generates a tf.data.Dataset object."""
if self._num_files > 1:
if self._tfds_builder:
dataset = self._read_tfds(input_context)
elif self._num_files > 1:
dataset = self._read_sharded_files(input_context)
else:
assert self._num_files == 1
......
......@@ -31,7 +31,12 @@ class DataConfig(base_config.Config):
Attributes:
input_path: The path to the input. It can be either (1) a file pattern, or
(2) multiple file patterns separated by comma.
(2) multiple file patterns separated by comma. It should not be specified
when the following `tfds_name` is specified.
tfds_name: The name of the tensorflow dataset (TFDS). It should not be
specified when the above `input_path` is specified.
tfds_split: A str indicating which split of the data to load from TFDS. It
is required when above `tfds_name` is specified.
global_batch_size: The global batch size across all replicas.
is_training: Whether this data is used for training or not.
drop_remainder: Whether the last batch should be dropped in the case it has
......@@ -41,21 +46,40 @@ class DataConfig(base_config.Config):
from disk on the second epoch. Requires significant memory overhead.
cycle_length: The number of files that will be processed concurrently when
interleaving files.
block_length: The number of consecutive elements to produce from each input
element before cycling to another input element when interleaving files.
sharding: Whether sharding is used in the input pipeline.
examples_consume: An `integer` specifying the number of examples it will
produce. If positive, it only takes this number of examples and raises
tf.error.OutOfRangeError after that. Default is -1, meaning it will
exhaust all the examples in the dataset.
tfds_data_dir: A str specifying the directory to read/write TFDS data.
tfds_download: A bool to indicate whether to download data using TFDS.
tfds_as_supervised: A bool. When loading dataset from TFDS, if True,
the returned tf.data.Dataset will have a 2-tuple structure (input, label)
according to builder.info.supervised_keys; if False, the default,
the returned tf.data.Dataset will have a dictionary with all the features.
tfds_skip_decoding_feature: A str to indicate which features are skipped
for decoding when loading dataset from TFDS. Use comma to separate
multiple features. The main use case is to skip the image/video decoding
for better performance.
"""
input_path: str = ""
tfds_name: str = ""
tfds_split: str = ""
global_batch_size: int = 0
is_training: bool = None
drop_remainder: bool = True
shuffle_buffer_size: int = 100
cache: bool = False
cycle_length: int = 8
block_length: int = 1
sharding: bool = True
examples_consume: int = -1
tfds_data_dir: str = ""
tfds_download: bool = False
tfds_as_supervised: bool = False
tfds_skip_decoding_feature: str = ""
@dataclasses.dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment