Commit 0f0b060c authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Refactor input reader

PiperOrigin-RevId: 398593113
parent 201d523a
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
"""A common dataset reader.""" """A common dataset reader."""
import random import random
from typing import Any, Callable, List, Optional, Union, Dict, Sequence from typing import Any, Callable, Dict, List, Optional, Sequence, Text, Union
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
...@@ -34,6 +34,154 @@ def _maybe_map_fn(dataset: tf.data.Dataset, ...@@ -34,6 +34,154 @@ def _maybe_map_fn(dataset: tf.data.Dataset,
fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
def match_files(input_path: Union[Sequence[str], str]) -> List[str]:
"""Matches files from an input_path."""
matched_files = []
# Read dataset from files.
usage = ('`input_path` should be either (1) a str indicating a file '
'path/pattern, or (2) a str indicating multiple file '
'paths/patterns separated by comma (e.g "a, b, c" or no spaces '
'"a,b,c", or (3) a list of str, each of which is a file '
'path/pattern or multiple file paths/patterns separated by '
'comma, but got: %s')
if isinstance(input_path, str):
input_path_list = [input_path]
elif isinstance(input_path, (list, tuple)):
if any(not isinstance(x, str) for x in input_path):
raise ValueError(usage % input_path)
input_path_list = input_path
else:
raise ValueError(usage % input_path)
for input_path in input_path_list:
input_patterns = input_path.strip().split(',')
for input_pattern in input_patterns:
input_pattern = input_pattern.strip()
if not input_pattern:
continue
if '*' in input_pattern or '?' in input_pattern:
tmp_matched_files = tf.io.gfile.glob(input_pattern)
if not tmp_matched_files:
raise ValueError('%s does not match any files.' % input_pattern)
matched_files.extend(tmp_matched_files)
else:
matched_files.append(input_pattern)
if not matched_files:
raise ValueError('%s does not match any files.' % input_path)
return matched_files
def _read_files_then_shard(matched_files: List[str],
dataset_fn,
input_context: Optional[
tf.distribute.InputContext] = None,
sharding: bool = False,
repeat: bool = False) -> tf.data.Dataset:
"""Sends all data files to every worker and then shard by data."""
dataset = dataset_fn(matched_files)
# When `input_file` is a path to a single file or the number of files is
# less than the number of input pipelines, disable auto sharding
# so that same input file is sent to all workers.
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = (
tf.data.experimental.AutoShardPolicy.OFF)
dataset = dataset.with_options(options)
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
if sharding and input_context and (input_context.num_input_pipelines > 1):
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
if repeat:
dataset = dataset.repeat()
return dataset
def _shard_files_then_read(matched_files: List[str],
dataset_fn,
input_context: Optional[
tf.distribute.InputContext] = None,
seed: Optional[Union[int, tf.Tensor]] = None,
is_training: bool = False,
sharding: bool = False,
cache: bool = False,
cycle_length: Optional[int] = None,
block_length: Optional[int] = None,
deterministic: bool = False) -> tf.data.Dataset:
"""Shards the data files and then sent a split to every worker to read."""
dataset = tf.data.Dataset.from_tensor_slices(matched_files)
# Shuffle and repeat at file level.
# If cache is enabled, `reshuffle_each_iteration` is set to False,
# because we will read the same cached data in every iteration anyway.
if is_training:
# We need a seed to shuffle the files so that when each TPU workers gets
# its own shard the files do not overlap.
if sharding and seed is None:
seed = _get_random_integer()
dataset = dataset.shuffle(
len(matched_files),
seed=seed,
reshuffle_each_iteration=True if not cache else False)
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
if sharding and input_context and (input_context.num_input_pipelines > 1):
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
# If cache is enabled, we will call `repeat()` later after `cache()`.
if is_training and not cache:
dataset = dataset.repeat()
dataset = dataset.interleave(
map_func=dataset_fn,
cycle_length=cycle_length,
block_length=block_length,
num_parallel_calls=(cycle_length
if cycle_length else tf.data.experimental.AUTOTUNE),
deterministic=deterministic)
return dataset
def _read_tfds(tfds_builder: tfds.core.DatasetBuilder,
tfds_split: Text,
tfds_skip_decoding_feature: Text,
tfds_as_supervised: bool,
input_context: Optional[tf.distribute.InputContext] = None,
seed: Optional[Union[int, tf.Tensor]] = None,
is_training: bool = False,
cache: bool = False,
cycle_length: Optional[int] = None,
block_length: Optional[int] = None) -> tf.data.Dataset:
"""Reads a dataset from tfds."""
# No op if exist.
tfds_builder.download_and_prepare()
read_config = tfds.ReadConfig(
interleave_cycle_length=cycle_length,
interleave_block_length=block_length,
input_context=input_context,
shuffle_seed=seed)
decoders = {}
if tfds_skip_decoding_feature:
for skip_feature in tfds_skip_decoding_feature.split(','):
decoders[skip_feature.strip()] = tfds.decode.SkipDecoding()
dataset = tfds_builder.as_dataset(
split=tfds_split,
shuffle_files=is_training,
as_supervised=tfds_as_supervised,
decoders=decoders,
read_config=read_config)
if is_training and not cache:
dataset = dataset.repeat()
return dataset
class InputReader: class InputReader:
"""Input reader that returns a tf.data.Dataset instance.""" """Input reader that returns a tf.data.Dataset instance."""
...@@ -90,16 +238,7 @@ class InputReader: ...@@ -90,16 +238,7 @@ class InputReader:
self._tfds_builder = None self._tfds_builder = None
self._matched_files = None self._matched_files = None
if params.input_path: if not params.input_path:
# we want to combine / mix datasets
if isinstance(params.input_path, cfg.base_config.Config):
self._matched_files = {}
for k, v in params.input_path.as_dict().items():
self._matched_files[k] = self._match_files(v)
# single dataset
else:
self._matched_files = self._match_files(params.input_path)
else:
# Read dataset from TFDS. # Read dataset from TFDS.
if not params.tfds_split: if not params.tfds_split:
raise ValueError( raise ValueError(
...@@ -107,6 +246,8 @@ class InputReader: ...@@ -107,6 +246,8 @@ class InputReader:
params.tfds_name) params.tfds_name)
self._tfds_builder = tfds.builder( self._tfds_builder = tfds.builder(
params.tfds_name, data_dir=params.tfds_data_dir) params.tfds_name, data_dir=params.tfds_data_dir)
else:
self._matched_files = self.get_files(params.input_path)
self._global_batch_size = params.global_batch_size self._global_batch_size = params.global_batch_size
self._is_training = params.is_training self._is_training = params.is_training
...@@ -149,145 +290,6 @@ class InputReader: ...@@ -149,145 +290,6 @@ class InputReader:
self._enable_round_robin_tf_data_service = params.get( self._enable_round_robin_tf_data_service = params.get(
'enable_round_robin_tf_data_service', False) 'enable_round_robin_tf_data_service', False)
def _match_files(self, input_path: Union[Sequence[str], str]) -> List[str]:
"""Matches files from an input_path."""
matched_files = []
# Read dataset from files.
usage = ('`input_path` should be either (1) a str indicating a file '
'path/pattern, or (2) a str indicating multiple file '
'paths/patterns separated by comma (e.g "a, b, c" or no spaces '
'"a,b,c", or (3) a list of str, each of which is a file '
'path/pattern or multiple file paths/patterns separated by '
'comma, but got: %s')
if isinstance(input_path, str):
input_path_list = [input_path]
elif isinstance(input_path, (list, tuple)):
if any(not isinstance(x, str) for x in input_path):
raise ValueError(usage % input_path)
input_path_list = input_path
else:
raise ValueError(usage % input_path)
for input_path in input_path_list:
input_patterns = input_path.strip().split(',')
for input_pattern in input_patterns:
input_pattern = input_pattern.strip()
if not input_pattern:
continue
if '*' in input_pattern or '?' in input_pattern:
tmp_matched_files = tf.io.gfile.glob(input_pattern)
if not tmp_matched_files:
raise ValueError('%s does not match any files.' % input_pattern)
matched_files.extend(tmp_matched_files)
else:
matched_files.append(input_pattern)
if not matched_files:
raise ValueError('%s does not match any files.' % input_path)
return matched_files
def _shard_files_then_read(
self,
matched_files: List[str],
dataset_fn,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Shards the data files and then sent a split to every worker to read."""
dataset = tf.data.Dataset.from_tensor_slices(matched_files)
# Shuffle and repeat at file level.
# If cache is enabled, `reshuffle_each_iteration` is set to False,
# because we will read the same cached data in every iteration anyway.
if self._is_training:
# We need a seed to shuffle the files so that when each TPU workers gets
# its own shard the files do not overlap.
if self._sharding and self._seed is None:
seed = _get_random_integer()
else:
seed = self._seed
dataset = dataset.shuffle(
len(matched_files),
seed=seed,
reshuffle_each_iteration=True if not self._cache else False)
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
if self._sharding and input_context and (input_context.num_input_pipelines >
1):
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
# If cache is enabled, we will call `repeat()` later after `cache()`.
if self._is_training and not self._cache:
dataset = dataset.repeat()
dataset = dataset.interleave(
map_func=dataset_fn,
cycle_length=self._cycle_length,
block_length=self._block_length,
num_parallel_calls=(self._cycle_length if self._cycle_length else
tf.data.experimental.AUTOTUNE),
deterministic=self._deterministic)
return dataset
def _read_files_then_shard(
self,
matched_files: List[str],
dataset_fn,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Sends all data files to every worker and then shard by data."""
dataset = dataset_fn(matched_files)
# When `input_file` is a path to a single file or the number of files is
# less than the number of input pipelines, disable auto sharding
# so that same input file is sent to all workers.
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = (
tf.data.experimental.AutoShardPolicy.OFF)
dataset = dataset.with_options(options)
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
if self._sharding and input_context and (input_context.num_input_pipelines >
1):
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
# If cache is enabled, we will call `repeat()` later after `cache()`.
if self._is_training and not self._cache:
dataset = dataset.repeat()
return dataset
def _read_tfds(
self,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Reads a dataset from tfds."""
# No op if exist.
self._tfds_builder.download_and_prepare()
read_config = tfds.ReadConfig(
interleave_cycle_length=self._cycle_length,
interleave_block_length=self._block_length,
input_context=input_context,
shuffle_seed=self._seed)
decoders = {}
if self._tfds_skip_decoding_feature:
for skip_feature in self._tfds_skip_decoding_feature.split(','):
decoders[skip_feature.strip()] = tfds.decode.SkipDecoding()
dataset = self._tfds_builder.as_dataset(
split=self._tfds_split,
shuffle_files=self._is_training,
as_supervised=self._tfds_as_supervised,
decoders=decoders,
read_config=read_config)
# If cache is enabled, we will call `repeat()` later after `cache()`.
if self._is_training and not self._cache:
dataset = dataset.repeat()
return dataset
@property @property
def tfds_info(self) -> tfds.core.DatasetInfo: def tfds_info(self) -> tfds.core.DatasetInfo:
"""Returns TFDS dataset info, if available.""" """Returns TFDS dataset info, if available."""
...@@ -297,14 +299,27 @@ class InputReader: ...@@ -297,14 +299,27 @@ class InputReader:
raise ValueError('tfds_info is not available, because the dataset ' raise ValueError('tfds_info is not available, because the dataset '
'is not loaded from tfds.') 'is not loaded from tfds.')
def _read_decode_and_parse_dataset( def get_files(self, input_path):
"""Gets matched files. Can be overridden by subclasses."""
if not input_path:
return None
# we want to combine / mix datasets
if isinstance(input_path, cfg.base_config.Config):
matched_files = {}
for k, v in input_path.as_dict().items():
matched_files[k] = match_files(v)
# single dataset
else:
matched_files = match_files(input_path)
return matched_files
def _read_data_source(
self, self,
matched_files: Union[Dict[str, List[str]], List[str]], matched_files: Union[Dict[str, List[str]], List[str]],
dataset_fn, dataset_fn,
batch_size: int,
input_context: Optional[tf.distribute.InputContext] = None, input_context: Optional[tf.distribute.InputContext] = None,
tfds_builder: bool = False) -> tf.data.Dataset: tfds_builder: Optional[tfds.core.DatasetBuilder] = None):
"""Returns a tf.data.Dataset object after reading, decoding, and parsing.""" """Reads the data source (files/tfds) to a dataset."""
def _files_to_dataset(files: List[str]) -> tf.data.Dataset: def _files_to_dataset(files: List[str]) -> tf.data.Dataset:
if len(files) > 1: if len(files) > 1:
...@@ -314,15 +329,66 @@ class InputReader: ...@@ -314,15 +329,66 @@ class InputReader:
'%d. We will send all input files to every worker. ' '%d. We will send all input files to every worker. '
'Please consider sharding your data into more files.', len(files), 'Please consider sharding your data into more files.', len(files),
input_context.num_input_pipelines) input_context.num_input_pipelines)
return self._read_files_then_shard(files, dataset_fn, input_context) return _read_files_then_shard(
files,
dataset_fn,
input_context,
sharding=self._sharding,
repeat=self._is_training and not self._cache)
else: else:
return self._shard_files_then_read(files, dataset_fn, input_context) return _shard_files_then_read(
files,
dataset_fn,
input_context,
seed=self._seed,
is_training=self._is_training,
sharding=self._sharding,
cache=self._cache,
cycle_length=self._cycle_length,
block_length=self._block_length,
deterministic=self._deterministic)
elif len(files) == 1: elif len(files) == 1:
return self._read_files_then_shard(files, dataset_fn, input_context) return _read_files_then_shard(
files,
dataset_fn,
input_context,
sharding=self._sharding,
repeat=self._is_training and not self._cache)
else: else:
raise ValueError('It is unexpected that `tfds_builder` is None and ' raise ValueError('It is unexpected that `tfds_builder` is None and '
'there is also no `files`.') 'there is also no `files`.')
if tfds_builder:
dataset = _read_tfds(
tfds_builder=self._tfds_builder,
tfds_split=self._tfds_split,
tfds_skip_decoding_feature=self._tfds_skip_decoding_feature,
tfds_as_supervised=self._tfds_as_supervised,
input_context=input_context,
seed=self._seed,
is_training=self._is_training,
cache=self._cache,
cycle_length=self._cycle_length,
block_length=self._block_length)
elif isinstance(matched_files, (list, tuple)):
dataset = _files_to_dataset(matched_files)
elif isinstance(matched_files, dict):
dataset = {}
for k, fs in matched_files.items():
dataset[k] = _files_to_dataset(fs)
else:
raise ValueError('`matched_files` should be a list or dict.')
return dataset
def _decode_and_parse_dataset(
self,
dataset: Union[tf.data.Dataset, Dict[Text, tf.data.Dataset]],
batch_size: int,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Returns a tf.data.Dataset object after shuffling, decoding, and parsing."""
def _shuffle_and_decode(ds): def _shuffle_and_decode(ds):
# If cache is enabled, we will call `shuffle()` later after `cache()`. # If cache is enabled, we will call `shuffle()` later after `cache()`.
if self._is_training and not self._cache: if self._is_training and not self._cache:
...@@ -331,20 +397,9 @@ class InputReader: ...@@ -331,20 +397,9 @@ class InputReader:
ds = _maybe_map_fn(ds, self._decoder_fn) ds = _maybe_map_fn(ds, self._decoder_fn)
return ds return ds
if tfds_builder: dataset = tf.nest.map_structure(_shuffle_and_decode, dataset)
dataset = self._read_tfds(input_context) if tf.nest.is_nested(dataset):
dataset = _shuffle_and_decode(dataset) dataset = self._combine_fn(dataset)
elif isinstance(matched_files, (list, tuple)):
dataset = _files_to_dataset(matched_files)
dataset = _shuffle_and_decode(dataset)
elif isinstance(matched_files, dict):
datasets = {}
for k, fs in matched_files.items():
datasets[k] = _files_to_dataset(fs)
datasets[k] = _shuffle_and_decode(datasets[k])
dataset = self._combine_fn(datasets)
else:
raise ValueError('`matched_files` should be a list or dict.')
if self._sample_fn is not None: if self._sample_fn is not None:
dataset = dataset.apply(self._sample_fn) dataset = dataset.apply(self._sample_fn)
...@@ -403,16 +458,16 @@ class InputReader: ...@@ -403,16 +458,16 @@ class InputReader:
job_name=self._tf_data_service_job_name)) job_name=self._tf_data_service_job_name))
return dataset return dataset
def read( def read(self,
self, input_context: Optional[tf.distribute.InputContext] = None,
input_context: Optional[tf.distribute.InputContext] = None dataset: Optional[tf.data.Dataset] = None) -> tf.data.Dataset:
) -> tf.data.Dataset:
"""Generates a tf.data.Dataset object.""" """Generates a tf.data.Dataset object."""
dataset = self._read_decode_and_parse_dataset(self._matched_files, if dataset is None:
self._dataset_fn, dataset = self._read_data_source(
self._global_batch_size, self._matched_files, self._dataset_fn, input_context,
input_context,
self._tfds_builder) self._tfds_builder)
dataset = self._decode_and_parse_dataset(dataset, self._global_batch_size,
input_context)
dataset = _maybe_map_fn(dataset, self._postprocess_fn) dataset = _maybe_map_fn(dataset, self._postprocess_fn)
dataset = self._maybe_apply_data_service(dataset, input_context) dataset = self._maybe_apply_data_service(dataset, input_context)
......
...@@ -113,7 +113,7 @@ class CombinationDatasetInputReader(input_reader.InputReader): ...@@ -113,7 +113,7 @@ class CombinationDatasetInputReader(input_reader.InputReader):
self._pseudo_label_file_pattern = params.pseudo_label_data.input_path self._pseudo_label_file_pattern = params.pseudo_label_data.input_path
self._pseudo_label_dataset_fn = pseudo_label_dataset_fn self._pseudo_label_dataset_fn = pseudo_label_dataset_fn
self._pseudo_label_data_ratio = params.pseudo_label_data.data_ratio self._pseudo_label_data_ratio = params.pseudo_label_data.data_ratio
self._pseudo_label_matched_files = self._match_files( self._pseudo_label_matched_files = input_reader.match_files(
self._pseudo_label_file_pattern) self._pseudo_label_file_pattern)
if not self._drop_remainder: if not self._drop_remainder:
raise ValueError( raise ValueError(
...@@ -134,14 +134,20 @@ class CombinationDatasetInputReader(input_reader.InputReader): ...@@ -134,14 +134,20 @@ class CombinationDatasetInputReader(input_reader.InputReader):
'resulting in a 0 batch size for one of the datasets.'.format( 'resulting in a 0 batch size for one of the datasets.'.format(
self._global_batch_size, self._pseudo_label_data_ratio)) self._global_batch_size, self._pseudo_label_data_ratio))
labeled_dataset = self._read_decode_and_parse_dataset( def _read_decode_and_parse_dataset(matched_files, dataset_fn, batch_size,
input_context, tfds_builder):
dataset = self._read_data_source(matched_files, dataset_fn, input_context,
tfds_builder)
return self._decode_and_parse_dataset(dataset, batch_size, input_context)
labeled_dataset = _read_decode_and_parse_dataset(
matched_files=self._matched_files, matched_files=self._matched_files,
dataset_fn=self._dataset_fn, dataset_fn=self._dataset_fn,
batch_size=labeled_batch_size, batch_size=labeled_batch_size,
input_context=input_context, input_context=input_context,
tfds_builder=self._tfds_builder) tfds_builder=self._tfds_builder)
pseudo_labeled_dataset = self._read_decode_and_parse_dataset( pseudo_labeled_dataset = _read_decode_and_parse_dataset(
matched_files=self._pseudo_label_matched_files, matched_files=self._pseudo_label_matched_files,
dataset_fn=self._pseudo_label_dataset_fn, dataset_fn=self._pseudo_label_dataset_fn,
batch_size=pl_batch_size, batch_size=pl_batch_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment