Unverified Commit 7d032ea3 authored by guptapriya's avatar guptapriya Committed by GitHub
Browse files

Merge pull request #5900 from tensorflow/priyag-resnet-main-changes

Allow custom parse record method in resnet input functions
parents c9f03bf6 3256c49f
......@@ -111,7 +111,7 @@ def preprocess_image(image, is_training):
def input_fn(is_training, data_dir, batch_size, num_epochs=1,
dtype=tf.float32, datasets_num_private_threads=None,
num_parallel_batches=1):
num_parallel_batches=1, parse_record_fn=parse_record):
"""Input function which provides batches for train or eval.
Args:
......@@ -122,6 +122,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
dtype: Data type to use for images/features
datasets_num_private_threads: Number of private threads for tf.data.
num_parallel_batches: Number of parallel batches for tf.data.
parse_record_fn: Function to use for parsing the records.
Returns:
A dataset that can be used for iteration.
......@@ -134,7 +135,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
is_training=is_training,
batch_size=batch_size,
shuffle_buffer=_NUM_IMAGES['train'],
parse_record_fn=parse_record,
parse_record_fn=parse_record_fn,
num_epochs=num_epochs,
dtype=dtype,
datasets_num_private_threads=datasets_num_private_threads,
......
......@@ -160,7 +160,7 @@ def parse_record(raw_record, is_training, dtype):
def input_fn(is_training, data_dir, batch_size, num_epochs=1,
dtype=tf.float32, datasets_num_private_threads=None,
num_parallel_batches=1):
num_parallel_batches=1, parse_record_fn=parse_record):
"""Input function which provides batches for train or eval.
Args:
......@@ -171,6 +171,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
dtype: Data type to use for images/features
datasets_num_private_threads: Number of private threads for tf.data.
num_parallel_batches: Number of parallel batches for tf.data.
parse_record_fn: Function to use for parsing the records.
Returns:
A dataset that can be used for iteration.
......@@ -195,7 +196,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
is_training=is_training,
batch_size=batch_size,
shuffle_buffer=_SHUFFLE_BUFFER,
parse_record_fn=parse_record,
parse_record_fn=parse_record_fn,
num_epochs=num_epochs,
dtype=dtype,
datasets_num_private_threads=datasets_num_private_threads,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment