Commit 26a2c3b9 authored by Shining Sun's avatar Shining Sun
Browse files

Commit before pull

parent cd034c8c
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a ResNet model on the ImageNet dataset."""
"""Runs a ResNet model on the Cifar-10 dataset."""
from __future__ import absolute_import
from __future__ import division
......@@ -32,27 +32,18 @@ from official.resnet.keras import keras_resnet_model
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import distribution_utils
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
# LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
# (1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
# ]
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
(0.1, 91), (0.01, 136), (0.001, 182)
]
BASE_LEARNING_RATE = 0.1
def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch, batch_size):
"""Handles linear scaling rule, gradual warmup, and LR decay.
The learning rate starts at 0, then it increases linearly per step.
After 5 epochs we reach the base learning rate (scaled to account
for batch size).
After 30, 60 and 80 epochs the learning rate is divided by 10.
After 90 epochs training stops and the LR is set to 0. This ensures
that we train for exactly 90 epochs for reproducibility.
The learning rate starts at base learning_rate, then after 91, 136 and
182 epochs, the learning rate is divided by 10.
Args:
current_epoch: integer, current epoch indexed from 0.
......@@ -61,19 +52,7 @@ def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch, batc
Returns:
Adjusted learning rate.
"""
# epoch = current_epoch + float(current_batch) / batches_per_epoch
# warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0]
# if epoch < warmup_end_epoch:
# # Learning rate increases linearly per step.
# return BASE_LEARNING_RATE * warmup_lr_multiplier * epoch / warmup_end_epoch
# for mult, start_epoch in LR_SCHEDULE:
# if epoch >= start_epoch:
# learning_rate = BASE_LEARNING_RATE * mult
# else:
# break
# return learning_rate
initial_learning_rate = BASE_LEARNING_RATE * batch_size / 128
initial_learning_rate = keras_common.BASE_LEARNING_RATE * batch_size / 128
learning_rate = initial_learning_rate
for mult, start_epoch in LR_SCHEDULE:
if current_epoch >= start_epoch:
......@@ -103,8 +82,8 @@ def parse_record_keras(raw_record, is_training, dtype):
return image, label
def run_cifar_with_keras(flags_obj):
"""Run ResNet ImageNet training and eval loop using native Keras APIs.
def run(flags_obj):
"""Run ResNet Cifar-10 training and eval loop using native Keras APIs.
Args:
flags_obj: An object containing parsed flag values.
......@@ -164,17 +143,17 @@ def run_cifar_with_keras(flags_obj):
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras)
opt, loss, accuracy = keras_common.get_optimizer_loss_and_metrics()
optimizer = keras_common.get_optimizer()
strategy = keras_common.get_dist_strategy()
model = keras_resnet_model.ResNet56(input_shape=(32, 32, 3),
include_top=True,
classes=cifar_main._NUM_CLASSES,
weights=None)
model.compile(loss=loss,
optimizer=opt,
metrics=[accuracy],
classes=cifar_main._NUM_CLASSES)
model.compile(loss='categorical_crossentropy',
optimizer=optimizer,
metrics=['categorical_accuracy'],
distribute=strategy)
time_callback, tensorboard_callback, lr_callback = keras_common.get_fit_callbacks(
learning_rate_schedule)
......@@ -204,17 +183,12 @@ def run_cifar_with_keras(flags_obj):
return stats
def define_keras_cifar_flags():
flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?')
def main(_):
with logger.benchmark_context(flags.FLAGS):
run_cifar_with_keras(flags.FLAGS)
run(flags.FLAGS)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.DEBUG)
define_keras_cifar_flags()
cifar_main.define_cifar_flags()
absl_app.run(main)
......@@ -26,18 +26,12 @@ import numpy as np
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import imagenet_main
from official.resnet import imagenet_preprocessing
from official.resnet import resnet_run_loop
from official.resnet.keras import keras_resnet_model
from official.resnet.keras import resnet_model_tpu
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import distribution_utils
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
FLAGS = flags.FLAGS
BASE_LEARNING_RATE = 0.1 # This matches Jing's version.
class TimeHistory(tf.keras.callbacks.Callback):
"""Callback for Keras models."""
......@@ -116,28 +110,23 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
tf.logging.debug('Epoch %05d Batch %05d: LearningRateBatchScheduler change '
'learning rate to %s.', self.epochs, batch, lr)
def get_optimizer_loss_and_metrics():
# Use Keras ResNet50 applications model and native keras APIs
# initialize RMSprop optimizer
# TODO(anjalisridhar): Move to using MomentumOptimizer.
# opt = tf.train.GradientDescentOptimizer(learning_rate=0.0001)
# I am setting an initial LR of 0.001 since this will be reset
# at the beginning of the training loop.
opt = gradient_descent_v2.SGD(learning_rate=0.1, momentum=0.9)
# TF Optimizer:
# learning_rate = BASE_LEARNING_RATE * FLAGS.batch_size / 256
# opt = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)
loss = 'categorical_crossentropy'
accuracy = 'categorical_accuracy'
def get_optimizer():
if FLAGS.use_tf_momentum_optimizer:
learning_rate = BASE_LEARNING_RATE * FLAGS.batch_size / 256
optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)
else:
optimizer = gradient_descent_v2.SGD(learning_rate=0.1, momentum=0.9)
return opt, loss, accuracy
return optimizer
def get_dist_strategy():
if FLAGS.num_gpus == 1 and FLAGS.dist_strat_off:
if FLAGS.num_gpus == 1 and not FLAGS.use_one_device_strategy:
print('Not using distribution strategies.')
strategy = None
elif FLAGS.num_gpus > 1 and FLAGS.use_one_device_strategy:
rase ValueError("When %d GPUs are specified, use_one_device_strategy'
'flag cannot be set to True.")
else:
strategy = distribution_utils.get_distribution_strategy(
num_gpus=FLAGS.num_gpus)
......
......@@ -29,18 +29,16 @@ from official.resnet import imagenet_main
from official.resnet import imagenet_preprocessing
from official.resnet import resnet_run_loop
from official.resnet.keras import keras_common
from official.resnet.keras import keras_resnet_model
from official.resnet.keras import resnet_model_tpu
from official.resnet.keras import resnet50
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import distribution_utils
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
(1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
]
BASE_LEARNING_RATE = 0.1 # This matches Jing's version.
def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch, batch_size):
"""Handles linear scaling rule, gradual warmup, and LR decay.
......@@ -59,7 +57,7 @@ def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch, batc
Returns:
Adjusted learning rate.
"""
initial_learning_rate = BASE_LEARNING_RATE * batch_size / 256
initial_learning_rate = keras_common.BASE_LEARNING_RATE * batch_size / 256
epoch = current_epoch + float(current_batch) / batches_per_epoch
warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0]
if epoch < warmup_end_epoch:
......@@ -74,32 +72,12 @@ def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch, batc
def parse_record_keras(raw_record, is_training, dtype):
"""Parses a record containing a training example of an image.
The input record is parsed into a label and image, and the image is passed
through preprocessing steps (cropping, flipping, and so on).
Args:
raw_record: scalar Tensor tf.string containing a serialized
Example protocol buffer.
is_training: A boolean denoting whether the input is for training.
dtype: Data type to use for input images.
Returns:
Tuple with processed image tensor and one-hot-encoded label tensor.
"""
image_buffer, label, bbox = imagenet_main._parse_example_proto(raw_record)
image = imagenet_preprocessing.preprocess_image(
image_buffer=image_buffer,
bbox=bbox,
output_height=imagenet_main._DEFAULT_IMAGE_SIZE,
output_width=imagenet_main._DEFAULT_IMAGE_SIZE,
num_channels=imagenet_main._NUM_CHANNELS,
is_training=is_training)
image = tf.cast(image, dtype)
label = tf.sparse_to_dense(label, (imagenet_main._NUM_CLASSES,), 1)
"""Adjust the shape of label."""
image, label = imagenet_main.parse_record(raw_record, is_training, dtype)
# Subtract one so that labels are in [0, 1000), and cast to float32 for
# Keras model.
label = tf.cast(tf.cast(tf.reshape(label, shape=[1]), dtype=tf.int32) - 1,
dtype=tf.float32)
return image, label
......@@ -161,14 +139,14 @@ def run_imagenet_with_keras(flags_obj):
parse_record_fn=parse_record_keras)
opt, loss, accuracy = keras_common.get_optimizer_loss_and_metrics()
optimizer = keras_common.get_optimizer()
strategy = keras_common.get_dist_strategy()
model = resnet_model_tpu.ResNet50(num_classes=imagenet_main._NUM_CLASSES)
model = resnet50.ResNet50(num_classes=imagenet_main._NUM_CLASSES)
model.compile(loss=loss,
optimizer=opt,
metrics=[accuracy],
model.compile(loss='categorical_crossentropy',
optimizer=optimizer,
metrics=['categorical_accuracy'],
distribute=strategy)
time_callback, tensorboard_callback, lr_callback = keras_common.get_fit_callbacks(
......@@ -199,9 +177,6 @@ def run_imagenet_with_keras(flags_obj):
return stats
def define_keras_imagenet_flags():
flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?')
def main(_):
with logger.benchmark_context(flags.FLAGS):
......@@ -210,6 +185,5 @@ def main(_):
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
define_keras_imagenet_flags()
imagenet_main.define_imagenet_flags()
absl_app.run(main)
......@@ -629,9 +629,13 @@ def define_resnet_flags(resnet_size_choices=None):
'inference. Note, this flag only applies to ImageNet and cannot '
'be used for CIFAR.'))
flags.DEFINE_boolean(
name='dist_strat_off', default=False,
help=flags_core.help_wrap('Set to true to not use distribution '
name='use_one_device_strategy', default=True,
help=flags_core.help_wrap('Set to False to not use distribution '
'strategies.'))
flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?')
flags.DEFINE_boolean(name='use_tf_momentum_optimizer', default=False,
help='Use tf MomentumOptimizer.')
choice_kwargs = dict(
name='resnet_size', short_name='rs', default='50',
help=flags_core.help_wrap('The size of the ResNet model to use.'))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment