Commit c64cb01b authored by Andrew Audibert's avatar Andrew Audibert Committed by A. Unique TensorFlower
Browse files

Add tf_data_service option to ResNet model.

Tested by running the model on TPU with a tf.data service running in GKE.

PiperOrigin-RevId: 316713637
parent 166f887c
...@@ -100,6 +100,9 @@ class DatasetConfig(base_config.Config): ...@@ -100,6 +100,9 @@ class DatasetConfig(base_config.Config):
skip_decoding: Whether to skip image decoding when loading from TFDS. skip_decoding: Whether to skip image decoding when loading from TFDS.
cache: whether to cache to dataset examples. Can be used to avoid re-reading cache: whether to cache to dataset examples. Can be used to avoid re-reading
from disk on the second epoch. Requires significant memory overhead. from disk on the second epoch. Requires significant memory overhead.
tf_data_service: The URI of a tf.data service to offload preprocessing onto
during training. The URI should be in the format "protocol://address",
e.g. "grpc://tf-data-service:5050".
mean_subtract: whether or not to apply mean subtraction to the dataset. mean_subtract: whether or not to apply mean subtraction to the dataset.
standardize: whether or not to apply standardization to the dataset. standardize: whether or not to apply standardization to the dataset.
""" """
...@@ -123,6 +126,7 @@ class DatasetConfig(base_config.Config): ...@@ -123,6 +126,7 @@ class DatasetConfig(base_config.Config):
file_shuffle_buffer_size: int = 1024 file_shuffle_buffer_size: int = 1024
skip_decoding: bool = True skip_decoding: bool = True
cache: bool = False cache: bool = False
tf_data_service: Optional[str] = None
mean_subtract: bool = False mean_subtract: bool = False
standardize: bool = False standardize: bool = False
...@@ -449,6 +453,18 @@ class DatasetBuilder: ...@@ -449,6 +453,18 @@ class DatasetBuilder:
# Prefetch overlaps in-feed with training # Prefetch overlaps in-feed with training
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
if self.config.tf_data_service:
if not hasattr(tf.data.experimental, 'service'):
raise ValueError('The tf_data_service flag requires Tensorflow version '
'>= 2.3.0, but the version is {}'.format(
tf.__version__))
dataset = dataset.apply(
tf.data.experimental.service.distribute(
processing_mode='parallel_epochs',
service=self.config.tf_data_service,
job_name='resnet_train'))
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
return dataset return dataset
def parse_record(self, record: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: def parse_record(self, record: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment