"...srt/layers/git@developer.sourcefind.cn:change/sglang.git" did not exist on "13bc39c5d631f7783675674fa7615d594ed533b2"
Commit 3256c49f authored by Priya Gupta's avatar Priya Gupta
Browse files

allow specifying parse function

parent c9f03bf6
...@@ -111,7 +111,7 @@ def preprocess_image(image, is_training): ...@@ -111,7 +111,7 @@ def preprocess_image(image, is_training):
def input_fn(is_training, data_dir, batch_size, num_epochs=1, def input_fn(is_training, data_dir, batch_size, num_epochs=1,
dtype=tf.float32, datasets_num_private_threads=None, dtype=tf.float32, datasets_num_private_threads=None,
num_parallel_batches=1): num_parallel_batches=1, parse_record_fn=parse_record):
"""Input function which provides batches for train or eval. """Input function which provides batches for train or eval.
Args: Args:
...@@ -122,6 +122,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, ...@@ -122,6 +122,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
dtype: Data type to use for images/features dtype: Data type to use for images/features
datasets_num_private_threads: Number of private threads for tf.data. datasets_num_private_threads: Number of private threads for tf.data.
num_parallel_batches: Number of parallel batches for tf.data. num_parallel_batches: Number of parallel batches for tf.data.
parse_record_fn: Function to use for parsing the records.
Returns: Returns:
A dataset that can be used for iteration. A dataset that can be used for iteration.
...@@ -134,7 +135,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, ...@@ -134,7 +135,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
is_training=is_training, is_training=is_training,
batch_size=batch_size, batch_size=batch_size,
shuffle_buffer=_NUM_IMAGES['train'], shuffle_buffer=_NUM_IMAGES['train'],
parse_record_fn=parse_record, parse_record_fn=parse_record_fn,
num_epochs=num_epochs, num_epochs=num_epochs,
dtype=dtype, dtype=dtype,
datasets_num_private_threads=datasets_num_private_threads, datasets_num_private_threads=datasets_num_private_threads,
......
...@@ -160,7 +160,7 @@ def parse_record(raw_record, is_training, dtype): ...@@ -160,7 +160,7 @@ def parse_record(raw_record, is_training, dtype):
def input_fn(is_training, data_dir, batch_size, num_epochs=1, def input_fn(is_training, data_dir, batch_size, num_epochs=1,
dtype=tf.float32, datasets_num_private_threads=None, dtype=tf.float32, datasets_num_private_threads=None,
num_parallel_batches=1): num_parallel_batches=1, parse_record_fn=parse_record):
"""Input function which provides batches for train or eval. """Input function which provides batches for train or eval.
Args: Args:
...@@ -171,6 +171,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, ...@@ -171,6 +171,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
dtype: Data type to use for images/features dtype: Data type to use for images/features
datasets_num_private_threads: Number of private threads for tf.data. datasets_num_private_threads: Number of private threads for tf.data.
num_parallel_batches: Number of parallel batches for tf.data. num_parallel_batches: Number of parallel batches for tf.data.
parse_record_fn: Function to use for parsing the records.
Returns: Returns:
A dataset that can be used for iteration. A dataset that can be used for iteration.
...@@ -195,7 +196,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, ...@@ -195,7 +196,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
is_training=is_training, is_training=is_training,
batch_size=batch_size, batch_size=batch_size,
shuffle_buffer=_SHUFFLE_BUFFER, shuffle_buffer=_SHUFFLE_BUFFER,
parse_record_fn=parse_record, parse_record_fn=parse_record_fn,
num_epochs=num_epochs, num_epochs=num_epochs,
dtype=dtype, dtype=dtype,
datasets_num_private_threads=datasets_num_private_threads, datasets_num_private_threads=datasets_num_private_threads,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment