Commit d86aa760 authored by Neal Wu's avatar Neal Wu Committed by GitHub
Browse files

Modify cifar10_main.py and imagenet_main.py to parse flags in main only (#2582)

parent 52a87b3a
......@@ -20,6 +20,7 @@ from __future__ import print_function
import argparse
import os
import sys
import tensorflow as tf
......@@ -46,29 +47,21 @@ parser.add_argument('--epochs_per_eval', type=int, default=10,
parser.add_argument('--batch_size', type=int, default=128,
help='The number of images per batch.')
FLAGS = parser.parse_args()
_HEIGHT = 32
_WIDTH = 32
_DEPTH = 3
_NUM_CLASSES = 10
_NUM_DATA_FILES = 5
_NUM_IMAGES = {
'train': 50000,
'validation': 10000,
}
# Scale the learning rate linearly with the batch size. When the batch size is
# 128, the learning rate should be 0.1.
_INITIAL_LEARNING_RATE = 0.1 * FLAGS.batch_size / 128
_MOMENTUM = 0.9
# We use a weight decay of 0.0002, which performs better than the 0.0001 that
# was originally suggested.
_WEIGHT_DECAY = 2e-4
_MOMENTUM = 0.9
_BATCHES_PER_EPOCH = _NUM_IMAGES['train'] / FLAGS.batch_size
_NUM_IMAGES = {
'train': 50000,
'validation': 10000,
}
def record_dataset(filenames):
......@@ -205,11 +198,15 @@ def cifar10_model_fn(features, labels, mode):
[tf.nn.l2_loss(v) for v in tf.trainable_variables()])
if mode == tf.estimator.ModeKeys.TRAIN:
# Scale the learning rate linearly with the batch size. When the batch size
# is 128, the learning rate should be 0.1.
initial_learning_rate = 0.1 * FLAGS.batch_size / 128
batches_per_epoch = _NUM_IMAGES['train'] / FLAGS.batch_size
global_step = tf.train.get_or_create_global_step()
# Multiply the learning rate by 0.1 at 100, 150, and 200 epochs.
boundaries = [int(_BATCHES_PER_EPOCH * epoch) for epoch in [100, 150, 200]]
values = [_INITIAL_LEARNING_RATE * decay for decay in [1, 0.1, 0.01, 0.001]]
boundaries = [int(batches_per_epoch * epoch) for epoch in [100, 150, 200]]
values = [initial_learning_rate * decay for decay in [1, 0.1, 0.01, 0.001]]
learning_rate = tf.train.piecewise_constant(
tf.cast(global_step, tf.int32), boundaries, values)
......@@ -276,4 +273,5 @@ def main(unused_argv):
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run()
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(argv=[sys.argv[0]] + unparsed)
......@@ -97,6 +97,6 @@ class BaseTest(tf.test.TestCase):
if __name__ == '__main__':
FLAGS = cifar10_main.parser.parse_args()
cifar10_main.FLAGS = FLAGS
cifar10_main.FLAGS = cifar10_main.parser.parse_args()
FLAGS = cifar10_main.FLAGS
tf.test.main()
......@@ -20,6 +20,7 @@ from __future__ import print_function
import argparse
import os
import sys
import tensorflow as tf
......@@ -52,12 +53,7 @@ parser.add_argument(
'--batch_size', type=int, default=32,
help='Batch size for training and evaluation.')
FLAGS = parser.parse_args()
# Scale the learning rate linearly with the batch size. When the batch size is
# 256, the learning rate should be 0.1.
_INITIAL_LEARNING_RATE = 0.1 * FLAGS.batch_size / 256
_DEFAULT_IMAGE_SIZE = 224
_NUM_CHANNELS = 3
_LABEL_CLASSES = 1001
......@@ -69,12 +65,6 @@ _NUM_IMAGES = {
'validation': 50000,
}
image_preprocessing_fn = vgg_preprocessing.preprocess_image
network = resnet_model.resnet_v2(
resnet_size=FLAGS.resnet_size, num_classes=_LABEL_CLASSES)
batches_per_epoch = _NUM_IMAGES['train'] / FLAGS.batch_size
def filenames(is_training):
"""Return filenames for dataset."""
......@@ -118,10 +108,10 @@ def dataset_parser(value, is_training):
_NUM_CHANNELS)
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image = image_preprocessing_fn(
image = vgg_preprocessing.preprocess_image(
image=image,
output_height=network.default_image_size,
output_width=network.default_image_size,
output_height=_DEFAULT_IMAGE_SIZE,
output_width=_DEFAULT_IMAGE_SIZE,
is_training=is_training)
label = tf.cast(
......@@ -158,6 +148,8 @@ def resnet_model_fn(features, labels, mode):
"""Our model_fn for ResNet to be used with our Estimator."""
tf.summary.image('images', features, max_outputs=6)
network = resnet_model.resnet_v2(
resnet_size=FLAGS.resnet_size, num_classes=_LABEL_CLASSES)
logits = network(
inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))
......@@ -183,13 +175,17 @@ def resnet_model_fn(features, labels, mode):
[tf.nn.l2_loss(v) for v in tf.trainable_variables()])
if mode == tf.estimator.ModeKeys.TRAIN:
# Scale the learning rate linearly with the batch size. When the batch size is
# 256, the learning rate should be 0.1.
initial_learning_rate = 0.1 * FLAGS.batch_size / 256
batches_per_epoch = _NUM_IMAGES['train'] / FLAGS.batch_size
global_step = tf.train.get_or_create_global_step()
# Multiply the learning rate by 0.1 at 30, 60, 80, and 90 epochs.
boundaries = [
int(batches_per_epoch * epoch) for epoch in [30, 60, 80, 90]]
values = [
_INITIAL_LEARNING_RATE * decay for decay in [1, 0.1, 0.01, 1e-3, 1e-4]]
initial_learning_rate * decay for decay in [1, 0.1, 0.01, 1e-3, 1e-4]]
learning_rate = tf.train.piecewise_constant(
tf.cast(global_step, tf.int32), boundaries, values)
......@@ -257,4 +253,5 @@ def main(unused_argv):
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run()
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(argv=[sys.argv[0]] + unparsed)
......@@ -171,5 +171,6 @@ class BaseTest(tf.test.TestCase):
if __name__ == '__main__':
imagenet_main.FLAGS = imagenet_main.parser.parse_args()
FLAGS = imagenet_main.FLAGS
tf.test.main()
......@@ -275,7 +275,6 @@ def cifar10_resnet_v2_generator(resnet_size, num_classes, data_format=None):
inputs = tf.identity(inputs, 'final_dense')
return inputs
model.default_image_size = 32
return model
......@@ -344,7 +343,6 @@ def imagenet_resnet_v2_generator(block_fn, layers, num_classes,
inputs = tf.identity(inputs, 'final_dense')
return inputs
model.default_image_size = 224
return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment