Commit 9d88abe8 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Fix the missing shuffle issues when reading shards data.

PiperOrigin-RevId: 325490166
parent ca9258a9
......@@ -15,6 +15,7 @@
# ==============================================================================
"""A common dataset reader."""
import random
from typing import Any, Callable, List, Optional
import tensorflow as tf
......@@ -23,6 +24,10 @@ import tensorflow_datasets as tfds
from official.modeling.hyperparams import config_definitions as cfg
def _get_random_integer():
return random.randint(0, (1 << 31) - 1)
class InputReader:
"""Input reader that returns a tf.data.Dataset instance."""
......@@ -107,6 +112,7 @@ class InputReader:
self._parser_fn = parser_fn
self._transform_and_batch_fn = transform_and_batch_fn
self._postprocess_fn = postprocess_fn
self._seed = _get_random_integer()
def _read_sharded_files(
self,
......@@ -117,13 +123,21 @@ class InputReader:
dataset = tf.data.Dataset.from_tensor_slices(self._shards)
else:
dataset = tf.data.Dataset.list_files(
self._input_patterns, shuffle=self._is_training)
self._input_patterns,
seed=self._seed,
shuffle=self._is_training)
# Shuffle and repeat at file level.
if self._shards and self._is_training:
dataset = dataset.shuffle(
len(self._shards),
seed=self._seed,
reshuffle_each_iteration=True)
if self._sharding and input_context and (
input_context.num_input_pipelines > 1):
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
if self._is_training:
dataset = dataset.repeat()
dataset = dataset.interleave(
map_func=self._dataset_fn,
......@@ -149,8 +163,6 @@ class InputReader:
input_context.num_input_pipelines > 1):
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
if self._is_training:
dataset = dataset.repeat()
return dataset
def _read_tfds(
......@@ -176,8 +188,6 @@ class InputReader:
decoders=decoders,
read_config=read_config)
if self._is_training:
dataset = dataset.repeat()
return dataset
@property
......@@ -202,12 +212,6 @@ class InputReader:
assert self._num_files == 1
dataset = self._read_single_file(input_context)
if self._cache:
dataset = dataset.cache()
if self._is_training:
dataset = dataset.shuffle(self._shuffle_buffer_size)
if self._examples_consume > 0:
dataset = dataset.take(self._examples_consume)
......@@ -218,6 +222,16 @@ class InputReader:
dataset = maybe_map_fn(dataset, self._decoder_fn)
dataset = maybe_map_fn(dataset, self._parser_fn)
if self._cache:
dataset = dataset.cache()
if self._is_training:
dataset = dataset.shuffle(
self._shuffle_buffer_size,
seed=self._seed,
reshuffle_each_iteration=True)
dataset = dataset.repeat()
if self._transform_and_batch_fn is not None:
dataset = self._transform_and_batch_fn(dataset, input_context)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment