Commit d0b623d4 authored by Hao Wu's avatar Hao Wu Committed by A. Unique TensorFlower
Browse files

Internal changes.

PiperOrigin-RevId: 465438835
parent 9b47a723
......@@ -75,6 +75,30 @@ class DataConfig(base_config.Config):
decoding when loading dataset from TFDS. Use comma to separate multiple
features. The main use case is to skip the image/video decoding for better
performance.
enable_shared_tf_data_service_between_parallel_trainers: A bool. When set to
true, only a single tf.data service will be started, and it will be shared
between all the trainer run simultaneously, e.g. using vizier to tune
hyperparameters. This will save CPU and RAM resources compared to running
separate tf.data service for each trainer. Notice that if batch size is
different for different trainers, the field
apply_tf_data_service_before_batching also needs to be true so that only a
single tf.data service instance will be created. In this case, tf.data
service will be applied before batching operation. So make sure to not
apply any processing steps after batching (e.g. in postprocess_fn) since
they wouldn't be paralleled by tf.data service and may slow down your
tf.data pipeline. When using shared tf.data service, the tf.data dataset
must be infinite, and slow trainer may skip certain training examples.
More details about shared tf.data service can be found at:
https://www.tensorflow.org/api_docs/python/tf/data/experimental/service?version=nightly#sharing_tfdata_service_with_concurrent_trainers.
apply_tf_data_service_before_batching: A bool. If set to True, tf.data
service will be applied before batching operation. This is useful to make
sure only a single tf.data service instance is created when
enable_shared_tf_data_service_between_parallel_trainers is true and batch
size is changing between parallel trainers.
trainer_id: A string. The id of the trainer if there are multiple parallel
trainer running at the same time, e.g. in vizier tuning case. It will be
automatically set if this field is needed. Users does not need to set it
when creating experiment configs.
seed: An optional seed to use for deterministic shuffling/preprocessing.
prefetch_buffer_size: An int specifying the buffer size of prefetch
datasets. If None, the buffer size is autotuned. Specifying this is useful
......@@ -99,6 +123,9 @@ class DataConfig(base_config.Config):
tfds_data_dir: str = ""
tfds_as_supervised: bool = False
tfds_skip_decoding_feature: str = ""
enable_shared_tf_data_service_between_parallel_trainers: bool = False
apply_tf_data_service_before_batching: bool = False
trainer_id: Optional[str] = None
seed: Optional[int] = None
prefetch_buffer_size: Optional[int] = None
......
......@@ -292,8 +292,8 @@ class InputReader:
self._transform_and_batch_fn = transform_and_batch_fn
self._postprocess_fn = postprocess_fn
self._seed = params.seed
self._prefetch_buffer_size = (params.prefetch_buffer_size or
tf.data.experimental.AUTOTUNE)
self._prefetch_buffer_size = (
params.prefetch_buffer_size or tf.data.experimental.AUTOTUNE)
# When tf.data service is enabled, each data service worker should get
# different random seeds. Thus, we set `seed` to None.
......@@ -306,6 +306,11 @@ class InputReader:
self._enable_tf_data_service = (
params.enable_tf_data_service and params.tf_data_service_address)
self._tf_data_service_address = params.tf_data_service_address
self._enable_shared_tf_data_service_between_parallel_trainers = (
params.enable_shared_tf_data_service_between_parallel_trainers)
self._apply_tf_data_service_before_batching = (
params.apply_tf_data_service_before_batching)
self._trainer_id = params.trainer_id
if self._enable_tf_data_service:
# Add a random seed as the tf.data service job name suffix, so tf.data
# service doesn't reuse the previous state if TPU worker gets preempted.
......@@ -322,6 +327,15 @@ class InputReader:
f'{self.static_randnum}')
self._enable_round_robin_tf_data_service = params.get(
'enable_round_robin_tf_data_service', False)
if self._enable_shared_tf_data_service_between_parallel_trainers:
# When shared tf.data service is enabled, only a single tf.data service
# instance should be created and shared between parallel trainers. If
# the global batch size is different across trainers,
# params.apply_tf_data_service_before_batching should be set to true
# because tf.data service with different batch sizes will be considered
# separate tf.data service instances.
self._tf_data_service_job_name = (
f'{params.tf_data_service_job_name}_{self.static_randnum}')
@property
def tfds_info(self) -> tfds.core.DatasetInfo:
......@@ -444,6 +458,19 @@ class InputReader:
dataset = dataset.repeat()
dataset = dataset.shuffle(self._shuffle_buffer_size, seed=self._seed)
# Applies tf.data service before batching operations. This is useful when
# tf.data service is shared between parallel trainers, and batch size is
# changing between parallel trainers. Then batch size is changing, tf.data
# services will be considered different instances if applied after batching
# operations, which make it difficult to share between parallel trainers.
# However, if there are additional expensive operations in
# self._transform_and_batch_fn and self._postprocess_fn, the entire tf.data
# pipeline could be slowed down. In this case, try to move these dataset
# operations into early stages if possible.
if (self._enable_shared_tf_data_service_between_parallel_trainers and
self._apply_tf_data_service_before_batching):
dataset = self._maybe_apply_data_service(dataset, input_context)
if self._transform_and_batch_fn is not None:
dataset = self._transform_and_batch_fn(dataset, input_context)
else:
......@@ -469,13 +496,18 @@ class InputReader:
num_consumers = input_context.num_input_pipelines * (
replicas_per_input_pipeline)
range_dataset = tf.data.Dataset.range(replicas_per_input_pipeline)
tfds_kwargs = {
'processing_mode': 'parallel_epochs',
'service': self._tf_data_service_address,
'job_name': self._tf_data_service_job_name,
'num_consumers': num_consumers
}
if self._enable_shared_tf_data_service_between_parallel_trainers:
raise ValueError('Shared tf.data service does not support round-robin'
' tf.data service.')
dataset = range_dataset.map(lambda i: dataset.apply( # pylint: disable=g-long-lambda
tf.data.experimental.service.distribute(
processing_mode='parallel_epochs',
service=self._tf_data_service_address,
job_name=self._tf_data_service_job_name,
consumer_index=base_consumer_index + i,
num_consumers=num_consumers)))
consumer_index=base_consumer_index + i, **tfds_kwargs)))
# Use parallel interleave to read multiple batches from a tf.data
# service worker in parallel.
dataset = dataset.interleave(
......@@ -484,11 +516,21 @@ class InputReader:
num_parallel_calls=replicas_per_input_pipeline,
deterministic=True)
else:
tfds_kwargs = {
'processing_mode': 'parallel_epochs',
'service': self._tf_data_service_address,
'job_name': self._tf_data_service_job_name,
}
if self._enable_shared_tf_data_service_between_parallel_trainers:
tfds_kwargs.update({
'processing_mode':
tf.data.experimental.service.ShardingPolicy.OFF,
'cross_trainer_cache':
tf.data.experimental.service.CrossTrainerCache(
trainer_id=self._trainer_id)
})
dataset = dataset.apply(
tf.data.experimental.service.distribute(
processing_mode='parallel_epochs',
service=self._tf_data_service_address,
job_name=self._tf_data_service_job_name))
tf.data.experimental.service.distribute(**tfds_kwargs))
return dataset
def read(self,
......@@ -501,6 +543,8 @@ class InputReader:
dataset = self._decode_and_parse_dataset(dataset, self._global_batch_size,
input_context)
dataset = _maybe_map_fn(dataset, self._postprocess_fn)
if not (self._enable_shared_tf_data_service_between_parallel_trainers and
self._apply_tf_data_service_before_batching):
dataset = self._maybe_apply_data_service(dataset, input_context)
if self._deterministic is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment