Commit 98879c05 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 379618127
parent 588d6da4
......@@ -110,14 +110,15 @@ class InputReader:
self._parser_fn = parser_fn
self._transform_and_batch_fn = transform_and_batch_fn
self._postprocess_fn = postprocess_fn
self._seed = params.seed
# When tf.data service is enabled, each data service worker should get
# different random seeds. Thus, we set `seed` to None.
if params.seed is not None:
self._seed = params.seed
elif params.enable_tf_data_service:
self._seed = _get_random_integer()
else:
# Sharding should also be disabled because tf data service handles how
# each worker shard data with `processing_mode` in distribute method.
if params.enable_tf_data_service:
self._seed = None
self._sharding = False
self._enable_tf_data_service = (
params.enable_tf_data_service and params.tf_data_service_address)
......@@ -181,16 +182,21 @@ class InputReader:
# If cache is enabled, `reshuffle_each_iteration` is set to False,
# because we will read the same cached data in every iteration anyway.
if self._is_training:
# We need a seed to shuffle the files so that when each TPU workers gets
# its own shard the files do not overlap.
if self._sharding and self._seed is None:
seed = _get_random_integer()
else:
seed = self._seed
dataset = dataset.shuffle(
len(matched_files),
seed=self._seed,
seed=seed,
reshuffle_each_iteration=True if not self._cache else False)
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
if self._sharding and input_context and (
input_context.num_input_pipelines > 1 and
not self._enable_tf_data_service):
input_context.num_input_pipelines > 1):
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
......@@ -226,8 +232,7 @@ class InputReader:
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
if self._sharding and input_context and (
input_context.num_input_pipelines > 1 and
not self._enable_tf_data_service):
input_context.num_input_pipelines > 1):
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment