"vscode:/vscode.git/clone" did not exist on "b57abe16632605ae9e8b0473dbb45fb0fd25e6f1"
Commit ac0a29f6 authored by Hao Wu's avatar Hao Wu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 434505499
parent 9adaa571
......@@ -285,8 +285,17 @@ class InputReader:
if self._enable_tf_data_service:
# Add a random seed as the tf.data service job name suffix, so tf.data
# service doesn't reuse the previous state if TPU worker gets preempted.
# It's necessary to add global batch size into the tf data service job
# name because when tuning batch size with vizier and tf data service is
# also enable, the tf data servce job name should be different for
# different vizier trials since once batch size is changed, from the
# tf.data perspective, the dataset is a different instance, and a
# different job name should be used for tf data service. Otherwise, the
# model would read tensors from the incorrect tf data service job, which
# would causes dimension mismatch on the batch size dimension.
self._tf_data_service_job_name = (
params.tf_data_service_job_name + str(self.static_randnum))
f'{params.tf_data_service_job_name}_bs{params.global_batch_size}_'
f'{self.static_randnum}')
self._enable_round_robin_tf_data_service = params.get(
'enable_round_robin_tf_data_service', False)
......@@ -463,9 +472,8 @@ class InputReader:
dataset: Optional[tf.data.Dataset] = None) -> tf.data.Dataset:
"""Generates a tf.data.Dataset object."""
if dataset is None:
dataset = self._read_data_source(
self._matched_files, self._dataset_fn, input_context,
self._tfds_builder)
dataset = self._read_data_source(self._matched_files, self._dataset_fn,
input_context, self._tfds_builder)
dataset = self._decode_and_parse_dataset(dataset, self._global_batch_size,
input_context)
dataset = _maybe_map_fn(dataset, self._postprocess_fn)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment