Unverified Commit 34e79348 authored by Taylor Robie's avatar Taylor Robie Committed by GitHub
Browse files

Arg parsing cleanup for MNIST and Wide-Deep (#3684)

* move wide_deep parser

* move mnist parsers
parent 1d38a225
......@@ -175,11 +175,14 @@ def validate_batch_size_for_multi_gpu(batch_size):
raise ValueError(err)
def main(_):
def main(argv):
parser = MNISTArgParser()
flags = parser.parse_args(args=argv[1:])
model_function = model_fn
if FLAGS.multi_gpu:
validate_batch_size_for_multi_gpu(FLAGS.batch_size)
if flags.multi_gpu:
validate_batch_size_for_multi_gpu(flags.batch_size)
# There are two steps required if using multi-GPU: (1) wrap the model_fn,
# and (2) wrap the optimizer. The first happens here, and (2) happens
......@@ -187,16 +190,16 @@ def main(_):
model_function = tf.contrib.estimator.replicate_model_fn(
model_fn, loss_reduction=tf.losses.Reduction.MEAN)
data_format = FLAGS.data_format
data_format = flags.data_format
if data_format is None:
data_format = ('channels_first'
if tf.test.is_built_with_cuda() else 'channels_last')
mnist_classifier = tf.estimator.Estimator(
model_fn=model_function,
model_dir=FLAGS.model_dir,
model_dir=flags.model_dir,
params={
'data_format': data_format,
'multi_gpu': FLAGS.multi_gpu
'multi_gpu': flags.multi_gpu
})
# Set up training and evaluation input functions.
......@@ -206,35 +209,35 @@ def main(_):
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes use less memory. MNIST is a small
# enough dataset that we can easily shuffle the full epoch.
ds = dataset.train(FLAGS.data_dir)
ds = ds.cache().shuffle(buffer_size=50000).batch(FLAGS.batch_size)
ds = dataset.train(flags.data_dir)
ds = ds.cache().shuffle(buffer_size=50000).batch(flags.batch_size)
# Iterate through the dataset a set number (`epochs_between_evals`) of times
# during each training session.
ds = ds.repeat(FLAGS.epochs_between_evals)
ds = ds.repeat(flags.epochs_between_evals)
return ds
def eval_input_fn():
return dataset.test(FLAGS.data_dir).batch(
FLAGS.batch_size).make_one_shot_iterator().get_next()
return dataset.test(flags.data_dir).batch(
flags.batch_size).make_one_shot_iterator().get_next()
# Set up hook that outputs training logs every 100 steps.
train_hooks = hooks_helper.get_train_hooks(
FLAGS.hooks, batch_size=FLAGS.batch_size)
flags.hooks, batch_size=flags.batch_size)
# Train and evaluate model.
for _ in range(FLAGS.train_epochs // FLAGS.epochs_between_evals):
for _ in range(flags.train_epochs // flags.epochs_between_evals):
mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
print('\nEvaluation results:\n\t%s\n' % eval_results)
# Export the model
if FLAGS.export_dir is not None:
if flags.export_dir is not None:
image = tf.placeholder(tf.float32, [None, 28, 28])
input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
'image': image,
})
mnist_classifier.export_savedmodel(FLAGS.export_dir, input_fn)
mnist_classifier.export_savedmodel(flags.export_dir, input_fn)
class MNISTArgParser(argparse.ArgumentParser):
......@@ -261,6 +264,4 @@ class MNISTArgParser(argparse.ArgumentParser):
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
parser = MNISTArgParser()
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
main(argv=sys.argv)
......@@ -38,8 +38,6 @@ from official.mnist import dataset as mnist_dataset
from official.mnist import mnist
from official.utils.arg_parsers import parsers
FLAGS = None
def loss(logits, labels):
return tf.reduce_mean(
......@@ -97,35 +95,38 @@ def test(model, dataset):
tf.contrib.summary.scalar('accuracy', accuracy.result())
def main(_):
def main(argv):
parser = MNISTEagerArgParser()
flags = parser.parse_args(args=argv[1:])
tfe.enable_eager_execution()
# Automatically determine device and data_format
(device, data_format) = ('/gpu:0', 'channels_first')
if FLAGS.no_gpu or tfe.num_gpus() <= 0:
if flags.no_gpu or tfe.num_gpus() <= 0:
(device, data_format) = ('/cpu:0', 'channels_last')
# If data_format is defined in FLAGS, overwrite automatically set value.
if FLAGS.data_format is not None:
if flags.data_format is not None:
data_format = data_format
print('Using device %s, and data format %s.' % (device, data_format))
# Load the datasets
train_ds = mnist_dataset.train(FLAGS.data_dir).shuffle(60000).batch(
FLAGS.batch_size)
test_ds = mnist_dataset.test(FLAGS.data_dir).batch(FLAGS.batch_size)
train_ds = mnist_dataset.train(flags.data_dir).shuffle(60000).batch(
flags.batch_size)
test_ds = mnist_dataset.test(flags.data_dir).batch(flags.batch_size)
# Create the model and optimizer
model = mnist.Model(data_format)
optimizer = tf.train.MomentumOptimizer(FLAGS.lr, FLAGS.momentum)
optimizer = tf.train.MomentumOptimizer(flags.lr, flags.momentum)
# Create file writers for writing TensorBoard summaries.
if FLAGS.output_dir:
if flags.output_dir:
# Create directories to which summaries will be written
# tensorboard --logdir=<output_dir>
# can then be used to see the recorded summaries.
train_dir = os.path.join(FLAGS.output_dir, 'train')
test_dir = os.path.join(FLAGS.output_dir, 'eval')
tf.gfile.MakeDirs(FLAGS.output_dir)
train_dir = os.path.join(flags.output_dir, 'train')
test_dir = os.path.join(flags.output_dir, 'eval')
tf.gfile.MakeDirs(flags.output_dir)
else:
train_dir = None
test_dir = None
......@@ -135,19 +136,19 @@ def main(_):
test_dir, flush_millis=10000, name='test')
# Create and restore checkpoint (if one exists on the path)
checkpoint_prefix = os.path.join(FLAGS.model_dir, 'ckpt')
checkpoint_prefix = os.path.join(flags.model_dir, 'ckpt')
step_counter = tf.train.get_or_create_global_step()
checkpoint = tfe.Checkpoint(
model=model, optimizer=optimizer, step_counter=step_counter)
# Restore variables on creation if a checkpoint exists.
checkpoint.restore(tf.train.latest_checkpoint(FLAGS.model_dir))
checkpoint.restore(tf.train.latest_checkpoint(flags.model_dir))
# Train and evaluate for a set number of epochs.
with tf.device(device):
for _ in range(FLAGS.train_epochs):
for _ in range(flags.train_epochs):
start = time.time()
with summary_writer.as_default():
train(model, optimizer, train_ds, step_counter, FLAGS.log_interval)
train(model, optimizer, train_ds, step_counter, flags.log_interval)
end = time.time()
print('\nTrain time for epoch #%d (%d total steps): %f' %
(checkpoint.save_counter.numpy() + 1,
......@@ -205,6 +206,4 @@ class MNISTEagerArgParser(argparse.ArgumentParser):
)
if __name__ == '__main__':
parser = MNISTEagerArgParser()
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
main(argv=sys.argv)
......@@ -171,33 +171,36 @@ def input_fn(data_file, num_epochs, shuffle, batch_size):
return dataset
def main(_):
def main(argv):
parser = WideDeepArgParser()
flags = parser.parse_args(args=argv[1:])
# Clean up the model directory if present
shutil.rmtree(FLAGS.model_dir, ignore_errors=True)
model = build_estimator(FLAGS.model_dir, FLAGS.model_type)
shutil.rmtree(flags.model_dir, ignore_errors=True)
model = build_estimator(flags.model_dir, flags.model_type)
train_file = os.path.join(FLAGS.data_dir, 'adult.data')
test_file = os.path.join(FLAGS.data_dir, 'adult.test')
train_file = os.path.join(flags.data_dir, 'adult.data')
test_file = os.path.join(flags.data_dir, 'adult.test')
# Train and evaluate the model every `FLAGS.epochs_per_eval` epochs.
def train_input_fn():
return input_fn(train_file, FLAGS.epochs_per_eval, True, FLAGS.batch_size)
return input_fn(train_file, flags.epochs_per_eval, True, flags.batch_size)
def eval_input_fn():
return input_fn(test_file, 1, False, FLAGS.batch_size)
return input_fn(test_file, 1, False, flags.batch_size)
train_hooks = hooks_helper.get_train_hooks(
FLAGS.hooks, batch_size=FLAGS.batch_size,
flags.hooks, batch_size=flags.batch_size,
tensors_to_log={'average_loss': 'head/truediv',
'loss': 'head/weighted_loss/Sum'})
# Train and evaluate the model every `FLAGS.epochs_between_evals` epochs.
for n in range(FLAGS.train_epochs // FLAGS.epochs_between_evals):
for n in range(flags.train_epochs // flags.epochs_between_evals):
model.train(input_fn=train_input_fn, hooks=train_hooks)
results = model.evaluate(input_fn=eval_input_fn)
# Display evaluation metrics
print('Results at epoch', (n + 1) * FLAGS.epochs_between_evals)
print('Results at epoch', (n + 1) * flags.epochs_between_evals)
print('-' * 60)
for key in sorted(results):
......@@ -224,6 +227,4 @@ class WideDeepArgParser(argparse.ArgumentParser):
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
parser = WideDeepArgParser()
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
main(argv=sys.argv)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment