Commit 9557e02a authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 363529035
parent e4388f88
...@@ -44,8 +44,10 @@ class DataConfig(base_config.Config): ...@@ -44,8 +44,10 @@ class DataConfig(base_config.Config):
drop_remainder: Whether the last batch should be dropped in the case it has drop_remainder: Whether the last batch should be dropped in the case it has
fewer than `global_batch_size` elements. fewer than `global_batch_size` elements.
shuffle_buffer_size: The buffer size used for shuffling training data. shuffle_buffer_size: The buffer size used for shuffling training data.
cache: Whether to cache dataset examples. Can be used to avoid re-reading cache: Whether to cache dataset examples. If `True`, we will cache the
from disk on the second epoch. Requires significant memory overhead. dataset after applying the decode_fn and parse_fn. It can be used to avoid
re-reading from disk, re-decoding and re-parsing the example on the
second epoch, but it requires significant memory overhead.
cycle_length: The number of files that will be processed concurrently when cycle_length: The number of files that will be processed concurrently when
interleaving files. interleaving files.
block_length: The number of consecutive elements to produce from each input block_length: The number of consecutive elements to produce from each input
......
...@@ -174,11 +174,13 @@ class InputReader: ...@@ -174,11 +174,13 @@ class InputReader:
dataset = tf.data.Dataset.from_tensor_slices(matched_files) dataset = tf.data.Dataset.from_tensor_slices(matched_files)
# Shuffle and repeat at file level. # Shuffle and repeat at file level.
# If cache is enabled, `reshuffle_each_iteration` is set to False,
# because we will read the same cached data in every iteration anyway.
if self._is_training: if self._is_training:
dataset = dataset.shuffle( dataset = dataset.shuffle(
len(matched_files), len(matched_files),
seed=self._seed, seed=self._seed,
reshuffle_each_iteration=True) reshuffle_each_iteration=True if not self._cache else False)
# Do not enable sharding if tf.data service is enabled, as sharding will be # Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service. # handled inside tf.data service.
...@@ -187,7 +189,9 @@ class InputReader: ...@@ -187,7 +189,9 @@ class InputReader:
not self._enable_tf_data_service): not self._enable_tf_data_service):
dataset = dataset.shard(input_context.num_input_pipelines, dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id) input_context.input_pipeline_id)
if self._is_training:
# If cache is enabled, we will call `repeat()` later after `cache()`.
if self._is_training and not self._cache:
dataset = dataset.repeat() dataset = dataset.repeat()
dataset = dataset.interleave( dataset = dataset.interleave(
...@@ -222,7 +226,9 @@ class InputReader: ...@@ -222,7 +226,9 @@ class InputReader:
not self._enable_tf_data_service): not self._enable_tf_data_service):
dataset = dataset.shard(input_context.num_input_pipelines, dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id) input_context.input_pipeline_id)
if self._is_training:
# If cache is enabled, we will call `repeat()` later after `cache()`.
if self._is_training and not self._cache:
dataset = dataset.repeat() dataset = dataset.repeat()
return dataset return dataset
...@@ -249,7 +255,8 @@ class InputReader: ...@@ -249,7 +255,8 @@ class InputReader:
decoders=decoders, decoders=decoders,
read_config=read_config) read_config=read_config)
if self._is_training: # If cache is enabled, we will call `repeat()` later after `cache()`.
if self._is_training and not self._cache:
dataset = dataset.repeat() dataset = dataset.repeat()
return dataset return dataset
...@@ -295,10 +302,8 @@ class InputReader: ...@@ -295,10 +302,8 @@ class InputReader:
raise ValueError('It is unexpected that `tfds_builder` is None and ' raise ValueError('It is unexpected that `tfds_builder` is None and '
'there is also no `matched_files`.') 'there is also no `matched_files`.')
if self._cache: # If cache is enabled, we will call `shuffle()` later after `cache()`.
dataset = dataset.cache() if self._is_training and not self._cache:
if self._is_training:
dataset = dataset.shuffle(self._shuffle_buffer_size) dataset = dataset.shuffle(self._shuffle_buffer_size)
dataset = _maybe_map_fn(dataset, self._decoder_fn) dataset = _maybe_map_fn(dataset, self._decoder_fn)
...@@ -306,6 +311,12 @@ class InputReader: ...@@ -306,6 +311,12 @@ class InputReader:
dataset = dataset.apply(self._sample_fn) dataset = dataset.apply(self._sample_fn)
dataset = _maybe_map_fn(dataset, self._parser_fn) dataset = _maybe_map_fn(dataset, self._parser_fn)
if self._cache:
dataset = dataset.cache()
if self._is_training:
dataset = dataset.repeat()
dataset = dataset.shuffle(self._shuffle_buffer_size)
if self._transform_and_batch_fn is not None: if self._transform_and_batch_fn is not None:
dataset = self._transform_and_batch_fn(dataset, input_context) dataset = self._transform_and_batch_fn(dataset, input_context)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment