Commit d86aa760 authored by Neal Wu's avatar Neal Wu Committed by GitHub
Browse files

Modify cifar10_main.py and imagenet_main.py to parse flags in main only (#2582)

parent 52a87b3a
...@@ -20,6 +20,7 @@ from __future__ import print_function ...@@ -20,6 +20,7 @@ from __future__ import print_function
import argparse import argparse
import os import os
import sys
import tensorflow as tf import tensorflow as tf
...@@ -46,29 +47,21 @@ parser.add_argument('--epochs_per_eval', type=int, default=10, ...@@ -46,29 +47,21 @@ parser.add_argument('--epochs_per_eval', type=int, default=10,
parser.add_argument('--batch_size', type=int, default=128, parser.add_argument('--batch_size', type=int, default=128,
help='The number of images per batch.') help='The number of images per batch.')
FLAGS = parser.parse_args()
_HEIGHT = 32 _HEIGHT = 32
_WIDTH = 32 _WIDTH = 32
_DEPTH = 3 _DEPTH = 3
_NUM_CLASSES = 10 _NUM_CLASSES = 10
_NUM_DATA_FILES = 5 _NUM_DATA_FILES = 5
_NUM_IMAGES = {
'train': 50000,
'validation': 10000,
}
# Scale the learning rate linearly with the batch size. When the batch size is
# 128, the learning rate should be 0.1.
_INITIAL_LEARNING_RATE = 0.1 * FLAGS.batch_size / 128
_MOMENTUM = 0.9
# We use a weight decay of 0.0002, which performs better than the 0.0001 that # We use a weight decay of 0.0002, which performs better than the 0.0001 that
# was originally suggested. # was originally suggested.
_WEIGHT_DECAY = 2e-4 _WEIGHT_DECAY = 2e-4
_MOMENTUM = 0.9
_BATCHES_PER_EPOCH = _NUM_IMAGES['train'] / FLAGS.batch_size _NUM_IMAGES = {
'train': 50000,
'validation': 10000,
}
def record_dataset(filenames): def record_dataset(filenames):
...@@ -205,11 +198,15 @@ def cifar10_model_fn(features, labels, mode): ...@@ -205,11 +198,15 @@ def cifar10_model_fn(features, labels, mode):
[tf.nn.l2_loss(v) for v in tf.trainable_variables()]) [tf.nn.l2_loss(v) for v in tf.trainable_variables()])
if mode == tf.estimator.ModeKeys.TRAIN: if mode == tf.estimator.ModeKeys.TRAIN:
# Scale the learning rate linearly with the batch size. When the batch size
# is 128, the learning rate should be 0.1.
initial_learning_rate = 0.1 * FLAGS.batch_size / 128
batches_per_epoch = _NUM_IMAGES['train'] / FLAGS.batch_size
global_step = tf.train.get_or_create_global_step() global_step = tf.train.get_or_create_global_step()
# Multiply the learning rate by 0.1 at 100, 150, and 200 epochs. # Multiply the learning rate by 0.1 at 100, 150, and 200 epochs.
boundaries = [int(_BATCHES_PER_EPOCH * epoch) for epoch in [100, 150, 200]] boundaries = [int(batches_per_epoch * epoch) for epoch in [100, 150, 200]]
values = [_INITIAL_LEARNING_RATE * decay for decay in [1, 0.1, 0.01, 0.001]] values = [initial_learning_rate * decay for decay in [1, 0.1, 0.01, 0.001]]
learning_rate = tf.train.piecewise_constant( learning_rate = tf.train.piecewise_constant(
tf.cast(global_step, tf.int32), boundaries, values) tf.cast(global_step, tf.int32), boundaries, values)
...@@ -276,4 +273,5 @@ def main(unused_argv): ...@@ -276,4 +273,5 @@ def main(unused_argv):
if __name__ == '__main__': if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run() FLAGS, unparsed = parser.parse_known_args()
tf.app.run(argv=[sys.argv[0]] + unparsed)
...@@ -97,6 +97,6 @@ class BaseTest(tf.test.TestCase): ...@@ -97,6 +97,6 @@ class BaseTest(tf.test.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
FLAGS = cifar10_main.parser.parse_args() cifar10_main.FLAGS = cifar10_main.parser.parse_args()
cifar10_main.FLAGS = FLAGS FLAGS = cifar10_main.FLAGS
tf.test.main() tf.test.main()
...@@ -20,6 +20,7 @@ from __future__ import print_function ...@@ -20,6 +20,7 @@ from __future__ import print_function
import argparse import argparse
import os import os
import sys
import tensorflow as tf import tensorflow as tf
...@@ -52,12 +53,7 @@ parser.add_argument( ...@@ -52,12 +53,7 @@ parser.add_argument(
'--batch_size', type=int, default=32, '--batch_size', type=int, default=32,
help='Batch size for training and evaluation.') help='Batch size for training and evaluation.')
FLAGS = parser.parse_args() _DEFAULT_IMAGE_SIZE = 224
# Scale the learning rate linearly with the batch size. When the batch size is
# 256, the learning rate should be 0.1.
_INITIAL_LEARNING_RATE = 0.1 * FLAGS.batch_size / 256
_NUM_CHANNELS = 3 _NUM_CHANNELS = 3
_LABEL_CLASSES = 1001 _LABEL_CLASSES = 1001
...@@ -69,12 +65,6 @@ _NUM_IMAGES = { ...@@ -69,12 +65,6 @@ _NUM_IMAGES = {
'validation': 50000, 'validation': 50000,
} }
image_preprocessing_fn = vgg_preprocessing.preprocess_image
network = resnet_model.resnet_v2(
resnet_size=FLAGS.resnet_size, num_classes=_LABEL_CLASSES)
batches_per_epoch = _NUM_IMAGES['train'] / FLAGS.batch_size
def filenames(is_training): def filenames(is_training):
"""Return filenames for dataset.""" """Return filenames for dataset."""
...@@ -118,10 +108,10 @@ def dataset_parser(value, is_training): ...@@ -118,10 +108,10 @@ def dataset_parser(value, is_training):
_NUM_CHANNELS) _NUM_CHANNELS)
image = tf.image.convert_image_dtype(image, dtype=tf.float32) image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image = image_preprocessing_fn( image = vgg_preprocessing.preprocess_image(
image=image, image=image,
output_height=network.default_image_size, output_height=_DEFAULT_IMAGE_SIZE,
output_width=network.default_image_size, output_width=_DEFAULT_IMAGE_SIZE,
is_training=is_training) is_training=is_training)
label = tf.cast( label = tf.cast(
...@@ -158,6 +148,8 @@ def resnet_model_fn(features, labels, mode): ...@@ -158,6 +148,8 @@ def resnet_model_fn(features, labels, mode):
"""Our model_fn for ResNet to be used with our Estimator.""" """Our model_fn for ResNet to be used with our Estimator."""
tf.summary.image('images', features, max_outputs=6) tf.summary.image('images', features, max_outputs=6)
network = resnet_model.resnet_v2(
resnet_size=FLAGS.resnet_size, num_classes=_LABEL_CLASSES)
logits = network( logits = network(
inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN)) inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))
...@@ -183,13 +175,17 @@ def resnet_model_fn(features, labels, mode): ...@@ -183,13 +175,17 @@ def resnet_model_fn(features, labels, mode):
[tf.nn.l2_loss(v) for v in tf.trainable_variables()]) [tf.nn.l2_loss(v) for v in tf.trainable_variables()])
if mode == tf.estimator.ModeKeys.TRAIN: if mode == tf.estimator.ModeKeys.TRAIN:
# Scale the learning rate linearly with the batch size. When the batch size is
# 256, the learning rate should be 0.1.
initial_learning_rate = 0.1 * FLAGS.batch_size / 256
batches_per_epoch = _NUM_IMAGES['train'] / FLAGS.batch_size
global_step = tf.train.get_or_create_global_step() global_step = tf.train.get_or_create_global_step()
# Multiply the learning rate by 0.1 at 30, 60, 80, and 90 epochs. # Multiply the learning rate by 0.1 at 30, 60, 80, and 90 epochs.
boundaries = [ boundaries = [
int(batches_per_epoch * epoch) for epoch in [30, 60, 80, 90]] int(batches_per_epoch * epoch) for epoch in [30, 60, 80, 90]]
values = [ values = [
_INITIAL_LEARNING_RATE * decay for decay in [1, 0.1, 0.01, 1e-3, 1e-4]] initial_learning_rate * decay for decay in [1, 0.1, 0.01, 1e-3, 1e-4]]
learning_rate = tf.train.piecewise_constant( learning_rate = tf.train.piecewise_constant(
tf.cast(global_step, tf.int32), boundaries, values) tf.cast(global_step, tf.int32), boundaries, values)
...@@ -257,4 +253,5 @@ def main(unused_argv): ...@@ -257,4 +253,5 @@ def main(unused_argv):
if __name__ == '__main__': if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run() FLAGS, unparsed = parser.parse_known_args()
tf.app.run(argv=[sys.argv[0]] + unparsed)
...@@ -171,5 +171,6 @@ class BaseTest(tf.test.TestCase): ...@@ -171,5 +171,6 @@ class BaseTest(tf.test.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
imagenet_main.FLAGS = imagenet_main.parser.parse_args()
FLAGS = imagenet_main.FLAGS FLAGS = imagenet_main.FLAGS
tf.test.main() tf.test.main()
...@@ -275,7 +275,6 @@ def cifar10_resnet_v2_generator(resnet_size, num_classes, data_format=None): ...@@ -275,7 +275,6 @@ def cifar10_resnet_v2_generator(resnet_size, num_classes, data_format=None):
inputs = tf.identity(inputs, 'final_dense') inputs = tf.identity(inputs, 'final_dense')
return inputs return inputs
model.default_image_size = 32
return model return model
...@@ -344,7 +343,6 @@ def imagenet_resnet_v2_generator(block_fn, layers, num_classes, ...@@ -344,7 +343,6 @@ def imagenet_resnet_v2_generator(block_fn, layers, num_classes,
inputs = tf.identity(inputs, 'final_dense') inputs = tf.identity(inputs, 'final_dense')
return inputs return inputs
model.default_image_size = 224
return model return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment