"...models/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "e61747438522e2d5a92b0b7836754be1a7eb9017"
Commit b67a8538 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Fix the missing shuffle issues when reading shards data.

PiperOrigin-RevId: 325490166
parent cd89deb7
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# ============================================================================== # ==============================================================================
"""A common dataset reader.""" """A common dataset reader."""
import random
from typing import Any, Callable, List, Optional from typing import Any, Callable, List, Optional
import tensorflow as tf import tensorflow as tf
...@@ -23,6 +24,10 @@ import tensorflow_datasets as tfds ...@@ -23,6 +24,10 @@ import tensorflow_datasets as tfds
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
def _get_random_integer():
return random.randint(0, (1 << 31) - 1)
class InputReader: class InputReader:
"""Input reader that returns a tf.data.Dataset instance.""" """Input reader that returns a tf.data.Dataset instance."""
...@@ -107,6 +112,7 @@ class InputReader: ...@@ -107,6 +112,7 @@ class InputReader:
self._parser_fn = parser_fn self._parser_fn = parser_fn
self._transform_and_batch_fn = transform_and_batch_fn self._transform_and_batch_fn = transform_and_batch_fn
self._postprocess_fn = postprocess_fn self._postprocess_fn = postprocess_fn
self._seed = _get_random_integer()
def _read_sharded_files( def _read_sharded_files(
self, self,
...@@ -117,13 +123,21 @@ class InputReader: ...@@ -117,13 +123,21 @@ class InputReader:
dataset = tf.data.Dataset.from_tensor_slices(self._shards) dataset = tf.data.Dataset.from_tensor_slices(self._shards)
else: else:
dataset = tf.data.Dataset.list_files( dataset = tf.data.Dataset.list_files(
self._input_patterns, shuffle=self._is_training) self._input_patterns,
seed=self._seed,
shuffle=self._is_training)
# Shuffle and repeat at file level.
if self._shards and self._is_training:
dataset = dataset.shuffle(
len(self._shards),
seed=self._seed,
reshuffle_each_iteration=True)
if self._sharding and input_context and ( if self._sharding and input_context and (
input_context.num_input_pipelines > 1): input_context.num_input_pipelines > 1):
dataset = dataset.shard(input_context.num_input_pipelines, dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id) input_context.input_pipeline_id)
if self._is_training:
dataset = dataset.repeat()
dataset = dataset.interleave( dataset = dataset.interleave(
map_func=self._dataset_fn, map_func=self._dataset_fn,
...@@ -149,8 +163,6 @@ class InputReader: ...@@ -149,8 +163,6 @@ class InputReader:
input_context.num_input_pipelines > 1): input_context.num_input_pipelines > 1):
dataset = dataset.shard(input_context.num_input_pipelines, dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id) input_context.input_pipeline_id)
if self._is_training:
dataset = dataset.repeat()
return dataset return dataset
def _read_tfds( def _read_tfds(
...@@ -176,8 +188,6 @@ class InputReader: ...@@ -176,8 +188,6 @@ class InputReader:
decoders=decoders, decoders=decoders,
read_config=read_config) read_config=read_config)
if self._is_training:
dataset = dataset.repeat()
return dataset return dataset
@property @property
...@@ -202,12 +212,6 @@ class InputReader: ...@@ -202,12 +212,6 @@ class InputReader:
assert self._num_files == 1 assert self._num_files == 1
dataset = self._read_single_file(input_context) dataset = self._read_single_file(input_context)
if self._cache:
dataset = dataset.cache()
if self._is_training:
dataset = dataset.shuffle(self._shuffle_buffer_size)
if self._examples_consume > 0: if self._examples_consume > 0:
dataset = dataset.take(self._examples_consume) dataset = dataset.take(self._examples_consume)
...@@ -218,6 +222,16 @@ class InputReader: ...@@ -218,6 +222,16 @@ class InputReader:
dataset = maybe_map_fn(dataset, self._decoder_fn) dataset = maybe_map_fn(dataset, self._decoder_fn)
dataset = maybe_map_fn(dataset, self._parser_fn) dataset = maybe_map_fn(dataset, self._parser_fn)
if self._cache:
dataset = dataset.cache()
if self._is_training:
dataset = dataset.shuffle(
self._shuffle_buffer_size,
seed=self._seed,
reshuffle_each_iteration=True)
dataset = dataset.repeat()
if self._transform_and_batch_fn is not None: if self._transform_and_batch_fn is not None:
dataset = self._transform_and_batch_fn(dataset, input_context) dataset = self._transform_and_batch_fn(dataset, input_context)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment