Commit 61ec2907 authored by Toby Boyd's avatar Toby Boyd
Browse files

Add return stats after run for cifar

- kears_cifar_main returns dict of stats after run.
- unit tests added for keras_common.
- lint issues fixed in all 4 files.
parent 8d5c1684
...@@ -66,9 +66,8 @@ class KerasCifar10BenchmarkTests(object): ...@@ -66,9 +66,8 @@ class KerasCifar10BenchmarkTests(object):
def _fill_report_object(self, stats): def _fill_report_object(self, stats):
if self.oss_report_object: if self.oss_report_object:
self.oss_report_object.top_1 = stats['accuracy_top_1'].item() self.oss_report_object.top_1 = stats['accuracy_top_1']
self.oss_report_object.add_other_quality(stats['training_accuracy_top_1'] self.oss_report_object.add_other_quality(stats['training_accuracy_top_1'],
.item(),
'top_1_train_accuracy') 'top_1_train_accuracy')
else: else:
raise ValueError('oss_report_object has not been set.') raise ValueError('oss_report_object has not been set.')
......
...@@ -23,7 +23,6 @@ from absl import flags ...@@ -23,7 +23,6 @@ from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import cifar10_main as cifar_main from official.resnet import cifar10_main as cifar_main
from official.resnet import resnet_run_loop
from official.resnet.keras import keras_common from official.resnet.keras import keras_common
from official.resnet.keras import resnet_cifar_model from official.resnet.keras import resnet_cifar_model
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
...@@ -36,15 +35,20 @@ LR_SCHEDULE = [ # (multiplier, epoch to start) tuples ...@@ -36,15 +35,20 @@ LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
] ]
def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch, batch_size): def learning_rate_schedule(current_epoch,
current_batch,
batches_per_epoch,
batch_size):
"""Handles linear scaling rule, gradual warmup, and LR decay. """Handles linear scaling rule, gradual warmup, and LR decay.
Scale learning rate at epoch boundaries provided in LR_SCHEDULE by the provided scaling Scale learning rate at epoch boundaries provided in LR_SCHEDULE by the
factor. provided scaling factor.
Args: Args:
current_epoch: integer, current epoch indexed from 0. current_epoch: integer, current epoch indexed from 0.
current_batch: integer, current batch in the current epoch, indexed from 0. current_batch: integer, current batch in the current epoch, indexed from 0.
batches_per_epoch: integer, number of steps in an epoch.
batch_size: integer, total batch sized.
Returns: Returns:
Adjusted learning rate. Adjusted learning rate.
...@@ -89,6 +93,9 @@ def run(flags_obj): ...@@ -89,6 +93,9 @@ def run(flags_obj):
Raises: Raises:
ValueError: If fp16 is passed as it is not currently supported. ValueError: If fp16 is passed as it is not currently supported.
Returns:
Dictionary of training and eval stats.
""" """
if flags_obj.enable_eager: if flags_obj.enable_eager:
tf.enable_eager_execution() tf.enable_eager_execution()
...@@ -127,10 +134,10 @@ def run(flags_obj): ...@@ -127,10 +134,10 @@ def run(flags_obj):
optimizer = keras_common.get_optimizer() optimizer = keras_common.get_optimizer()
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
flags_obj.num_gpus, flags_obj.turn_off_distribution_strategy) flags_obj.num_gpus, flags_obj.turn_off_distribution_strategy)
model = resnet_cifar_model.resnet56(input_shape=(32, 32, 3), model = resnet_cifar_model.resnet56(input_shape=(32, 32, 3),
classes=cifar_main.NUM_CLASSES) classes=cifar_main.NUM_CLASSES)
model.compile(loss='categorical_crossentropy', model.compile(loss='categorical_crossentropy',
optimizer=optimizer, optimizer=optimizer,
...@@ -156,21 +163,23 @@ def run(flags_obj): ...@@ -156,21 +163,23 @@ def run(flags_obj):
validation_data = None validation_data = None
history = model.fit(train_input_dataset, history = model.fit(train_input_dataset,
epochs=train_epochs, epochs=train_epochs,
steps_per_epoch=train_steps, steps_per_epoch=train_steps,
callbacks=[ callbacks=[
time_callback, time_callback,
lr_callback, lr_callback,
tensorboard_callback tensorboard_callback
], ],
validation_steps=num_eval_steps, validation_steps=num_eval_steps,
validation_data=validation_data, validation_data=validation_data,
verbose=1) verbose=1)
eval_output = None
if not flags_obj.skip_eval: if not flags_obj.skip_eval:
eval_output = model.evaluate(eval_input_dataset, eval_output = model.evaluate(eval_input_dataset,
steps=num_eval_steps, steps=num_eval_steps,
verbose=1) verbose=1)
stats = keras_common.build_stats(history, eval_output)
return stats
def main(_): def main(_):
......
...@@ -29,6 +29,8 @@ from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_de ...@@ -29,6 +29,8 @@ from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_de
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
BASE_LEARNING_RATE = 0.1 # This matches Jing's version. BASE_LEARNING_RATE = 0.1 # This matches Jing's version.
TRAIN_TOP_1 = 'training_accuracy_top_1'
class TimeHistory(tf.keras.callbacks.Callback): class TimeHistory(tf.keras.callbacks.Callback):
"""Callback for Keras models.""" """Callback for Keras models."""
...@@ -49,7 +51,7 @@ class TimeHistory(tf.keras.callbacks.Callback): ...@@ -49,7 +51,7 @@ class TimeHistory(tf.keras.callbacks.Callback):
def on_batch_begin(self, batch, logs=None): def on_batch_begin(self, batch, logs=None):
if self.record_batch: if self.record_batch:
self.start_time= time.time() self.start_time = time.time()
self.record_batch = False self.record_batch = False
def on_batch_end(self, batch, logs=None): def on_batch_end(self, batch, logs=None):
...@@ -63,6 +65,7 @@ class TimeHistory(tf.keras.callbacks.Callback): ...@@ -63,6 +65,7 @@ class TimeHistory(tf.keras.callbacks.Callback):
"'images_per_second': %f}" % "'images_per_second': %f}" %
(batch, elapsed_time, examples_per_second)) (batch, elapsed_time, examples_per_second))
class LearningRateBatchScheduler(tf.keras.callbacks.Callback): class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
"""Callback to update learning rate on every batch (not epoch boundaries). """Callback to update learning rate on every batch (not epoch boundaries).
...@@ -88,43 +91,77 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback): ...@@ -88,43 +91,77 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
self.epochs += 1 self.epochs += 1
def on_batch_begin(self, batch, logs=None): def on_batch_begin(self, batch, logs=None):
lr = self.schedule(self.epochs, batch, self.batches_per_epoch, self.batch_size) lr = self.schedule(self.epochs,
batch,
self.batches_per_epoch,
self.batch_size)
if not isinstance(lr, (float, np.float32, np.float64)): if not isinstance(lr, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function should be float.') raise ValueError('The output of the "schedule" function should be float.')
if lr != self.prev_lr: if lr != self.prev_lr:
self.model.optimizer.learning_rate = lr # lr should be a float here self.model.optimizer.learning_rate = lr # lr should be a float here
self.prev_lr = lr self.prev_lr = lr
tf.logging.debug('Epoch %05d Batch %05d: LearningRateBatchScheduler change ' tf.logging.debug('Epoch %05d Batch %05d: LearningRateBatchScheduler '
'learning rate to %s.', self.epochs, batch, lr) 'change learning rate to %s.', self.epochs, batch, lr)
def get_optimizer(): def get_optimizer():
# The learning rate set here is a placeholder and not use. It will be overwritten """Returns optimizer to use."""
# at the beginning of each batch by callback # The learning_rate is overwritten at the beginning of each step by callback.
return gradient_descent_v2.SGD(learning_rate=0.1, momentum=0.9) return gradient_descent_v2.SGD(learning_rate=0.1, momentum=0.9)
def get_callbacks(learning_rate_schedule_fn, num_images): def get_callbacks(learning_rate_schedule_fn, num_images):
"""Returns common callbacks."""
time_callback = TimeHistory(FLAGS.batch_size) time_callback = TimeHistory(FLAGS.batch_size)
tensorboard_callback = tf.keras.callbacks.TensorBoard( tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=FLAGS.model_dir) log_dir=FLAGS.model_dir)
lr_callback = LearningRateBatchScheduler( lr_callback = LearningRateBatchScheduler(
learning_rate_schedule_fn, learning_rate_schedule_fn,
batch_size=FLAGS.batch_size, batch_size=FLAGS.batch_size,
num_images=num_images) num_images=num_images)
return time_callback, tensorboard_callback, lr_callback return time_callback, tensorboard_callback, lr_callback
def build_stats(history, eval_output):
"""Normalizes and returns dictionary of stats.
Args:
history: Results of the training step. Supports both categorical_accuracy
and sparse_categorical_accuracy.
eval_output: Output of the eval step. Assumes first value is eval_loss and
second value is accuracy_top_1.
Returns:
Dictionary of normalized results.
"""
stats = {}
if eval_output:
stats['accuracy_top_1'] = eval_output[1].item()
stats['eval_loss'] = eval_output[0].item()
if history and history.history:
train_hist = history.history
# Gets final loss from training.
stats['loss'] = train_hist['loss'][-1].item()
# Gets top_1 training accuracy.
if 'categorical_accuracy' in train_hist:
stats[TRAIN_TOP_1] = train_hist['categorical_accuracy'][-1].item()
elif 'sparse_categorical_accuracy' in train_hist:
stats[TRAIN_TOP_1] = train_hist['sparse_categorical_accuracy'][-1].item()
return stats
def define_keras_flags(): def define_keras_flags():
flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?') flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?')
flags.DEFINE_boolean(name='skip_eval', default=False, help='Skip evaluation?') flags.DEFINE_boolean(name='skip_eval', default=False, help='Skip evaluation?')
flags.DEFINE_integer( flags.DEFINE_integer(
name="train_steps", default=None, name='train_steps', default=None,
help="The number of steps to run for training. If it is larger than " help='The number of steps to run for training. If it is larger than '
"# batches per epoch, then use # bathes per epoch. When this flag is " '# batches per epoch, then use # bathes per epoch. When this flag is '
"set, only one epoch is going to run for training.") 'set, only one epoch is going to run for training.')
def get_synth_input_fn(height, width, num_channels, num_classes, def get_synth_input_fn(height, width, num_channels, num_classes,
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the keras_common module."""
from __future__ import absolute_import
from __future__ import print_function
from mock import Mock
import numpy as np
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet.keras import keras_common
tf.logging.set_verbosity(tf.logging.ERROR)
class KerasCommonTests(tf.test.TestCase):
"""Tests for keras_common."""
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(KerasCommonTests, cls).setUpClass()
def test_build_stats(self):
history = self._build_history(1.145, cat_accuracy=.99988)
eval_output = self._build_eval_output(.56432111, 5.990)
stats = keras_common.build_stats(history, eval_output)
self.assertEqual(1.145, stats['loss'])
self.assertEqual(.99988, stats['training_accuracy_top_1'])
self.assertEqual(.56432111, stats['accuracy_top_1'])
self.assertEqual(5.990, stats['eval_loss'])
def test_build_stats_sparse(self):
history = self._build_history(1.145, cat_accuracy_sparse=.99988)
eval_output = self._build_eval_output(.928, 1.9844)
stats = keras_common.build_stats(history, eval_output)
self.assertEqual(1.145, stats['loss'])
self.assertEqual(.99988, stats['training_accuracy_top_1'])
self.assertEqual(.928, stats['accuracy_top_1'])
self.assertEqual(1.9844, stats['eval_loss'])
def _build_history(self, loss, cat_accuracy=None,
cat_accuracy_sparse=None):
history_p = Mock()
history = {}
history_p.history = history
history['loss'] = [np.float64(loss)]
if cat_accuracy:
history['categorical_accuracy'] = [np.float64(cat_accuracy)]
if cat_accuracy_sparse:
history['sparse_categorical_accuracy'] = [np.float64(cat_accuracy_sparse)]
return history_p
def _build_eval_output(self, top_1, eval_loss):
eval_output = [np.float64(eval_loss), np.float64(top_1)]
return eval_output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment