Commit 5b60084d authored by Ruoxin Sang's avatar Ruoxin Sang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 352730170
parent e31d3f37
...@@ -30,6 +30,10 @@ def _get_random_integer(): ...@@ -30,6 +30,10 @@ def _get_random_integer():
class InputReader: class InputReader:
"""Input reader that returns a tf.data.Dataset instance.""" """Input reader that returns a tf.data.Dataset instance."""
# A static random number which is the same across different InputReader
# instances.
static_randnum = _get_random_integer()
def __init__(self, def __init__(self,
params: cfg.DataConfig, params: cfg.DataConfig,
dataset_fn=tf.data.TFRecordDataset, dataset_fn=tf.data.TFRecordDataset,
...@@ -136,7 +140,13 @@ class InputReader: ...@@ -136,7 +140,13 @@ class InputReader:
self._enable_tf_data_service = ( self._enable_tf_data_service = (
params.enable_tf_data_service and params.tf_data_service_address) params.enable_tf_data_service and params.tf_data_service_address)
self._tf_data_service_address = params.tf_data_service_address self._tf_data_service_address = params.tf_data_service_address
self._tf_data_service_job_name = params.tf_data_service_job_name if self._enable_tf_data_service:
# Add a random seed as the tf.data service job name suffix, so tf.data
# service doesn't reuse the previous state if TPU worker gets preempted.
self._tf_data_service_job_name = (
params.tf_data_service_job_name + str(self.static_randnum))
self._enable_round_robin_tf_data_service = params.get(
'enable_round_robin_tf_data_service', False)
def _shard_files_then_read( def _shard_files_then_read(
self, input_context: Optional[tf.distribute.InputContext] = None): self, input_context: Optional[tf.distribute.InputContext] = None):
...@@ -276,7 +286,30 @@ class InputReader: ...@@ -276,7 +286,30 @@ class InputReader:
dataset = maybe_map_fn(dataset, self._postprocess_fn) dataset = maybe_map_fn(dataset, self._postprocess_fn)
if self._enable_tf_data_service: if self._enable_tf_data_service and input_context:
if self._enable_round_robin_tf_data_service:
replicas_per_input_pipeline = input_context.num_replicas_in_sync // (
input_context.num_input_pipelines)
base_consumer_index = input_context.input_pipeline_id * (
replicas_per_input_pipeline)
num_consumers = input_context.num_input_pipelines * (
replicas_per_input_pipeline)
range_dataset = tf.data.Dataset.range(replicas_per_input_pipeline)
dataset = range_dataset.map(lambda i: dataset.apply( # pylint: disable=g-long-lambda
tf.data.experimental.service.distribute(
processing_mode='parallel_epochs',
service=self._tf_data_service_address,
job_name=self._tf_data_service_job_name,
consumer_index=base_consumer_index + i,
num_consumers=num_consumers)))
# Use parallel interleave to read multiple batches from a tf.data
# service worker in parallel.
dataset = dataset.interleave(
lambda x: x,
cycle_length=replicas_per_input_pipeline,
num_parallel_calls=replicas_per_input_pipeline,
deterministic=True)
else:
dataset = dataset.apply( dataset = dataset.apply(
tf.data.experimental.service.distribute( tf.data.experimental.service.distribute(
processing_mode='parallel_epochs', processing_mode='parallel_epochs',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment