Commit 0e09477a authored by Olivia's avatar Olivia Committed by GitHub
Browse files

Merge pull request #2443 from tensorflow/flags

Flags
parents 1630da34 8815a860
...@@ -35,6 +35,7 @@ from __future__ import absolute_import ...@@ -35,6 +35,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse
import os import os
import re import re
import sys import sys
...@@ -45,15 +46,19 @@ import tensorflow as tf ...@@ -45,15 +46,19 @@ import tensorflow as tf
import cifar10_input import cifar10_input
FLAGS = tf.app.flags.FLAGS parser = argparse.ArgumentParser()
# Basic model parameters. # Basic model parameters.
tf.app.flags.DEFINE_integer('batch_size', 128, parser.add_argument('--batch_size', type=int, default=128,
"""Number of images to process in a batch.""") help='Number of images to process in a batch.')
tf.app.flags.DEFINE_string('data_dir', '/tmp/cifar10_data',
"""Path to the CIFAR-10 data directory.""") parser.add_argument('--data_dir', type=str, default='/tmp/cifar10_data',
tf.app.flags.DEFINE_boolean('use_fp16', False, help='Path to the CIFAR-10 data directory.')
"""Train the model using fp16.""")
parser.add_argument('--use_fp16', type=bool, default=False,
help='Train the model using fp16.')
FLAGS = parser.parse_args()
# Global constants describing the CIFAR-10 data set. # Global constants describing the CIFAR-10 data set.
IMAGE_SIZE = cifar10_input.IMAGE_SIZE IMAGE_SIZE = cifar10_input.IMAGE_SIZE
......
...@@ -34,6 +34,7 @@ from __future__ import absolute_import ...@@ -34,6 +34,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse
from datetime import datetime from datetime import datetime
import math import math
import time import time
...@@ -43,20 +44,25 @@ import tensorflow as tf ...@@ -43,20 +44,25 @@ import tensorflow as tf
import cifar10 import cifar10
FLAGS = tf.app.flags.FLAGS parser = cifar10.parser
tf.app.flags.DEFINE_string('eval_dir', '/tmp/cifar10_eval', parser.add_argument('--eval_dir', type=str, default='/tmp/cifar10_eval',
"""Directory where to write event logs.""") help='Directory where to write event logs.')
tf.app.flags.DEFINE_string('eval_data', 'test',
"""Either 'test' or 'train_eval'.""") parser.add_argument('--eval_data', type=str, default='test',
tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/cifar10_train', help='Either `test` or `train_eval`.')
"""Directory where to read model checkpoints.""")
tf.app.flags.DEFINE_integer('eval_interval_secs', 60 * 5, parser.add_argument('--checkpoint_dir', type=str, default='/tmp/cifar10_train',
"""How often to run the eval.""") help='Directory where to read model checkpoints.')
tf.app.flags.DEFINE_integer('num_examples', 10000,
"""Number of examples to run.""") parser.add_argument('--eval_interval_secs', type=int, default=60*5,
tf.app.flags.DEFINE_boolean('run_once', False, help='How often to run the eval.')
"""Whether to run eval only once.""")
parser.add_argument('--num_examples', type=int, default=10000,
help='Number of examples to run.')
parser.add_argument('--run_once', type=bool, default=False,
help='Whether to run eval only once.')
def eval_once(saver, summary_writer, top_k_op, summary_op): def eval_once(saver, summary_writer, top_k_op, summary_op):
...@@ -154,4 +160,5 @@ def main(argv=None): # pylint: disable=unused-argument ...@@ -154,4 +160,5 @@ def main(argv=None): # pylint: disable=unused-argument
if __name__ == '__main__': if __name__ == '__main__':
FLAGS = parser.parse_args()
tf.app.run() tf.app.run()
...@@ -39,6 +39,7 @@ from __future__ import absolute_import ...@@ -39,6 +39,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse
from datetime import datetime from datetime import datetime
import os.path import os.path
import re import re
...@@ -49,17 +50,19 @@ from six.moves import xrange # pylint: disable=redefined-builtin ...@@ -49,17 +50,19 @@ from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf import tensorflow as tf
import cifar10 import cifar10
FLAGS = tf.app.flags.FLAGS parser = cifar10.parser
tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train', parser.add_argument('--train_dir', type=str, default='/tmp/cifar10_train',
"""Directory where to write event logs """ help='Directory where to write event logs and checkpoint.')
"""and checkpoint.""")
tf.app.flags.DEFINE_integer('max_steps', 1000000, parser.add_argument('--max_steps', type=int, default=1000000,
"""Number of batches to run.""") help='Number of batches to run.')
tf.app.flags.DEFINE_integer('num_gpus', 1,
"""How many GPUs to use.""") parser.add_argument('--num_gpus', type=int, default=1,
tf.app.flags.DEFINE_boolean('log_device_placement', False, help='How many GPUs to use.')
"""Whether to log device placement.""")
parser.add_argument('--log_device_placement', type=bool, default=False,
help='Whether to log device placement.')
def tower_loss(scope, images, labels): def tower_loss(scope, images, labels):
...@@ -274,4 +277,5 @@ def main(argv=None): # pylint: disable=unused-argument ...@@ -274,4 +277,5 @@ def main(argv=None): # pylint: disable=unused-argument
if __name__ == '__main__': if __name__ == '__main__':
FLAGS = parser.parse_args()
tf.app.run() tf.app.run()
...@@ -36,6 +36,7 @@ from __future__ import absolute_import ...@@ -36,6 +36,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse
from datetime import datetime from datetime import datetime
import time import time
...@@ -43,17 +44,19 @@ import tensorflow as tf ...@@ -43,17 +44,19 @@ import tensorflow as tf
import cifar10 import cifar10
FLAGS = tf.app.flags.FLAGS parser = cifar10.parser
tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train', parser.add_argument('--train_dir', type=str, default='/tmp/cifar10_train',
"""Directory where to write event logs """ help='Directory where to write event logs and checkpoint.')
"""and checkpoint.""")
tf.app.flags.DEFINE_integer('max_steps', 1000000, parser.add_argument('--max_steps', type=int, default=1000000,
"""Number of batches to run.""") help='Number of batches to run.')
tf.app.flags.DEFINE_boolean('log_device_placement', False,
"""Whether to log device placement.""") parser.add_argument('--log_device_placement', type=bool, default=False,
tf.app.flags.DEFINE_integer('log_frequency', 10, help='Whether to log device placement.')
"""How often to log results to the console.""")
parser.add_argument('--log_frequency', type=int, default=10,
help='How often to log results to the console.')
def train(): def train():
...@@ -124,4 +127,5 @@ def main(argv=None): # pylint: disable=unused-argument ...@@ -124,4 +127,5 @@ def main(argv=None): # pylint: disable=unused-argument
if __name__ == '__main__': if __name__ == '__main__':
FLAGS = parser.parse_args()
tf.app.run() tf.app.run()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment