Commit ac0a29f6 authored by Hao Wu's avatar Hao Wu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 434505499
parent 9adaa571
...@@ -285,8 +285,17 @@ class InputReader: ...@@ -285,8 +285,17 @@ class InputReader:
if self._enable_tf_data_service: if self._enable_tf_data_service:
# Add a random seed as the tf.data service job name suffix, so tf.data # Add a random seed as the tf.data service job name suffix, so tf.data
# service doesn't reuse the previous state if TPU worker gets preempted. # service doesn't reuse the previous state if TPU worker gets preempted.
# It's necessary to add global batch size into the tf data service job
# name because when tuning batch size with vizier and tf data service is
# also enable, the tf data servce job name should be different for
# different vizier trials since once batch size is changed, from the
# tf.data perspective, the dataset is a different instance, and a
# different job name should be used for tf data service. Otherwise, the
# model would read tensors from the incorrect tf data service job, which
# would causes dimension mismatch on the batch size dimension.
self._tf_data_service_job_name = ( self._tf_data_service_job_name = (
params.tf_data_service_job_name + str(self.static_randnum)) f'{params.tf_data_service_job_name}_bs{params.global_batch_size}_'
f'{self.static_randnum}')
self._enable_round_robin_tf_data_service = params.get( self._enable_round_robin_tf_data_service = params.get(
'enable_round_robin_tf_data_service', False) 'enable_round_robin_tf_data_service', False)
...@@ -463,9 +472,8 @@ class InputReader: ...@@ -463,9 +472,8 @@ class InputReader:
dataset: Optional[tf.data.Dataset] = None) -> tf.data.Dataset: dataset: Optional[tf.data.Dataset] = None) -> tf.data.Dataset:
"""Generates a tf.data.Dataset object.""" """Generates a tf.data.Dataset object."""
if dataset is None: if dataset is None:
dataset = self._read_data_source( dataset = self._read_data_source(self._matched_files, self._dataset_fn,
self._matched_files, self._dataset_fn, input_context, input_context, self._tfds_builder)
self._tfds_builder)
dataset = self._decode_and_parse_dataset(dataset, self._global_batch_size, dataset = self._decode_and_parse_dataset(dataset, self._global_batch_size,
input_context) input_context)
dataset = _maybe_map_fn(dataset, self._postprocess_fn) dataset = _maybe_map_fn(dataset, self._postprocess_fn)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment