"...resnet50_tensorflow.git" did not exist on "9abd85f29a29c6ffc33dbb6ff1ebab0263b2e732"
Commit c338144f authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Fix the missing shuffle issues when reading shards data.

PiperOrigin-RevId: 326246781
parent 7c29567d
...@@ -138,6 +138,8 @@ class InputReader: ...@@ -138,6 +138,8 @@ class InputReader:
input_context.num_input_pipelines > 1): input_context.num_input_pipelines > 1):
dataset = dataset.shard(input_context.num_input_pipelines, dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id) input_context.input_pipeline_id)
if self._is_training:
dataset = dataset.repeat()
dataset = dataset.interleave( dataset = dataset.interleave(
map_func=self._dataset_fn, map_func=self._dataset_fn,
...@@ -163,6 +165,8 @@ class InputReader: ...@@ -163,6 +165,8 @@ class InputReader:
input_context.num_input_pipelines > 1): input_context.num_input_pipelines > 1):
dataset = dataset.shard(input_context.num_input_pipelines, dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id) input_context.input_pipeline_id)
if self._is_training:
dataset = dataset.repeat()
return dataset return dataset
def _read_tfds( def _read_tfds(
...@@ -188,6 +192,8 @@ class InputReader: ...@@ -188,6 +192,8 @@ class InputReader:
decoders=decoders, decoders=decoders,
read_config=read_config) read_config=read_config)
if self._is_training:
dataset = dataset.repeat()
return dataset return dataset
@property @property
...@@ -212,6 +218,12 @@ class InputReader: ...@@ -212,6 +218,12 @@ class InputReader:
assert self._num_files == 1 assert self._num_files == 1
dataset = self._read_single_file(input_context) dataset = self._read_single_file(input_context)
if self._cache:
dataset = dataset.cache()
if self._is_training:
dataset = dataset.shuffle(self._shuffle_buffer_size)
if self._examples_consume > 0: if self._examples_consume > 0:
dataset = dataset.take(self._examples_consume) dataset = dataset.take(self._examples_consume)
...@@ -222,16 +234,6 @@ class InputReader: ...@@ -222,16 +234,6 @@ class InputReader:
dataset = maybe_map_fn(dataset, self._decoder_fn) dataset = maybe_map_fn(dataset, self._decoder_fn)
dataset = maybe_map_fn(dataset, self._parser_fn) dataset = maybe_map_fn(dataset, self._parser_fn)
if self._cache:
dataset = dataset.cache()
if self._is_training:
dataset = dataset.shuffle(
self._shuffle_buffer_size,
seed=self._seed,
reshuffle_each_iteration=True)
dataset = dataset.repeat()
if self._transform_and_batch_fn is not None: if self._transform_and_batch_fn is not None:
dataset = self._transform_and_batch_fn(dataset, input_context) dataset = self._transform_and_batch_fn(dataset, input_context)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment