Commit fb12693f authored by Neal Wu's avatar Neal Wu Committed by GitHub
Browse files

Updates to resnet, including using epochs in flags rather than steps (#2552)

parent 6263692b
......@@ -25,14 +25,6 @@ import tensorflow as tf
import resnet_model
HEIGHT = 32
WIDTH = 32
DEPTH = 3
NUM_CLASSES = 10
NUM_DATA_BATCHES = 5
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 10000 * NUM_DATA_BATCHES
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000
parser = argparse.ArgumentParser()
# Basic model parameters.
......@@ -45,10 +37,10 @@ parser.add_argument('--model_dir', type=str, default='/tmp/cifar10_model',
parser.add_argument('--resnet_size', type=int, default=32,
help='The size of the ResNet model to use.')
parser.add_argument('--train_steps', type=int, default=100000,
help='The number of batches to train.')
parser.add_argument('--train_epochs', type=int, default=250,
help='The number of epochs to train.')
parser.add_argument('--steps_per_eval', type=int, default=4000,
parser.add_argument('--epochs_per_eval', type=int, default=10,
help='The number of batches to run in between evaluations.')
parser.add_argument('--batch_size', type=int, default=128,
......@@ -56,6 +48,17 @@ parser.add_argument('--batch_size', type=int, default=128,
FLAGS = parser.parse_args()
_HEIGHT = 32
_WIDTH = 32
_DEPTH = 3
_NUM_CLASSES = 10
_NUM_DATA_FILES = 5
_NUM_IMAGES = {
'train': 50000,
'validation': 10000,
}
# Scale the learning rate linearly with the batch size. When the batch size is
# 128, the learning rate should be 0.1.
_INITIAL_LEARNING_RATE = 0.1 * FLAGS.batch_size / 128
......@@ -65,32 +68,30 @@ _MOMENTUM = 0.9
# was originally suggested.
_WEIGHT_DECAY = 2e-4
_BATCHES_PER_EPOCH = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.batch_size
_BATCHES_PER_EPOCH = _NUM_IMAGES['train'] / FLAGS.batch_size
def record_dataset(filenames):
"""Returns an input pipeline Dataset from `filenames`."""
record_bytes = HEIGHT * WIDTH * DEPTH + 1
record_bytes = _HEIGHT * _WIDTH * _DEPTH + 1
return tf.contrib.data.FixedLengthRecordDataset(filenames, record_bytes)
def get_filenames(mode):
"""Returns a list of filenames based on 'mode'."""
def get_filenames(is_training):
"""Returns a list of filenames."""
data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
assert os.path.exists(data_dir), (
'Run cifar10_download_and_extract.py first to download and extract the '
'CIFAR-10 data.')
if mode == tf.estimator.ModeKeys.TRAIN:
if is_training:
return [
os.path.join(data_dir, 'data_batch_%d.bin' % i)
for i in range(1, NUM_DATA_BATCHES + 1)
for i in range(1, _NUM_DATA_FILES + 1)
]
elif mode == tf.estimator.ModeKeys.EVAL:
return [os.path.join(data_dir, 'test_batch.bin')]
else:
raise ValueError('Invalid mode: %s' % mode)
return [os.path.join(data_dir, 'test_batch.bin')]
def dataset_parser(value):
......@@ -98,7 +99,7 @@ def dataset_parser(value):
# Every record consists of a label followed by the image, with a fixed number
# of bytes for each.
label_bytes = 1
image_bytes = HEIGHT * WIDTH * DEPTH
image_bytes = _HEIGHT * _WIDTH * _DEPTH
record_bytes = label_bytes + image_bytes
# Convert from a string to a vector of uint8 that is record_bytes long.
......@@ -110,22 +111,22 @@ def dataset_parser(value):
# The remaining bytes after the label represent the image, which we reshape
# from [depth * height * width] to [depth, height, width].
depth_major = tf.reshape(raw_record[label_bytes:record_bytes],
[DEPTH, HEIGHT, WIDTH])
[_DEPTH, _HEIGHT, _WIDTH])
# Convert from [depth, height, width] to [height, width, depth], and cast as
# float32.
image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32)
return image, tf.one_hot(label, NUM_CLASSES)
return image, tf.one_hot(label, _NUM_CLASSES)
def train_preprocess_fn(image, label):
"""Preprocess a single training image of layout [height, width, depth]."""
# Resize the image to add four extra pixels on each side.
image = tf.image.resize_image_with_crop_or_pad(image, HEIGHT + 8, WIDTH + 8)
image = tf.image.resize_image_with_crop_or_pad(image, _HEIGHT + 8, _WIDTH + 8)
# Randomly crop a [HEIGHT, WIDTH] section of the image.
image = tf.random_crop(image, [HEIGHT, WIDTH, DEPTH])
# Randomly crop a [_HEIGHT, _WIDTH] section of the image.
image = tf.random_crop(image, [_HEIGHT, _WIDTH, _DEPTH])
# Randomly flip the image horizontally.
image = tf.image.random_flip_left_right(image)
......@@ -133,44 +134,41 @@ def train_preprocess_fn(image, label):
return image, label
def input_fn(mode, batch_size):
def input_fn(is_training, num_epochs=1):
"""Input_fn using the contrib.data input pipeline for CIFAR-10 dataset.
Args:
mode: Standard names for model modes from tf.estimator.ModeKeys.
batch_size: The number of samples per batch of input requested.
is_training: A boolean denoting whether the input is for training.
num_epochs: The number of epochs to repeat the dataset.
Returns:
A tuple of images and labels.
"""
dataset = record_dataset(get_filenames(mode))
# For training repeat forever.
if mode == tf.estimator.ModeKeys.TRAIN:
dataset = dataset.repeat()
dataset = record_dataset(get_filenames(is_training))
dataset = dataset.map(dataset_parser, num_threads=1,
output_buffer_size=2 * batch_size)
output_buffer_size=2 * FLAGS.batch_size)
# For training, preprocess the image and shuffle.
if mode == tf.estimator.ModeKeys.TRAIN:
if is_training:
dataset = dataset.map(train_preprocess_fn, num_threads=1,
output_buffer_size=2 * batch_size)
output_buffer_size=2 * FLAGS.batch_size)
# Ensure that the capacity is sufficiently large to provide good random
# shuffling.
buffer_size = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * 0.4) + 3 * batch_size
buffer_size = int(0.4 * _NUM_IMAGES['train'])
dataset = dataset.shuffle(buffer_size=buffer_size)
# Subtract off the mean and divide by the variance of the pixels.
dataset = dataset.map(
lambda image, label: (tf.image.per_image_standardization(image), label),
num_threads=1,
output_buffer_size=2 * batch_size)
output_buffer_size=2 * FLAGS.batch_size)
dataset = dataset.repeat(num_epochs)
# Batch results by up to batch_size, and then fetch the tuple from the
# iterator.
iterator = dataset.batch(batch_size).make_one_shot_iterator()
iterator = dataset.batch(FLAGS.batch_size).make_one_shot_iterator()
images, labels = iterator.get_next()
return images, labels
......@@ -181,9 +179,9 @@ def cifar10_model_fn(features, labels, mode):
tf.summary.image('images', features, max_outputs=6)
network = resnet_model.cifar10_resnet_v2_generator(
FLAGS.resnet_size, NUM_CLASSES)
FLAGS.resnet_size, _NUM_CLASSES)
inputs = tf.reshape(features, [-1, HEIGHT, WIDTH, DEPTH])
inputs = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _DEPTH])
logits = network(inputs, mode == tf.estimator.ModeKeys.TRAIN)
predictions = {
......@@ -250,10 +248,12 @@ def main(unused_argv):
# Using the Winograd non-fused algorithms provides a small performance boost.
os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
# Set up a RunConfig to only save checkpoints once per training cycle.
run_config = tf.estimator.RunConfig().replace(save_checkpoints_secs=1e9)
cifar_classifier = tf.estimator.Estimator(
model_fn=cifar10_model_fn, model_dir=FLAGS.model_dir)
model_fn=cifar10_model_fn, model_dir=FLAGS.model_dir, config=run_config)
for _ in range(FLAGS.train_steps // FLAGS.steps_per_eval):
for _ in range(FLAGS.train_epochs // FLAGS.epochs_per_eval):
tensors_to_log = {
'learning_rate': 'learning_rate',
'cross_entropy': 'cross_entropy',
......@@ -264,15 +264,13 @@ def main(unused_argv):
tensors=tensors_to_log, every_n_iter=100)
cifar_classifier.train(
input_fn=lambda: input_fn(tf.estimator.ModeKeys.TRAIN,
batch_size=FLAGS.batch_size),
steps=FLAGS.steps_per_eval,
input_fn=lambda: input_fn(
is_training=True, num_epochs=FLAGS.epochs_per_eval),
hooks=[logging_hook])
# Evaluate the model and print results
eval_results = cifar_classifier.evaluate(
input_fn=lambda: input_fn(tf.estimator.ModeKeys.EVAL,
batch_size=FLAGS.batch_size))
input_fn=lambda: input_fn(is_training=False))
print(eval_results)
......
......@@ -41,26 +41,17 @@ parser.add_argument(
help='The size of the ResNet model to use.')
parser.add_argument(
'--train_steps', type=int, default=6400000,
help='The number of steps to use for training.')
'--train_epochs', type=int, default=100,
help='The number of epochs to use for training.')
parser.add_argument(
'--steps_per_eval', type=int, default=40000,
help='The number of training steps to run between evaluations.')
'--epochs_per_eval', type=int, default=1,
help='The number of training epochs to run between evaluations.')
parser.add_argument(
'--batch_size', type=int, default=32,
help='Batch size for training and evaluation.')
parser.add_argument(
'--map_threads', type=int, default=5,
help='The number of threads for dataset.map.')
parser.add_argument(
'--first_cycle_steps', type=int, default=None,
help='The number of steps to run before the first evaluation. Useful if '
'you have stopped partway through a training cycle.')
FLAGS = parser.parse_args()
# Scale the learning rate linearly with the batch size. When the batch size is
......@@ -140,18 +131,18 @@ def dataset_parser(value, is_training):
return image, tf.one_hot(label, _LABEL_CLASSES)
def input_fn(is_training):
"""Input function which provides a single batch for train or eval."""
def input_fn(is_training, num_epochs=1):
"""Input function which provides batches for train or eval."""
dataset = tf.contrib.data.Dataset.from_tensor_slices(filenames(is_training))
if is_training:
dataset = dataset.shuffle(buffer_size=1024)
dataset = dataset.flat_map(tf.contrib.data.TFRecordDataset)
if is_training:
dataset = dataset.repeat()
dataset = dataset.repeat(num_epochs)
dataset = dataset.map(lambda value: dataset_parser(value, is_training),
num_threads=FLAGS.map_threads,
num_threads=5,
output_buffer_size=FLAGS.batch_size)
if is_training:
......@@ -194,9 +185,9 @@ def resnet_model_fn(features, labels, mode):
if mode == tf.estimator.ModeKeys.TRAIN:
global_step = tf.train.get_or_create_global_step()
# Multiply the learning rate by 0.1 at 30, 60, 120, and 150 epochs.
# Multiply the learning rate by 0.1 at 30, 60, 80, and 90 epochs.
boundaries = [
int(batches_per_epoch * epoch) for epoch in [30, 60, 120, 150]]
int(batches_per_epoch * epoch) for epoch in [30, 60, 80, 90]]
values = [
_INITIAL_LEARNING_RATE * decay for decay in [1, 0.1, 0.01, 1e-3, 1e-4]]
learning_rate = tf.train.piecewise_constant(
......@@ -237,10 +228,12 @@ def main(unused_argv):
# Using the Winograd non-fused algorithms provides a small performance boost.
os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
# Set up a RunConfig to only save checkpoints once per training cycle.
run_config = tf.estimator.RunConfig().replace(save_checkpoints_secs=1e9)
resnet_classifier = tf.estimator.Estimator(
model_fn=resnet_model_fn, model_dir=FLAGS.model_dir)
model_fn=resnet_model_fn, model_dir=FLAGS.model_dir, config=run_config)
for _ in range(FLAGS.train_steps // FLAGS.steps_per_eval):
for _ in range(FLAGS.train_epochs // FLAGS.epochs_per_eval):
tensors_to_log = {
'learning_rate': 'learning_rate',
'cross_entropy': 'cross_entropy',
......@@ -252,13 +245,13 @@ def main(unused_argv):
print('Starting a training cycle.')
resnet_classifier.train(
input_fn=lambda: input_fn(True),
steps=FLAGS.first_cycle_steps or FLAGS.steps_per_eval,
input_fn=lambda: input_fn(
is_training=True, num_epochs=FLAGS.epochs_per_eval),
hooks=[logging_hook])
FLAGS.first_cycle_steps = None
print('Starting to evaluate.')
eval_results = resnet_classifier.evaluate(input_fn=lambda: input_fn(False))
eval_results = resnet_classifier.evaluate(
input_fn=lambda: input_fn(is_training=False))
print(eval_results)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment