Unverified Commit 34e79348 authored by Taylor Robie's avatar Taylor Robie Committed by GitHub
Browse files

Arg parsing cleanup for MNIST and Wide-Deep (#3684)

* move wide_deep parser

* move mnist parsers
parent 1d38a225
...@@ -175,11 +175,14 @@ def validate_batch_size_for_multi_gpu(batch_size): ...@@ -175,11 +175,14 @@ def validate_batch_size_for_multi_gpu(batch_size):
raise ValueError(err) raise ValueError(err)
def main(_): def main(argv):
parser = MNISTArgParser()
flags = parser.parse_args(args=argv[1:])
model_function = model_fn model_function = model_fn
if FLAGS.multi_gpu: if flags.multi_gpu:
validate_batch_size_for_multi_gpu(FLAGS.batch_size) validate_batch_size_for_multi_gpu(flags.batch_size)
# There are two steps required if using multi-GPU: (1) wrap the model_fn, # There are two steps required if using multi-GPU: (1) wrap the model_fn,
# and (2) wrap the optimizer. The first happens here, and (2) happens # and (2) wrap the optimizer. The first happens here, and (2) happens
...@@ -187,16 +190,16 @@ def main(_): ...@@ -187,16 +190,16 @@ def main(_):
model_function = tf.contrib.estimator.replicate_model_fn( model_function = tf.contrib.estimator.replicate_model_fn(
model_fn, loss_reduction=tf.losses.Reduction.MEAN) model_fn, loss_reduction=tf.losses.Reduction.MEAN)
data_format = FLAGS.data_format data_format = flags.data_format
if data_format is None: if data_format is None:
data_format = ('channels_first' data_format = ('channels_first'
if tf.test.is_built_with_cuda() else 'channels_last') if tf.test.is_built_with_cuda() else 'channels_last')
mnist_classifier = tf.estimator.Estimator( mnist_classifier = tf.estimator.Estimator(
model_fn=model_function, model_fn=model_function,
model_dir=FLAGS.model_dir, model_dir=flags.model_dir,
params={ params={
'data_format': data_format, 'data_format': data_format,
'multi_gpu': FLAGS.multi_gpu 'multi_gpu': flags.multi_gpu
}) })
# Set up training and evaluation input functions. # Set up training and evaluation input functions.
...@@ -206,35 +209,35 @@ def main(_): ...@@ -206,35 +209,35 @@ def main(_):
# When choosing shuffle buffer sizes, larger sizes result in better # When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes use less memory. MNIST is a small # randomness, while smaller sizes use less memory. MNIST is a small
# enough dataset that we can easily shuffle the full epoch. # enough dataset that we can easily shuffle the full epoch.
ds = dataset.train(FLAGS.data_dir) ds = dataset.train(flags.data_dir)
ds = ds.cache().shuffle(buffer_size=50000).batch(FLAGS.batch_size) ds = ds.cache().shuffle(buffer_size=50000).batch(flags.batch_size)
# Iterate through the dataset a set number (`epochs_between_evals`) of times # Iterate through the dataset a set number (`epochs_between_evals`) of times
# during each training session. # during each training session.
ds = ds.repeat(FLAGS.epochs_between_evals) ds = ds.repeat(flags.epochs_between_evals)
return ds return ds
def eval_input_fn(): def eval_input_fn():
return dataset.test(FLAGS.data_dir).batch( return dataset.test(flags.data_dir).batch(
FLAGS.batch_size).make_one_shot_iterator().get_next() flags.batch_size).make_one_shot_iterator().get_next()
# Set up hook that outputs training logs every 100 steps. # Set up hook that outputs training logs every 100 steps.
train_hooks = hooks_helper.get_train_hooks( train_hooks = hooks_helper.get_train_hooks(
FLAGS.hooks, batch_size=FLAGS.batch_size) flags.hooks, batch_size=flags.batch_size)
# Train and evaluate model. # Train and evaluate model.
for _ in range(FLAGS.train_epochs // FLAGS.epochs_between_evals): for _ in range(flags.train_epochs // flags.epochs_between_evals):
mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks) mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn) eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
print('\nEvaluation results:\n\t%s\n' % eval_results) print('\nEvaluation results:\n\t%s\n' % eval_results)
# Export the model # Export the model
if FLAGS.export_dir is not None: if flags.export_dir is not None:
image = tf.placeholder(tf.float32, [None, 28, 28]) image = tf.placeholder(tf.float32, [None, 28, 28])
input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({ input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
'image': image, 'image': image,
}) })
mnist_classifier.export_savedmodel(FLAGS.export_dir, input_fn) mnist_classifier.export_savedmodel(flags.export_dir, input_fn)
class MNISTArgParser(argparse.ArgumentParser): class MNISTArgParser(argparse.ArgumentParser):
...@@ -261,6 +264,4 @@ class MNISTArgParser(argparse.ArgumentParser): ...@@ -261,6 +264,4 @@ class MNISTArgParser(argparse.ArgumentParser):
if __name__ == '__main__': if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
parser = MNISTArgParser() main(argv=sys.argv)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
...@@ -38,8 +38,6 @@ from official.mnist import dataset as mnist_dataset ...@@ -38,8 +38,6 @@ from official.mnist import dataset as mnist_dataset
from official.mnist import mnist from official.mnist import mnist
from official.utils.arg_parsers import parsers from official.utils.arg_parsers import parsers
FLAGS = None
def loss(logits, labels): def loss(logits, labels):
return tf.reduce_mean( return tf.reduce_mean(
...@@ -97,35 +95,38 @@ def test(model, dataset): ...@@ -97,35 +95,38 @@ def test(model, dataset):
tf.contrib.summary.scalar('accuracy', accuracy.result()) tf.contrib.summary.scalar('accuracy', accuracy.result())
def main(_): def main(argv):
parser = MNISTEagerArgParser()
flags = parser.parse_args(args=argv[1:])
tfe.enable_eager_execution() tfe.enable_eager_execution()
# Automatically determine device and data_format # Automatically determine device and data_format
(device, data_format) = ('/gpu:0', 'channels_first') (device, data_format) = ('/gpu:0', 'channels_first')
if FLAGS.no_gpu or tfe.num_gpus() <= 0: if flags.no_gpu or tfe.num_gpus() <= 0:
(device, data_format) = ('/cpu:0', 'channels_last') (device, data_format) = ('/cpu:0', 'channels_last')
# If data_format is defined in FLAGS, overwrite automatically set value. # If data_format is defined in FLAGS, overwrite automatically set value.
if FLAGS.data_format is not None: if flags.data_format is not None:
data_format = data_format data_format = data_format
print('Using device %s, and data format %s.' % (device, data_format)) print('Using device %s, and data format %s.' % (device, data_format))
# Load the datasets # Load the datasets
train_ds = mnist_dataset.train(FLAGS.data_dir).shuffle(60000).batch( train_ds = mnist_dataset.train(flags.data_dir).shuffle(60000).batch(
FLAGS.batch_size) flags.batch_size)
test_ds = mnist_dataset.test(FLAGS.data_dir).batch(FLAGS.batch_size) test_ds = mnist_dataset.test(flags.data_dir).batch(flags.batch_size)
# Create the model and optimizer # Create the model and optimizer
model = mnist.Model(data_format) model = mnist.Model(data_format)
optimizer = tf.train.MomentumOptimizer(FLAGS.lr, FLAGS.momentum) optimizer = tf.train.MomentumOptimizer(flags.lr, flags.momentum)
# Create file writers for writing TensorBoard summaries. # Create file writers for writing TensorBoard summaries.
if FLAGS.output_dir: if flags.output_dir:
# Create directories to which summaries will be written # Create directories to which summaries will be written
# tensorboard --logdir=<output_dir> # tensorboard --logdir=<output_dir>
# can then be used to see the recorded summaries. # can then be used to see the recorded summaries.
train_dir = os.path.join(FLAGS.output_dir, 'train') train_dir = os.path.join(flags.output_dir, 'train')
test_dir = os.path.join(FLAGS.output_dir, 'eval') test_dir = os.path.join(flags.output_dir, 'eval')
tf.gfile.MakeDirs(FLAGS.output_dir) tf.gfile.MakeDirs(flags.output_dir)
else: else:
train_dir = None train_dir = None
test_dir = None test_dir = None
...@@ -135,19 +136,19 @@ def main(_): ...@@ -135,19 +136,19 @@ def main(_):
test_dir, flush_millis=10000, name='test') test_dir, flush_millis=10000, name='test')
# Create and restore checkpoint (if one exists on the path) # Create and restore checkpoint (if one exists on the path)
checkpoint_prefix = os.path.join(FLAGS.model_dir, 'ckpt') checkpoint_prefix = os.path.join(flags.model_dir, 'ckpt')
step_counter = tf.train.get_or_create_global_step() step_counter = tf.train.get_or_create_global_step()
checkpoint = tfe.Checkpoint( checkpoint = tfe.Checkpoint(
model=model, optimizer=optimizer, step_counter=step_counter) model=model, optimizer=optimizer, step_counter=step_counter)
# Restore variables on creation if a checkpoint exists. # Restore variables on creation if a checkpoint exists.
checkpoint.restore(tf.train.latest_checkpoint(FLAGS.model_dir)) checkpoint.restore(tf.train.latest_checkpoint(flags.model_dir))
# Train and evaluate for a set number of epochs. # Train and evaluate for a set number of epochs.
with tf.device(device): with tf.device(device):
for _ in range(FLAGS.train_epochs): for _ in range(flags.train_epochs):
start = time.time() start = time.time()
with summary_writer.as_default(): with summary_writer.as_default():
train(model, optimizer, train_ds, step_counter, FLAGS.log_interval) train(model, optimizer, train_ds, step_counter, flags.log_interval)
end = time.time() end = time.time()
print('\nTrain time for epoch #%d (%d total steps): %f' % print('\nTrain time for epoch #%d (%d total steps): %f' %
(checkpoint.save_counter.numpy() + 1, (checkpoint.save_counter.numpy() + 1,
...@@ -205,6 +206,4 @@ class MNISTEagerArgParser(argparse.ArgumentParser): ...@@ -205,6 +206,4 @@ class MNISTEagerArgParser(argparse.ArgumentParser):
) )
if __name__ == '__main__': if __name__ == '__main__':
parser = MNISTEagerArgParser() main(argv=sys.argv)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
...@@ -171,33 +171,36 @@ def input_fn(data_file, num_epochs, shuffle, batch_size): ...@@ -171,33 +171,36 @@ def input_fn(data_file, num_epochs, shuffle, batch_size):
return dataset return dataset
def main(_): def main(argv):
parser = WideDeepArgParser()
flags = parser.parse_args(args=argv[1:])
# Clean up the model directory if present # Clean up the model directory if present
shutil.rmtree(FLAGS.model_dir, ignore_errors=True) shutil.rmtree(flags.model_dir, ignore_errors=True)
model = build_estimator(FLAGS.model_dir, FLAGS.model_type) model = build_estimator(flags.model_dir, flags.model_type)
train_file = os.path.join(FLAGS.data_dir, 'adult.data') train_file = os.path.join(flags.data_dir, 'adult.data')
test_file = os.path.join(FLAGS.data_dir, 'adult.test') test_file = os.path.join(flags.data_dir, 'adult.test')
# Train and evaluate the model every `FLAGS.epochs_per_eval` epochs. # Train and evaluate the model every `FLAGS.epochs_per_eval` epochs.
def train_input_fn(): def train_input_fn():
return input_fn(train_file, FLAGS.epochs_per_eval, True, FLAGS.batch_size) return input_fn(train_file, flags.epochs_per_eval, True, flags.batch_size)
def eval_input_fn(): def eval_input_fn():
return input_fn(test_file, 1, False, FLAGS.batch_size) return input_fn(test_file, 1, False, flags.batch_size)
train_hooks = hooks_helper.get_train_hooks( train_hooks = hooks_helper.get_train_hooks(
FLAGS.hooks, batch_size=FLAGS.batch_size, flags.hooks, batch_size=flags.batch_size,
tensors_to_log={'average_loss': 'head/truediv', tensors_to_log={'average_loss': 'head/truediv',
'loss': 'head/weighted_loss/Sum'}) 'loss': 'head/weighted_loss/Sum'})
# Train and evaluate the model every `FLAGS.epochs_between_evals` epochs. # Train and evaluate the model every `FLAGS.epochs_between_evals` epochs.
for n in range(FLAGS.train_epochs // FLAGS.epochs_between_evals): for n in range(flags.train_epochs // flags.epochs_between_evals):
model.train(input_fn=train_input_fn, hooks=train_hooks) model.train(input_fn=train_input_fn, hooks=train_hooks)
results = model.evaluate(input_fn=eval_input_fn) results = model.evaluate(input_fn=eval_input_fn)
# Display evaluation metrics # Display evaluation metrics
print('Results at epoch', (n + 1) * FLAGS.epochs_between_evals) print('Results at epoch', (n + 1) * flags.epochs_between_evals)
print('-' * 60) print('-' * 60)
for key in sorted(results): for key in sorted(results):
...@@ -224,6 +227,4 @@ class WideDeepArgParser(argparse.ArgumentParser): ...@@ -224,6 +227,4 @@ class WideDeepArgParser(argparse.ArgumentParser):
if __name__ == '__main__': if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
parser = WideDeepArgParser() main(argv=sys.argv)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment