Unverified Commit 4cfa0d3b authored by Neal Wu's avatar Neal Wu Committed by GitHub
Browse files

Input improvements (#2706)

parent dcc23689
...@@ -79,13 +79,16 @@ def input_fn(is_training, filename, batch_size=1, num_epochs=1): ...@@ -79,13 +79,16 @@ def input_fn(is_training, filename, batch_size=1, num_epochs=1):
# a small dataset, we can easily shuffle the full epoch. # a small dataset, we can easily shuffle the full epoch.
dataset = dataset.shuffle(buffer_size=_NUM_IMAGES['train']) dataset = dataset.shuffle(buffer_size=_NUM_IMAGES['train'])
# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
dataset = dataset.repeat(num_epochs) dataset = dataset.repeat(num_epochs)
# Map example_parser over dataset, and batch results by up to batch_size # Map example_parser over dataset, and batch results by up to batch_size
dataset = dataset.map( dataset = dataset.map(
example_parser, num_threads=1, output_buffer_size=batch_size) example_parser, num_threads=1, output_buffer_size=batch_size)
dataset = dataset.batch(batch_size) dataset = dataset.batch(batch_size)
images, labels = dataset.make_one_shot_iterator().get_next() iterator = dataset.make_one_shot_iterator()
images, labels = iterator.get_next()
return images, labels return images, labels
......
...@@ -166,11 +166,14 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1): ...@@ -166,11 +166,14 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
num_threads=1, num_threads=1,
output_buffer_size=2 * batch_size) output_buffer_size=2 * batch_size)
# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
dataset = dataset.repeat(num_epochs) dataset = dataset.repeat(num_epochs)
# Batch results by up to batch_size, and then fetch the tuple from the # Batch results by up to batch_size, and then fetch the tuple from the
# iterator. # iterator.
iterator = dataset.batch(batch_size).make_one_shot_iterator() dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
images, labels = iterator.get_next() images, labels = iterator.get_next()
return images, labels return images, labels
......
...@@ -73,6 +73,7 @@ _NUM_IMAGES = { ...@@ -73,6 +73,7 @@ _NUM_IMAGES = {
'validation': 50000, 'validation': 50000,
} }
_FILE_SHUFFLE_BUFFER = 1024
_SHUFFLE_BUFFER = 1500 _SHUFFLE_BUFFER = 1500
...@@ -81,11 +82,11 @@ def filenames(is_training, data_dir): ...@@ -81,11 +82,11 @@ def filenames(is_training, data_dir):
if is_training: if is_training:
return [ return [
os.path.join(data_dir, 'train-%05d-of-01024' % i) os.path.join(data_dir, 'train-%05d-of-01024' % i)
for i in range(0, 1024)] for i in range(1024)]
else: else:
return [ return [
os.path.join(data_dir, 'validation-%05d-of-00128' % i) os.path.join(data_dir, 'validation-%05d-of-00128' % i)
for i in range(0, 128)] for i in range(128)]
def dataset_parser(value, is_training): def dataset_parser(value, is_training):
...@@ -137,11 +138,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1): ...@@ -137,11 +138,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
filenames(is_training, data_dir)) filenames(is_training, data_dir))
if is_training: if is_training:
dataset = dataset.shuffle(buffer_size=1024) dataset = dataset.shuffle(buffer_size=_FILE_SHUFFLE_BUFFER)
dataset = dataset.flat_map(tf.contrib.data.TFRecordDataset)
if is_training: dataset = dataset.flat_map(tf.contrib.data.TFRecordDataset)
dataset = dataset.repeat(num_epochs)
dataset = dataset.map(lambda value: dataset_parser(value, is_training), dataset = dataset.map(lambda value: dataset_parser(value, is_training),
num_threads=5, num_threads=5,
...@@ -152,7 +151,12 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1): ...@@ -152,7 +151,12 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
# randomness, while smaller sizes have better performance. # randomness, while smaller sizes have better performance.
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER) dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
iterator = dataset.batch(batch_size).make_one_shot_iterator() # We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
images, labels = iterator.get_next() images, labels = iterator.get_next()
return images, labels return images, labels
...@@ -188,8 +192,8 @@ def resnet_model_fn(features, labels, mode, params): ...@@ -188,8 +192,8 @@ def resnet_model_fn(features, labels, mode, params):
[tf.nn.l2_loss(v) for v in tf.trainable_variables()]) [tf.nn.l2_loss(v) for v in tf.trainable_variables()])
if mode == tf.estimator.ModeKeys.TRAIN: if mode == tf.estimator.ModeKeys.TRAIN:
# Scale the learning rate linearly with the batch size. When the batch size is # Scale the learning rate linearly with the batch size. When the batch size
# 256, the learning rate should be 0.1. # is 256, the learning rate should be 0.1.
initial_learning_rate = 0.1 * params['batch_size'] / 256 initial_learning_rate = 0.1 * params['batch_size'] / 256
batches_per_epoch = _NUM_IMAGES['train'] / params['batch_size'] batches_per_epoch = _NUM_IMAGES['train'] / params['batch_size']
global_step = tf.train.get_or_create_global_step() global_step = tf.train.get_or_create_global_step()
......
...@@ -61,6 +61,8 @@ parser.add_argument( ...@@ -61,6 +61,8 @@ parser.add_argument(
'--test_data', type=str, default='/tmp/census_data/adult.test', '--test_data', type=str, default='/tmp/census_data/adult.test',
help='Path to the test data.') help='Path to the test data.')
_SHUFFLE_BUFFER = 100000
def build_model_columns(): def build_model_columns():
"""Builds a set of wide and deep feature columns.""" """Builds a set of wide and deep feature columns."""
...@@ -167,6 +169,7 @@ def input_fn(data_file, num_epochs, shuffle, batch_size): ...@@ -167,6 +169,7 @@ def input_fn(data_file, num_epochs, shuffle, batch_size):
assert tf.gfile.Exists(data_file), ( assert tf.gfile.Exists(data_file), (
'%s not found. Please make sure you have either run data_download.py or ' '%s not found. Please make sure you have either run data_download.py or '
'set both arguments --train_data and --test_data.' % data_file) 'set both arguments --train_data and --test_data.' % data_file)
def parse_csv(value): def parse_csv(value):
print('Parsing', data_file) print('Parsing', data_file)
columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS) columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS)
...@@ -178,49 +181,36 @@ def input_fn(data_file, num_epochs, shuffle, batch_size): ...@@ -178,49 +181,36 @@ def input_fn(data_file, num_epochs, shuffle, batch_size):
dataset = tf.contrib.data.TextLineDataset(data_file) dataset = tf.contrib.data.TextLineDataset(data_file)
dataset = dataset.map(parse_csv, num_threads=5) dataset = dataset.map(parse_csv, num_threads=5)
# Apply transformations to the Dataset if shuffle:
dataset = dataset.batch(batch_size) dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
dataset = dataset.repeat(num_epochs) dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
# Input function that is called by the Estimator iterator = dataset.make_one_shot_iterator()
def _input_fn(): features, labels = iterator.get_next()
if shuffle: return features, labels
# Apply shuffle transformation to re-shuffle the dataset in each call.
shuffled_dataset = dataset.shuffle(buffer_size=100000)
iterator = shuffled_dataset.make_one_shot_iterator()
else:
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
return _input_fn
def main(unused_argv): def main(unused_argv):
# Clean up the model directory if present # Clean up the model directory if present
shutil.rmtree(FLAGS.model_dir, ignore_errors=True) shutil.rmtree(FLAGS.model_dir, ignore_errors=True)
model = build_estimator(FLAGS.model_dir, FLAGS.model_type) model = build_estimator(FLAGS.model_dir, FLAGS.model_type)
# Set up input function generators for the train and test data files.
train_input_fn = input_fn(
data_file=FLAGS.train_data,
num_epochs=FLAGS.epochs_per_eval,
shuffle=True,
batch_size=FLAGS.batch_size)
eval_input_fn = input_fn(
data_file=FLAGS.test_data,
num_epochs=1,
shuffle=False,
batch_size=FLAGS.batch_size)
# Train and evaluate the model every `FLAGS.epochs_per_eval` epochs. # Train and evaluate the model every `FLAGS.epochs_per_eval` epochs.
for n in range(FLAGS.train_epochs // FLAGS.epochs_per_eval): for n in range(FLAGS.train_epochs // FLAGS.epochs_per_eval):
model.train(input_fn=train_input_fn) model.train(input_fn=lambda: input_fn(
results = model.evaluate(input_fn=eval_input_fn) FLAGS.train_data, FLAGS.epochs_per_eval, True, FLAGS.batch_size))
results = model.evaluate(input_fn=lambda: input_fn(
FLAGS.test_data, 1, False, FLAGS.batch_size))
# Display evaluation metrics # Display evaluation metrics
print('Results at epoch', (n + 1) * FLAGS.epochs_per_eval) print('Results at epoch', (n + 1) * FLAGS.epochs_per_eval)
print('-' * 30) print('-' * 30)
for key in sorted(results): for key in sorted(results):
print('%s: %s' % (key, results[key])) print('%s: %s' % (key, results[key]))
......
...@@ -54,7 +54,7 @@ class BaseTest(tf.test.TestCase): ...@@ -54,7 +54,7 @@ class BaseTest(tf.test.TestCase):
temp_csv.write(TEST_INPUT) temp_csv.write(TEST_INPUT)
def test_input_fn(self): def test_input_fn(self):
features, labels = wide_deep.input_fn(self.input_csv, 1, False, 1)() features, labels = wide_deep.input_fn(self.input_csv, 1, False, 1)
with tf.Session() as sess: with tf.Session() as sess:
features, labels = sess.run((features, labels)) features, labels = sess.run((features, labels))
...@@ -78,20 +78,20 @@ class BaseTest(tf.test.TestCase): ...@@ -78,20 +78,20 @@ class BaseTest(tf.test.TestCase):
# Train for 1 step to initialize model and evaluate initial loss # Train for 1 step to initialize model and evaluate initial loss
model.train( model.train(
input_fn=wide_deep.input_fn( input_fn=lambda: wide_deep.input_fn(
TEST_CSV, num_epochs=1, shuffle=True, batch_size=1), TEST_CSV, num_epochs=1, shuffle=True, batch_size=1),
steps=1) steps=1)
initial_results = model.evaluate( initial_results = model.evaluate(
input_fn=wide_deep.input_fn( input_fn=lambda: wide_deep.input_fn(
TEST_CSV, num_epochs=1, shuffle=False, batch_size=1)) TEST_CSV, num_epochs=1, shuffle=False, batch_size=1))
# Train for 40 steps at batch size 2 and evaluate final loss # Train for 40 steps at batch size 2 and evaluate final loss
model.train( model.train(
input_fn=wide_deep.input_fn( input_fn=lambda: wide_deep.input_fn(
TEST_CSV, num_epochs=None, shuffle=True, batch_size=2), TEST_CSV, num_epochs=None, shuffle=True, batch_size=2),
steps=40) steps=40)
final_results = model.evaluate( final_results = model.evaluate(
input_fn=wide_deep.input_fn( input_fn=lambda: wide_deep.input_fn(
TEST_CSV, num_epochs=1, shuffle=False, batch_size=1)) TEST_CSV, num_epochs=1, shuffle=False, batch_size=1))
print('%s initial results:' % model_type, initial_results) print('%s initial results:' % model_type, initial_results)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment