Commit a367dc49 authored by Asim Shankar's avatar Asim Shankar
Browse files

[mnist]: Make flags function scoped (not module scoped).

This will make it easier to share the model definition with
eager execution and TPU demos without any side effects of running
unnecessary code on module import.
parent 72f5834c
...@@ -24,45 +24,6 @@ import sys ...@@ -24,45 +24,6 @@ import sys
import tensorflow as tf import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data from tensorflow.examples.tutorials.mnist import input_data
parser = argparse.ArgumentParser()
# Basic model parameters.
parser.add_argument(
'--batch_size',
type=int,
default=100,
help='Number of images to process in a batch')
parser.add_argument(
'--data_dir',
type=str,
default='/tmp/mnist_data',
help='Path to directory containing the MNIST dataset')
parser.add_argument(
'--model_dir',
type=str,
default='/tmp/mnist_model',
help='The directory where the model will be stored.')
parser.add_argument(
'--train_epochs', type=int, default=40, help='Number of epochs to train.')
parser.add_argument(
'--data_format',
type=str,
default=None,
choices=['channels_first', 'channels_last'],
help='A flag to override the data format used in the model. channels_first '
'provides a performance boost on GPU but is not always compatible '
'with CPU. If left unspecified, the data format will be chosen '
'automatically based on whether TensorFlow was built for CPU or GPU.')
parser.add_argument(
'--export_dir',
type=str,
help='The directory where the exported SavedModel will be stored.')
def train_dataset(data_dir): def train_dataset(data_dir):
"""Returns a tf.data.Dataset yielding (image, label) pairs for training.""" """Returns a tf.data.Dataset yielding (image, label) pairs for training."""
...@@ -220,6 +181,38 @@ def main(unused_argv): ...@@ -220,6 +181,38 @@ def main(unused_argv):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--batch_size',
type=int,
default=100,
help='Number of images to process in a batch')
parser.add_argument(
'--data_dir',
type=str,
default='/tmp/mnist_data',
help='Path to directory containing the MNIST dataset')
parser.add_argument(
'--model_dir',
type=str,
default='/tmp/mnist_model',
help='The directory where the model will be stored.')
parser.add_argument(
'--train_epochs', type=int, default=40, help='Number of epochs to train.')
parser.add_argument(
'--data_format',
type=str,
default=None,
choices=['channels_first', 'channels_last'],
help='A flag to override the data format used in the model. channels_first '
'provides a performance boost on GPU but is not always compatible '
'with CPU. If left unspecified, the data format will be chosen '
'automatically based on whether TensorFlow was built for CPU or GPU.')
parser.add_argument(
'--export_dir',
type=str,
help='The directory where the exported SavedModel will be stored.')
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
FLAGS, unparsed = parser.parse_known_args() FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment