Commit a76bc125 authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Internal Changes.

PiperOrigin-RevId: 289172941
parent 7124ed12
......@@ -143,7 +143,7 @@ def get_filenames(is_training, data_dir):
for i in range(128)]
def _parse_example_proto(example_serialized):
def parse_example_proto(example_serialized):
"""Parses an Example proto containing a training example of an image.
The output of the build_image_data.py image preprocessing script is a dataset
......@@ -228,7 +228,7 @@ def parse_record(raw_record, is_training, dtype):
Returns:
Tuple with processed image tensor and one-hot-encoded label tensor.
"""
image_buffer, label, bbox = _parse_example_proto(raw_record)
image_buffer, label, bbox = parse_example_proto(raw_record)
image = preprocess_image(
image_buffer=image_buffer,
......@@ -256,7 +256,8 @@ def input_fn(is_training,
input_context=None,
drop_remainder=False,
tf_data_experimental_slack=False,
training_dataset_cache=False):
training_dataset_cache=False,
filenames=None):
"""Input function which provides batches for train or eval.
Args:
......@@ -276,10 +277,12 @@ def input_fn(is_training,
training_dataset_cache: Whether to cache the training dataset on workers.
Typically used to improve training performance when training data is in
remote storage and can fit into worker memory.
filenames: Optional field for providing the file names of the TFRecords.
Returns:
A dataset that can be used for iteration.
"""
if filenames is None:
filenames = get_filenames(is_training, data_dir)
dataset = tf.data.Dataset.from_tensor_slices(filenames)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment