Commit 5a3af75c authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Support to read a dataset from TFDS.

PiperOrigin-RevId: 315774221
parent 8a1dbbad
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
from typing import Any, Callable, List, Optional from typing import Any, Callable, List, Optional
import tensorflow as tf import tensorflow as tf
import tensorflow_datasets as tfds
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
...@@ -53,11 +54,15 @@ class InputReader: ...@@ -53,11 +54,15 @@ class InputReader:
postprocess_fn: A optional `callable` that processes batched tensors. It postprocess_fn: A optional `callable` that processes batched tensors. It
will be executed after batching. will be executed after batching.
""" """
# TODO(chendouble): Support TFDS as input_path. if params.input_path and params.tfds_name:
raise ValueError('At most one of `input_path` and `tfds_name` can be '
'specified, but got %s and %s.' % (
params.input_path, params.tfds_name))
self._shards = shards self._shards = shards
self._tfds_builder = None
if self._shards: if self._shards:
self._num_files = len(self._shards) self._num_files = len(self._shards)
else: elif not params.tfds_name:
self._input_patterns = params.input_path.strip().split(',') self._input_patterns = params.input_path.strip().split(',')
self._num_files = 0 self._num_files = 0
for input_pattern in self._input_patterns: for input_pattern in self._input_patterns:
...@@ -71,6 +76,13 @@ class InputReader: ...@@ -71,6 +76,13 @@ class InputReader:
self._num_files += len(matched_files) self._num_files += len(matched_files)
if self._num_files == 0: if self._num_files == 0:
raise ValueError('%s does not match any files.' % params.input_path) raise ValueError('%s does not match any files.' % params.input_path)
else:
if not params.tfds_split:
raise ValueError(
'`tfds_name` is %s, but `tfds_split` is not specified.' %
params.tfds_name)
self._tfds_builder = tfds.builder(
params.tfds_name, data_dir=params.tfds_data_dir)
self._global_batch_size = params.global_batch_size self._global_batch_size = params.global_batch_size
self._is_training = params.is_training self._is_training = params.is_training
...@@ -78,8 +90,13 @@ class InputReader: ...@@ -78,8 +90,13 @@ class InputReader:
self._shuffle_buffer_size = params.shuffle_buffer_size self._shuffle_buffer_size = params.shuffle_buffer_size
self._cache = params.cache self._cache = params.cache
self._cycle_length = params.cycle_length self._cycle_length = params.cycle_length
self._block_length = params.block_length
self._sharding = params.sharding self._sharding = params.sharding
self._examples_consume = params.examples_consume self._examples_consume = params.examples_consume
self._tfds_split = params.tfds_split
self._tfds_download = params.tfds_download
self._tfds_as_supervised = params.tfds_as_supervised
self._tfds_skip_decoding_feature = params.tfds_skip_decoding_feature
self._dataset_fn = dataset_fn self._dataset_fn = dataset_fn
self._decoder_fn = decoder_fn self._decoder_fn = decoder_fn
...@@ -107,6 +124,7 @@ class InputReader: ...@@ -107,6 +124,7 @@ class InputReader:
dataset = dataset.interleave( dataset = dataset.interleave(
map_func=self._dataset_fn, map_func=self._dataset_fn,
cycle_length=self._cycle_length, cycle_length=self._cycle_length,
block_length=self._block_length,
num_parallel_calls=tf.data.experimental.AUTOTUNE) num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset return dataset
...@@ -131,12 +149,47 @@ class InputReader: ...@@ -131,12 +149,47 @@ class InputReader:
dataset = dataset.repeat() dataset = dataset.repeat()
return dataset return dataset
def _read_tfds(
self,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Reads a dataset from tfds."""
if self._tfds_download:
self._tfds_builder.download_and_prepare()
read_config = tfds.ReadConfig(
interleave_cycle_length=self._cycle_length,
interleave_block_length=self._block_length,
input_context=input_context)
decoders = {}
if self._tfds_skip_decoding_feature:
for skip_feature in self._tfds_skip_decoding_feature.split(','):
decoders[skip_feature.strip()] = tfds.decode.SkipDecoding()
dataset = self._tfds_builder.as_dataset(
split=self._tfds_split,
shuffle_files=self._is_training,
as_supervised=self._tfds_as_supervised,
decoders=decoders,
read_config=read_config)
return dataset
@property
def tfds_info(self) -> tfds.core.DatasetInfo:
"""Returns TFDS dataset info, if available."""
if self._tfds_builder:
return self._tfds_builder.info
else:
raise ValueError('tfds_info is not available, because the dataset '
'is not loaded from tfds.')
def read( def read(
self, self,
input_context: Optional[tf.distribute.InputContext] = None input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset: ) -> tf.data.Dataset:
"""Generates a tf.data.Dataset object.""" """Generates a tf.data.Dataset object."""
if self._num_files > 1: if self._tfds_builder:
dataset = self._read_tfds(input_context)
elif self._num_files > 1:
dataset = self._read_sharded_files(input_context) dataset = self._read_sharded_files(input_context)
else: else:
assert self._num_files == 1 assert self._num_files == 1
......
...@@ -31,7 +31,12 @@ class DataConfig(base_config.Config): ...@@ -31,7 +31,12 @@ class DataConfig(base_config.Config):
Attributes: Attributes:
input_path: The path to the input. It can be either (1) a file pattern, or input_path: The path to the input. It can be either (1) a file pattern, or
(2) multiple file patterns separated by comma. (2) multiple file patterns separated by comma. It should not be specified
when the following `tfds_name` is specified.
tfds_name: The name of the tensorflow dataset (TFDS). It should not be
specified when the above `input_path` is specified.
tfds_split: A str indicating which split of the data to load from TFDS. It
is required when above `tfds_name` is specified.
global_batch_size: The global batch size across all replicas. global_batch_size: The global batch size across all replicas.
is_training: Whether this data is used for training or not. is_training: Whether this data is used for training or not.
drop_remainder: Whether the last batch should be dropped in the case it has drop_remainder: Whether the last batch should be dropped in the case it has
...@@ -41,21 +46,40 @@ class DataConfig(base_config.Config): ...@@ -41,21 +46,40 @@ class DataConfig(base_config.Config):
from disk on the second epoch. Requires significant memory overhead. from disk on the second epoch. Requires significant memory overhead.
cycle_length: The number of files that will be processed concurrently when cycle_length: The number of files that will be processed concurrently when
interleaving files. interleaving files.
block_length: The number of consecutive elements to produce from each input
element before cycling to another input element when interleaving files.
sharding: Whether sharding is used in the input pipeline. sharding: Whether sharding is used in the input pipeline.
examples_consume: An `integer` specifying the number of examples it will examples_consume: An `integer` specifying the number of examples it will
produce. If positive, it only takes this number of examples and raises produce. If positive, it only takes this number of examples and raises
tf.error.OutOfRangeError after that. Default is -1, meaning it will tf.error.OutOfRangeError after that. Default is -1, meaning it will
exhaust all the examples in the dataset. exhaust all the examples in the dataset.
tfds_data_dir: A str specifying the directory to read/write TFDS data.
tfds_download: A bool to indicate whether to download data using TFDS.
tfds_as_supervised: A bool. When loading dataset from TFDS, if True,
the returned tf.data.Dataset will have a 2-tuple structure (input, label)
according to builder.info.supervised_keys; if False, the default,
the returned tf.data.Dataset will have a dictionary with all the features.
tfds_skip_decoding_feature: A str to indicate which features are skipped
for decoding when loading dataset from TFDS. Use comma to separate
multiple features. The main use case is to skip the image/video decoding
for better performance.
""" """
input_path: str = "" input_path: str = ""
tfds_name: str = ""
tfds_split: str = ""
global_batch_size: int = 0 global_batch_size: int = 0
is_training: bool = None is_training: bool = None
drop_remainder: bool = True drop_remainder: bool = True
shuffle_buffer_size: int = 100 shuffle_buffer_size: int = 100
cache: bool = False cache: bool = False
cycle_length: int = 8 cycle_length: int = 8
block_length: int = 1
sharding: bool = True sharding: bool = True
examples_consume: int = -1 examples_consume: int = -1
tfds_data_dir: str = ""
tfds_download: bool = False
tfds_as_supervised: bool = False
tfds_skip_decoding_feature: str = ""
@dataclasses.dataclass @dataclasses.dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment