Commit a76bc125 authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Internal Changes.

PiperOrigin-RevId: 289172941
parent 7124ed12
...@@ -143,7 +143,7 @@ def get_filenames(is_training, data_dir): ...@@ -143,7 +143,7 @@ def get_filenames(is_training, data_dir):
for i in range(128)] for i in range(128)]
def _parse_example_proto(example_serialized): def parse_example_proto(example_serialized):
"""Parses an Example proto containing a training example of an image. """Parses an Example proto containing a training example of an image.
The output of the build_image_data.py image preprocessing script is a dataset The output of the build_image_data.py image preprocessing script is a dataset
...@@ -228,7 +228,7 @@ def parse_record(raw_record, is_training, dtype): ...@@ -228,7 +228,7 @@ def parse_record(raw_record, is_training, dtype):
Returns: Returns:
Tuple with processed image tensor and one-hot-encoded label tensor. Tuple with processed image tensor and one-hot-encoded label tensor.
""" """
image_buffer, label, bbox = _parse_example_proto(raw_record) image_buffer, label, bbox = parse_example_proto(raw_record)
image = preprocess_image( image = preprocess_image(
image_buffer=image_buffer, image_buffer=image_buffer,
...@@ -256,7 +256,8 @@ def input_fn(is_training, ...@@ -256,7 +256,8 @@ def input_fn(is_training,
input_context=None, input_context=None,
drop_remainder=False, drop_remainder=False,
tf_data_experimental_slack=False, tf_data_experimental_slack=False,
training_dataset_cache=False): training_dataset_cache=False,
filenames=None):
"""Input function which provides batches for train or eval. """Input function which provides batches for train or eval.
Args: Args:
...@@ -276,11 +277,13 @@ def input_fn(is_training, ...@@ -276,11 +277,13 @@ def input_fn(is_training,
training_dataset_cache: Whether to cache the training dataset on workers. training_dataset_cache: Whether to cache the training dataset on workers.
Typically used to improve training performance when training data is in Typically used to improve training performance when training data is in
remote storage and can fit into worker memory. remote storage and can fit into worker memory.
filenames: Optional field for providing the file names of the TFRecords.
Returns: Returns:
A dataset that can be used for iteration. A dataset that can be used for iteration.
""" """
filenames = get_filenames(is_training, data_dir) if filenames is None:
filenames = get_filenames(is_training, data_dir)
dataset = tf.data.Dataset.from_tensor_slices(filenames) dataset = tf.data.Dataset.from_tensor_slices(filenames)
if input_context: if input_context:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment