Commit c338144f authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Fix the missing shuffle issues when reading shards data.

PiperOrigin-RevId: 326246781
parent 7c29567d
......@@ -138,6 +138,8 @@ class InputReader:
input_context.num_input_pipelines > 1):
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
if self._is_training:
dataset = dataset.repeat()
dataset = dataset.interleave(
map_func=self._dataset_fn,
......@@ -163,6 +165,8 @@ class InputReader:
input_context.num_input_pipelines > 1):
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
if self._is_training:
dataset = dataset.repeat()
return dataset
def _read_tfds(
......@@ -188,6 +192,8 @@ class InputReader:
decoders=decoders,
read_config=read_config)
if self._is_training:
dataset = dataset.repeat()
return dataset
@property
......@@ -212,6 +218,12 @@ class InputReader:
assert self._num_files == 1
dataset = self._read_single_file(input_context)
if self._cache:
dataset = dataset.cache()
if self._is_training:
dataset = dataset.shuffle(self._shuffle_buffer_size)
if self._examples_consume > 0:
dataset = dataset.take(self._examples_consume)
......@@ -222,16 +234,6 @@ class InputReader:
dataset = maybe_map_fn(dataset, self._decoder_fn)
dataset = maybe_map_fn(dataset, self._parser_fn)
if self._cache:
dataset = dataset.cache()
if self._is_training:
dataset = dataset.shuffle(
self._shuffle_buffer_size,
seed=self._seed,
reshuffle_each_iteration=True)
dataset = dataset.repeat()
if self._transform_and_batch_fn is not None:
dataset = self._transform_and_batch_fn(dataset, input_context)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment