Commit ca509f79 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 358289482
parent a5e7e2ce
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
"""A common dataset reader.""" """A common dataset reader."""
import random import random
from typing import Any, Callable, Optional from typing import Any, Callable, List, Optional
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
...@@ -27,6 +27,13 @@ def _get_random_integer(): ...@@ -27,6 +27,13 @@ def _get_random_integer():
return random.randint(0, (1 << 31) - 1) return random.randint(0, (1 << 31) - 1)
def _maybe_map_fn(dataset: tf.data.Dataset,
fn: Optional[Callable[..., Any]] = None) -> tf.data.Dataset:
"""Calls dataset.map if a valid function is passed in."""
return dataset if fn is None else dataset.map(
fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
class InputReader: class InputReader:
"""Input reader that returns a tf.data.Dataset instance.""" """Input reader that returns a tf.data.Dataset instance."""
...@@ -74,38 +81,7 @@ class InputReader: ...@@ -74,38 +81,7 @@ class InputReader:
self._tfds_builder = None self._tfds_builder = None
self._matched_files = [] self._matched_files = []
if params.input_path: if params.input_path:
# Read dataset from files. self._matched_files = self._match_files(params.input_path)
usage = ('`input_path` should be either (1) a str indicating a file '
'path/pattern, or (2) a str indicating multiple file '
'paths/patterns separated by comma (e.g "a, b, c" or no spaces '
'"a,b,c", or (3) a list of str, each of which is a file '
'path/pattern or multiple file paths/patterns separated by '
'comma, but got: %s')
if isinstance(params.input_path, str):
input_path_list = [params.input_path]
elif isinstance(params.input_path, (list, tuple)):
if any(not isinstance(x, str) for x in params.input_path):
raise ValueError(usage % params.input_path)
input_path_list = params.input_path
else:
raise ValueError(usage % params.input_path)
for input_path in input_path_list:
input_patterns = input_path.strip().split(',')
for input_pattern in input_patterns:
input_pattern = input_pattern.strip()
if not input_pattern:
continue
if '*' in input_pattern or '?' in input_pattern:
tmp_matched_files = tf.io.gfile.glob(input_pattern)
if not tmp_matched_files:
raise ValueError('%s does not match any files.' % input_pattern)
self._matched_files.extend(tmp_matched_files)
else:
self._matched_files.append(input_pattern)
if not self._matched_files:
raise ValueError('%s does not match any files.' % params.input_path)
else: else:
# Read dataset from TFDS. # Read dataset from TFDS.
if not params.tfds_split: if not params.tfds_split:
...@@ -148,15 +124,57 @@ class InputReader: ...@@ -148,15 +124,57 @@ class InputReader:
self._enable_round_robin_tf_data_service = params.get( self._enable_round_robin_tf_data_service = params.get(
'enable_round_robin_tf_data_service', False) 'enable_round_robin_tf_data_service', False)
def _match_files(self, input_path: str) -> List[str]:
"""Matches files from an input_path."""
matched_files = []
# Read dataset from files.
usage = ('`input_path` should be either (1) a str indicating a file '
'path/pattern, or (2) a str indicating multiple file '
'paths/patterns separated by comma (e.g "a, b, c" or no spaces '
'"a,b,c", or (3) a list of str, each of which is a file '
'path/pattern or multiple file paths/patterns separated by '
'comma, but got: %s')
if isinstance(input_path, str):
input_path_list = [input_path]
elif isinstance(input_path, (list, tuple)):
if any(not isinstance(x, str) for x in input_path):
raise ValueError(usage % input_path)
input_path_list = input_path
else:
raise ValueError(usage % input_path)
for input_path in input_path_list:
input_patterns = input_path.strip().split(',')
for input_pattern in input_patterns:
input_pattern = input_pattern.strip()
if not input_pattern:
continue
if '*' in input_pattern or '?' in input_pattern:
tmp_matched_files = tf.io.gfile.glob(input_pattern)
if not tmp_matched_files:
raise ValueError('%s does not match any files.' % input_pattern)
matched_files.extend(tmp_matched_files)
else:
matched_files.append(input_pattern)
if not matched_files:
raise ValueError('%s does not match any files.' % input_path)
return matched_files
def _shard_files_then_read( def _shard_files_then_read(
self, input_context: Optional[tf.distribute.InputContext] = None): self,
matched_files: List[str],
dataset_fn,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Shards the data files and then sent a split to every worker to read.""" """Shards the data files and then sent a split to every worker to read."""
dataset = tf.data.Dataset.from_tensor_slices(self._matched_files) dataset = tf.data.Dataset.from_tensor_slices(matched_files)
# Shuffle and repeat at file level. # Shuffle and repeat at file level.
if self._is_training: if self._is_training:
dataset = dataset.shuffle( dataset = dataset.shuffle(
len(self._matched_files), len(matched_files),
seed=self._seed, seed=self._seed,
reshuffle_each_iteration=True) reshuffle_each_iteration=True)
...@@ -171,7 +189,7 @@ class InputReader: ...@@ -171,7 +189,7 @@ class InputReader:
dataset = dataset.repeat() dataset = dataset.repeat()
dataset = dataset.interleave( dataset = dataset.interleave(
map_func=self._dataset_fn, map_func=dataset_fn,
cycle_length=self._cycle_length, cycle_length=self._cycle_length,
block_length=self._block_length, block_length=self._block_length,
num_parallel_calls=(self._cycle_length if self._cycle_length else num_parallel_calls=(self._cycle_length if self._cycle_length else
...@@ -180,9 +198,13 @@ class InputReader: ...@@ -180,9 +198,13 @@ class InputReader:
return dataset return dataset
def _read_files_then_shard( def _read_files_then_shard(
self, input_context: Optional[tf.distribute.InputContext] = None): self,
matched_files: List[str],
dataset_fn,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Sends all data files to every worker and then shard by data.""" """Sends all data files to every worker and then shard by data."""
dataset = self._dataset_fn(self._matched_files) dataset = dataset_fn(matched_files)
# When `input_file` is a path to a single file or the number of files is # When `input_file` is a path to a single file or the number of files is
# less than the number of input pipelines, disable auto sharding # less than the number of input pipelines, disable auto sharding
...@@ -238,26 +260,35 @@ class InputReader: ...@@ -238,26 +260,35 @@ class InputReader:
raise ValueError('tfds_info is not available, because the dataset ' raise ValueError('tfds_info is not available, because the dataset '
'is not loaded from tfds.') 'is not loaded from tfds.')
def read( def _read_decode_and_parse_dataset(
self, self,
input_context: Optional[tf.distribute.InputContext] = None matched_files: List[str],
) -> tf.data.Dataset: dataset_fn,
"""Generates a tf.data.Dataset object.""" batch_size: int,
if self._tfds_builder: input_context: Optional[tf.distribute.InputContext] = None,
tfds_builder: bool = False) -> tf.data.Dataset:
"""Returns a tf.data.Dataset object after reading, decoding, and parsing."""
if tfds_builder:
dataset = self._read_tfds(input_context) dataset = self._read_tfds(input_context)
elif len(self._matched_files) > 1: elif len(self._matched_files) > 1:
if input_context and (len(self._matched_files) < if input_context and (len(matched_files) <
input_context.num_input_pipelines): input_context.num_input_pipelines):
logging.warn( logging.warn(
'The number of files %d is less than the number of input pipelines ' 'The number of files %d is less than the number of input pipelines '
'%d. We will send all input files to every worker. ' '%d. We will send all input files to every worker. '
'Please consider sharding your data into more files.', 'Please consider sharding your data into more files.',
len(self._matched_files), input_context.num_input_pipelines) len(matched_files), input_context.num_input_pipelines)
dataset = self._read_files_then_shard(input_context) dataset = self._read_files_then_shard(matched_files,
dataset_fn,
input_context)
else: else:
dataset = self._shard_files_then_read(input_context) dataset = self._shard_files_then_read(matched_files,
elif len(self._matched_files) == 1: dataset_fn,
dataset = self._read_files_then_shard(input_context) input_context)
elif len(matched_files) == 1:
dataset = self._read_files_then_shard(matched_files,
dataset_fn,
input_context)
else: else:
raise ValueError('It is unexpected that `tfds_builder` is None and ' raise ValueError('It is unexpected that `tfds_builder` is None and '
'there is also no `matched_files`.') 'there is also no `matched_files`.')
...@@ -268,25 +299,28 @@ class InputReader: ...@@ -268,25 +299,28 @@ class InputReader:
if self._is_training: if self._is_training:
dataset = dataset.shuffle(self._shuffle_buffer_size) dataset = dataset.shuffle(self._shuffle_buffer_size)
def maybe_map_fn(dataset, fn): dataset = _maybe_map_fn(dataset, self._decoder_fn)
return dataset if fn is None else dataset.map(
fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = maybe_map_fn(dataset, self._decoder_fn)
if self._sample_fn is not None: if self._sample_fn is not None:
dataset = dataset.apply(self._sample_fn) dataset = dataset.apply(self._sample_fn)
dataset = maybe_map_fn(dataset, self._parser_fn) dataset = _maybe_map_fn(dataset, self._parser_fn)
if self._transform_and_batch_fn is not None: if self._transform_and_batch_fn is not None:
dataset = self._transform_and_batch_fn(dataset, input_context) dataset = self._transform_and_batch_fn(dataset, input_context)
else: else:
per_replica_batch_size = input_context.get_per_replica_batch_size( per_replica_batch_size = input_context.get_per_replica_batch_size(
self._global_batch_size) if input_context else self._global_batch_size batch_size) if input_context else batch_size
dataset = dataset.batch( dataset = dataset.batch(
per_replica_batch_size, drop_remainder=self._drop_remainder) per_replica_batch_size, drop_remainder=self._drop_remainder
)
dataset = maybe_map_fn(dataset, self._postprocess_fn) return dataset
def _maybe_apply_data_service(
self,
dataset: tf.data.Dataset,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Potentially distributes a dataset."""
if self._enable_tf_data_service and input_context: if self._enable_tf_data_service and input_context:
if self._enable_round_robin_tf_data_service: if self._enable_round_robin_tf_data_service:
replicas_per_input_pipeline = input_context.num_replicas_in_sync // ( replicas_per_input_pipeline = input_context.num_replicas_in_sync // (
...@@ -316,6 +350,20 @@ class InputReader: ...@@ -316,6 +350,20 @@ class InputReader:
processing_mode='parallel_epochs', processing_mode='parallel_epochs',
service=self._tf_data_service_address, service=self._tf_data_service_address,
job_name=self._tf_data_service_job_name)) job_name=self._tf_data_service_job_name))
return dataset
def read(
self,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Generates a tf.data.Dataset object."""
dataset = self._read_decode_and_parse_dataset(self._matched_files,
self._dataset_fn,
self._global_batch_size,
input_context,
self._tfds_builder)
dataset = _maybe_map_fn(dataset, self._postprocess_fn)
dataset = self._maybe_apply_data_service(dataset, input_context)
if self._deterministic is not None: if self._deterministic is not None:
options = tf.data.Options() options = tf.data.Options()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment