Commit 469339ec authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 379618127
parent bcbce005
...@@ -110,14 +110,15 @@ class InputReader: ...@@ -110,14 +110,15 @@ class InputReader:
self._parser_fn = parser_fn self._parser_fn = parser_fn
self._transform_and_batch_fn = transform_and_batch_fn self._transform_and_batch_fn = transform_and_batch_fn
self._postprocess_fn = postprocess_fn self._postprocess_fn = postprocess_fn
self._seed = params.seed
# When tf.data service is enabled, each data service worker should get # When tf.data service is enabled, each data service worker should get
# different random seeds. Thus, we set `seed` to None. # different random seeds. Thus, we set `seed` to None.
if params.seed is not None: # Sharding should also be disabled because tf data service handles how
self._seed = params.seed # each worker shard data with `processing_mode` in distribute method.
elif params.enable_tf_data_service: if params.enable_tf_data_service:
self._seed = _get_random_integer()
else:
self._seed = None self._seed = None
self._sharding = False
self._enable_tf_data_service = ( self._enable_tf_data_service = (
params.enable_tf_data_service and params.tf_data_service_address) params.enable_tf_data_service and params.tf_data_service_address)
...@@ -181,16 +182,21 @@ class InputReader: ...@@ -181,16 +182,21 @@ class InputReader:
# If cache is enabled, `reshuffle_each_iteration` is set to False, # If cache is enabled, `reshuffle_each_iteration` is set to False,
# because we will read the same cached data in every iteration anyway. # because we will read the same cached data in every iteration anyway.
if self._is_training: if self._is_training:
# We need a seed to shuffle the files so that when each TPU workers gets
# its own shard the files do not overlap.
if self._sharding and self._seed is None:
seed = _get_random_integer()
else:
seed = self._seed
dataset = dataset.shuffle( dataset = dataset.shuffle(
len(matched_files), len(matched_files),
seed=self._seed, seed=seed,
reshuffle_each_iteration=True if not self._cache else False) reshuffle_each_iteration=True if not self._cache else False)
# Do not enable sharding if tf.data service is enabled, as sharding will be # Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service. # handled inside tf.data service.
if self._sharding and input_context and ( if self._sharding and input_context and (
input_context.num_input_pipelines > 1 and input_context.num_input_pipelines > 1):
not self._enable_tf_data_service):
dataset = dataset.shard(input_context.num_input_pipelines, dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id) input_context.input_pipeline_id)
...@@ -226,8 +232,7 @@ class InputReader: ...@@ -226,8 +232,7 @@ class InputReader:
# Do not enable sharding if tf.data service is enabled, as sharding will be # Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service. # handled inside tf.data service.
if self._sharding and input_context and ( if self._sharding and input_context and (
input_context.num_input_pipelines > 1 and input_context.num_input_pipelines > 1):
not self._enable_tf_data_service):
dataset = dataset.shard(input_context.num_input_pipelines, dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id) input_context.input_pipeline_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment