Commit 1b45a4a5 authored by Rick Chao's avatar Rick Chao Committed by A. Unique TensorFlower
Browse files

Remove num_epochs in imagenet_preprocessing.py as it's unused.

PiperOrigin-RevId: 299424863
parent 11ccb99e
...@@ -115,7 +115,6 @@ def get_filenames(is_training, data_dir): ...@@ -115,7 +115,6 @@ def get_filenames(is_training, data_dir):
def input_fn(is_training, def input_fn(is_training,
data_dir, data_dir,
batch_size, batch_size,
num_epochs=1,
dtype=tf.float32, dtype=tf.float32,
datasets_num_private_threads=None, datasets_num_private_threads=None,
parse_record_fn=parse_record, parse_record_fn=parse_record,
...@@ -127,7 +126,6 @@ def input_fn(is_training, ...@@ -127,7 +126,6 @@ def input_fn(is_training,
is_training: A boolean denoting whether the input is for training. is_training: A boolean denoting whether the input is for training.
data_dir: The directory containing the input data. data_dir: The directory containing the input data.
batch_size: The number of samples per batch. batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset.
dtype: Data type to use for images/features dtype: Data type to use for images/features
datasets_num_private_threads: Number of private threads for tf.data. datasets_num_private_threads: Number of private threads for tf.data.
parse_record_fn: Function to use for parsing the records. parse_record_fn: Function to use for parsing the records.
...@@ -155,7 +153,6 @@ def input_fn(is_training, ...@@ -155,7 +153,6 @@ def input_fn(is_training,
batch_size=batch_size, batch_size=batch_size,
shuffle_buffer=NUM_IMAGES['train'], shuffle_buffer=NUM_IMAGES['train'],
parse_record_fn=parse_record_fn, parse_record_fn=parse_record_fn,
num_epochs=num_epochs,
dtype=dtype, dtype=dtype,
datasets_num_private_threads=datasets_num_private_threads, datasets_num_private_threads=datasets_num_private_threads,
drop_remainder=drop_remainder drop_remainder=drop_remainder
......
...@@ -67,7 +67,6 @@ def process_record_dataset(dataset, ...@@ -67,7 +67,6 @@ def process_record_dataset(dataset,
batch_size, batch_size,
shuffle_buffer, shuffle_buffer,
parse_record_fn, parse_record_fn,
num_epochs=1,
dtype=tf.float32, dtype=tf.float32,
datasets_num_private_threads=None, datasets_num_private_threads=None,
drop_remainder=False, drop_remainder=False,
...@@ -83,7 +82,6 @@ def process_record_dataset(dataset, ...@@ -83,7 +82,6 @@ def process_record_dataset(dataset,
time and use less memory. time and use less memory.
parse_record_fn: A function that takes a raw record and returns the parse_record_fn: A function that takes a raw record and returns the
corresponding (image, label) pair. corresponding (image, label) pair.
num_epochs: The number of epochs to repeat the dataset.
dtype: Data type to use for images/features. dtype: Data type to use for images/features.
datasets_num_private_threads: Number of threads for a private datasets_num_private_threads: Number of threads for a private
threadpool created for all datasets computation. threadpool created for all datasets computation.
...@@ -276,7 +274,6 @@ def get_parse_record_fn(use_keras_image_data_format=False): ...@@ -276,7 +274,6 @@ def get_parse_record_fn(use_keras_image_data_format=False):
def input_fn(is_training, def input_fn(is_training,
data_dir, data_dir,
batch_size, batch_size,
num_epochs=1,
dtype=tf.float32, dtype=tf.float32,
datasets_num_private_threads=None, datasets_num_private_threads=None,
parse_record_fn=parse_record, parse_record_fn=parse_record,
...@@ -291,7 +288,6 @@ def input_fn(is_training, ...@@ -291,7 +288,6 @@ def input_fn(is_training,
is_training: A boolean denoting whether the input is for training. is_training: A boolean denoting whether the input is for training.
data_dir: The directory containing the input data. data_dir: The directory containing the input data.
batch_size: The number of samples per batch. batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset.
dtype: Data type to use for images/features dtype: Data type to use for images/features
datasets_num_private_threads: Number of private threads for tf.data. datasets_num_private_threads: Number of private threads for tf.data.
parse_record_fn: Function to use for parsing the records. parse_record_fn: Function to use for parsing the records.
...@@ -344,7 +340,6 @@ def input_fn(is_training, ...@@ -344,7 +340,6 @@ def input_fn(is_training,
batch_size=batch_size, batch_size=batch_size,
shuffle_buffer=_SHUFFLE_BUFFER, shuffle_buffer=_SHUFFLE_BUFFER,
parse_record_fn=parse_record_fn, parse_record_fn=parse_record_fn,
num_epochs=num_epochs,
dtype=dtype, dtype=dtype,
datasets_num_private_threads=datasets_num_private_threads, datasets_num_private_threads=datasets_num_private_threads,
drop_remainder=drop_remainder, drop_remainder=drop_remainder,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment