Unverified Commit 4cfa0d3b authored by Neal Wu's avatar Neal Wu Committed by GitHub
Browse files

Input improvements (#2706)

parent dcc23689
......@@ -79,13 +79,16 @@ def input_fn(is_training, filename, batch_size=1, num_epochs=1):
# a small dataset, we can easily shuffle the full epoch.
dataset = dataset.shuffle(buffer_size=_NUM_IMAGES['train'])
# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
dataset = dataset.repeat(num_epochs)
# Map example_parser over dataset, and batch results by up to batch_size
dataset = dataset.map(
example_parser, num_threads=1, output_buffer_size=batch_size)
dataset = dataset.batch(batch_size)
images, labels = dataset.make_one_shot_iterator().get_next()
iterator = dataset.make_one_shot_iterator()
images, labels = iterator.get_next()
return images, labels
......
......@@ -166,11 +166,14 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
num_threads=1,
output_buffer_size=2 * batch_size)
# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
dataset = dataset.repeat(num_epochs)
# Batch results by up to batch_size, and then fetch the tuple from the
# iterator.
iterator = dataset.batch(batch_size).make_one_shot_iterator()
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
images, labels = iterator.get_next()
return images, labels
......
......@@ -73,6 +73,7 @@ _NUM_IMAGES = {
'validation': 50000,
}
_FILE_SHUFFLE_BUFFER = 1024
_SHUFFLE_BUFFER = 1500
......@@ -81,11 +82,11 @@ def filenames(is_training, data_dir):
if is_training:
return [
os.path.join(data_dir, 'train-%05d-of-01024' % i)
for i in range(0, 1024)]
for i in range(1024)]
else:
return [
os.path.join(data_dir, 'validation-%05d-of-00128' % i)
for i in range(0, 128)]
for i in range(128)]
def dataset_parser(value, is_training):
......@@ -137,11 +138,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
filenames(is_training, data_dir))
if is_training:
dataset = dataset.shuffle(buffer_size=1024)
dataset = dataset.flat_map(tf.contrib.data.TFRecordDataset)
dataset = dataset.shuffle(buffer_size=_FILE_SHUFFLE_BUFFER)
if is_training:
dataset = dataset.repeat(num_epochs)
dataset = dataset.flat_map(tf.contrib.data.TFRecordDataset)
dataset = dataset.map(lambda value: dataset_parser(value, is_training),
num_threads=5,
......@@ -152,7 +151,12 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
# randomness, while smaller sizes have better performance.
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
iterator = dataset.batch(batch_size).make_one_shot_iterator()
# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
images, labels = iterator.get_next()
return images, labels
......@@ -188,8 +192,8 @@ def resnet_model_fn(features, labels, mode, params):
[tf.nn.l2_loss(v) for v in tf.trainable_variables()])
if mode == tf.estimator.ModeKeys.TRAIN:
# Scale the learning rate linearly with the batch size. When the batch size is
# 256, the learning rate should be 0.1.
# Scale the learning rate linearly with the batch size. When the batch size
# is 256, the learning rate should be 0.1.
initial_learning_rate = 0.1 * params['batch_size'] / 256
batches_per_epoch = _NUM_IMAGES['train'] / params['batch_size']
global_step = tf.train.get_or_create_global_step()
......
......@@ -61,6 +61,8 @@ parser.add_argument(
'--test_data', type=str, default='/tmp/census_data/adult.test',
help='Path to the test data.')
_SHUFFLE_BUFFER = 100000
def build_model_columns():
"""Builds a set of wide and deep feature columns."""
......@@ -167,6 +169,7 @@ def input_fn(data_file, num_epochs, shuffle, batch_size):
assert tf.gfile.Exists(data_file), (
'%s not found. Please make sure you have either run data_download.py or '
'set both arguments --train_data and --test_data.' % data_file)
def parse_csv(value):
print('Parsing', data_file)
columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS)
......@@ -178,49 +181,36 @@ def input_fn(data_file, num_epochs, shuffle, batch_size):
dataset = tf.contrib.data.TextLineDataset(data_file)
dataset = dataset.map(parse_csv, num_threads=5)
# Apply transformations to the Dataset
dataset = dataset.batch(batch_size)
if shuffle:
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
# Input function that is called by the Estimator
def _input_fn():
if shuffle:
# Apply shuffle transformation to re-shuffle the dataset in each call.
shuffled_dataset = dataset.shuffle(buffer_size=100000)
iterator = shuffled_dataset.make_one_shot_iterator()
else:
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
return _input_fn
def main(unused_argv):
# Clean up the model directory if present
shutil.rmtree(FLAGS.model_dir, ignore_errors=True)
model = build_estimator(FLAGS.model_dir, FLAGS.model_type)
# Set up input function generators for the train and test data files.
train_input_fn = input_fn(
data_file=FLAGS.train_data,
num_epochs=FLAGS.epochs_per_eval,
shuffle=True,
batch_size=FLAGS.batch_size)
eval_input_fn = input_fn(
data_file=FLAGS.test_data,
num_epochs=1,
shuffle=False,
batch_size=FLAGS.batch_size)
# Train and evaluate the model every `FLAGS.epochs_per_eval` epochs.
for n in range(FLAGS.train_epochs // FLAGS.epochs_per_eval):
model.train(input_fn=train_input_fn)
results = model.evaluate(input_fn=eval_input_fn)
model.train(input_fn=lambda: input_fn(
FLAGS.train_data, FLAGS.epochs_per_eval, True, FLAGS.batch_size))
results = model.evaluate(input_fn=lambda: input_fn(
FLAGS.test_data, 1, False, FLAGS.batch_size))
# Display evaluation metrics
print('Results at epoch', (n + 1) * FLAGS.epochs_per_eval)
print('-' * 30)
for key in sorted(results):
print('%s: %s' % (key, results[key]))
......
......@@ -54,7 +54,7 @@ class BaseTest(tf.test.TestCase):
temp_csv.write(TEST_INPUT)
def test_input_fn(self):
features, labels = wide_deep.input_fn(self.input_csv, 1, False, 1)()
features, labels = wide_deep.input_fn(self.input_csv, 1, False, 1)
with tf.Session() as sess:
features, labels = sess.run((features, labels))
......@@ -78,20 +78,20 @@ class BaseTest(tf.test.TestCase):
# Train for 1 step to initialize model and evaluate initial loss
model.train(
input_fn=wide_deep.input_fn(
input_fn=lambda: wide_deep.input_fn(
TEST_CSV, num_epochs=1, shuffle=True, batch_size=1),
steps=1)
initial_results = model.evaluate(
input_fn=wide_deep.input_fn(
input_fn=lambda: wide_deep.input_fn(
TEST_CSV, num_epochs=1, shuffle=False, batch_size=1))
# Train for 40 steps at batch size 2 and evaluate final loss
model.train(
input_fn=wide_deep.input_fn(
input_fn=lambda: wide_deep.input_fn(
TEST_CSV, num_epochs=None, shuffle=True, batch_size=2),
steps=40)
final_results = model.evaluate(
input_fn=wide_deep.input_fn(
input_fn=lambda: wide_deep.input_fn(
TEST_CSV, num_epochs=1, shuffle=False, batch_size=1))
print('%s initial results:' % model_type, initial_results)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment