Commit e52c88e5 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Add seed to DataConfig. Add deterministic tests for input reader.

PiperOrigin-RevId: 364734183
parent 6e5c5a1b
......@@ -73,6 +73,7 @@ class DataConfig(base_config.Config):
decoding when loading dataset from TFDS. Use comma to separate multiple
features. The main use case is to skip the image/video decoding for better
performance.
seed: An optional seed to use for deterministic shuffling/preprocessing.
"""
input_path: Union[Sequence[str], str] = ""
tfds_name: str = ""
......@@ -92,6 +93,7 @@ class DataConfig(base_config.Config):
tfds_data_dir: str = ""
tfds_as_supervised: bool = False
tfds_skip_decoding_feature: str = ""
seed: Optional[int] = None
@dataclasses.dataclass
......
......@@ -112,8 +112,12 @@ class InputReader:
self._postprocess_fn = postprocess_fn
# When tf.data service is enabled, each data service worker should get
# different random seeds. Thus, we set `seed` to None.
self._seed = (None
if params.enable_tf_data_service else _get_random_integer())
if params.seed is not None:
self._seed = params.seed
elif params.enable_tf_data_service:
self._seed = _get_random_integer()
else:
self._seed = None
self._enable_tf_data_service = (
params.enable_tf_data_service and params.tf_data_service_address)
......@@ -243,7 +247,8 @@ class InputReader:
read_config = tfds.ReadConfig(
interleave_cycle_length=self._cycle_length,
interleave_block_length=self._block_length,
input_context=input_context)
input_context=input_context,
shuffle_seed=self._seed)
decoders = {}
if self._tfds_skip_decoding_feature:
for skip_feature in self._tfds_skip_decoding_feature.split(','):
......@@ -304,7 +309,7 @@ class InputReader:
# If cache is enabled, we will call `shuffle()` later after `cache()`.
if self._is_training and not self._cache:
dataset = dataset.shuffle(self._shuffle_buffer_size)
dataset = dataset.shuffle(self._shuffle_buffer_size, seed=self._seed)
dataset = _maybe_map_fn(dataset, self._decoder_fn)
if self._sample_fn is not None:
......@@ -315,7 +320,7 @@ class InputReader:
dataset = dataset.cache()
if self._is_training:
dataset = dataset.repeat()
dataset = dataset.shuffle(self._shuffle_buffer_size)
dataset = dataset.shuffle(self._shuffle_buffer_size, seed=self._seed)
if self._transform_and_batch_fn is not None:
dataset = self._transform_and_batch_fn(dataset, input_context)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment