"examples/vscode:/vscode.git/clone" did not exist on "9c13f8657986e68f5f05987912c54432fd28d86f"
Commit e52c88e5 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Add seed to DataConfig. Add deterministic tests for input reader.

PiperOrigin-RevId: 364734183
parent 6e5c5a1b
...@@ -73,6 +73,7 @@ class DataConfig(base_config.Config): ...@@ -73,6 +73,7 @@ class DataConfig(base_config.Config):
decoding when loading dataset from TFDS. Use comma to separate multiple decoding when loading dataset from TFDS. Use comma to separate multiple
features. The main use case is to skip the image/video decoding for better features. The main use case is to skip the image/video decoding for better
performance. performance.
seed: An optional seed to use for deterministic shuffling/preprocessing.
""" """
input_path: Union[Sequence[str], str] = "" input_path: Union[Sequence[str], str] = ""
tfds_name: str = "" tfds_name: str = ""
...@@ -92,6 +93,7 @@ class DataConfig(base_config.Config): ...@@ -92,6 +93,7 @@ class DataConfig(base_config.Config):
tfds_data_dir: str = "" tfds_data_dir: str = ""
tfds_as_supervised: bool = False tfds_as_supervised: bool = False
tfds_skip_decoding_feature: str = "" tfds_skip_decoding_feature: str = ""
seed: Optional[int] = None
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -112,8 +112,12 @@ class InputReader: ...@@ -112,8 +112,12 @@ class InputReader:
self._postprocess_fn = postprocess_fn self._postprocess_fn = postprocess_fn
# When tf.data service is enabled, each data service worker should get # When tf.data service is enabled, each data service worker should get
# different random seeds. Thus, we set `seed` to None. # different random seeds. Thus, we set `seed` to None.
self._seed = (None if params.seed is not None:
if params.enable_tf_data_service else _get_random_integer()) self._seed = params.seed
elif params.enable_tf_data_service:
self._seed = _get_random_integer()
else:
self._seed = None
self._enable_tf_data_service = ( self._enable_tf_data_service = (
params.enable_tf_data_service and params.tf_data_service_address) params.enable_tf_data_service and params.tf_data_service_address)
...@@ -243,7 +247,8 @@ class InputReader: ...@@ -243,7 +247,8 @@ class InputReader:
read_config = tfds.ReadConfig( read_config = tfds.ReadConfig(
interleave_cycle_length=self._cycle_length, interleave_cycle_length=self._cycle_length,
interleave_block_length=self._block_length, interleave_block_length=self._block_length,
input_context=input_context) input_context=input_context,
shuffle_seed=self._seed)
decoders = {} decoders = {}
if self._tfds_skip_decoding_feature: if self._tfds_skip_decoding_feature:
for skip_feature in self._tfds_skip_decoding_feature.split(','): for skip_feature in self._tfds_skip_decoding_feature.split(','):
...@@ -304,7 +309,7 @@ class InputReader: ...@@ -304,7 +309,7 @@ class InputReader:
# If cache is enabled, we will call `shuffle()` later after `cache()`. # If cache is enabled, we will call `shuffle()` later after `cache()`.
if self._is_training and not self._cache: if self._is_training and not self._cache:
dataset = dataset.shuffle(self._shuffle_buffer_size) dataset = dataset.shuffle(self._shuffle_buffer_size, seed=self._seed)
dataset = _maybe_map_fn(dataset, self._decoder_fn) dataset = _maybe_map_fn(dataset, self._decoder_fn)
if self._sample_fn is not None: if self._sample_fn is not None:
...@@ -315,7 +320,7 @@ class InputReader: ...@@ -315,7 +320,7 @@ class InputReader:
dataset = dataset.cache() dataset = dataset.cache()
if self._is_training: if self._is_training:
dataset = dataset.repeat() dataset = dataset.repeat()
dataset = dataset.shuffle(self._shuffle_buffer_size) dataset = dataset.shuffle(self._shuffle_buffer_size, seed=self._seed)
if self._transform_and_batch_fn is not None: if self._transform_and_batch_fn is not None:
dataset = self._transform_and_batch_fn(dataset, input_context) dataset = self._transform_and_batch_fn(dataset, input_context)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment