Commit 61ec2907 authored by Toby Boyd's avatar Toby Boyd
Browse files

Add return stats after run for cifar

- kears_cifar_main returns dict of stats after run.
- unit tests added for keras_common.
- lint issues fixed in all 4 files.
parent 8d5c1684
......@@ -66,9 +66,8 @@ class KerasCifar10BenchmarkTests(object):
def _fill_report_object(self, stats):
if self.oss_report_object:
self.oss_report_object.top_1 = stats['accuracy_top_1'].item()
self.oss_report_object.add_other_quality(stats['training_accuracy_top_1']
.item(),
self.oss_report_object.top_1 = stats['accuracy_top_1']
self.oss_report_object.add_other_quality(stats['training_accuracy_top_1'],
'top_1_train_accuracy')
else:
raise ValueError('oss_report_object has not been set.')
......
......@@ -23,7 +23,6 @@ from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import cifar10_main as cifar_main
from official.resnet import resnet_run_loop
from official.resnet.keras import keras_common
from official.resnet.keras import resnet_cifar_model
from official.utils.flags import core as flags_core
......@@ -36,15 +35,20 @@ LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
]
def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch, batch_size):
def learning_rate_schedule(current_epoch,
current_batch,
batches_per_epoch,
batch_size):
"""Handles linear scaling rule, gradual warmup, and LR decay.
Scale learning rate at epoch boundaries provided in LR_SCHEDULE by the provided scaling
factor.
Scale learning rate at epoch boundaries provided in LR_SCHEDULE by the
provided scaling factor.
Args:
current_epoch: integer, current epoch indexed from 0.
current_batch: integer, current batch in the current epoch, indexed from 0.
batches_per_epoch: integer, number of steps in an epoch.
batch_size: integer, total batch sized.
Returns:
Adjusted learning rate.
......@@ -89,6 +93,9 @@ def run(flags_obj):
Raises:
ValueError: If fp16 is passed as it is not currently supported.
Returns:
Dictionary of training and eval stats.
"""
if flags_obj.enable_eager:
tf.enable_eager_execution()
......@@ -166,11 +173,13 @@ def run(flags_obj):
validation_steps=num_eval_steps,
validation_data=validation_data,
verbose=1)
eval_output = None
if not flags_obj.skip_eval:
eval_output = model.evaluate(eval_input_dataset,
steps=num_eval_steps,
verbose=1)
stats = keras_common.build_stats(history, eval_output)
return stats
def main(_):
......
......@@ -29,6 +29,8 @@ from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_de
FLAGS = flags.FLAGS
BASE_LEARNING_RATE = 0.1 # This matches Jing's version.
TRAIN_TOP_1 = 'training_accuracy_top_1'
class TimeHistory(tf.keras.callbacks.Callback):
"""Callback for Keras models."""
......@@ -49,7 +51,7 @@ class TimeHistory(tf.keras.callbacks.Callback):
def on_batch_begin(self, batch, logs=None):
if self.record_batch:
self.start_time= time.time()
self.start_time = time.time()
self.record_batch = False
def on_batch_end(self, batch, logs=None):
......@@ -63,6 +65,7 @@ class TimeHistory(tf.keras.callbacks.Callback):
"'images_per_second': %f}" %
(batch, elapsed_time, examples_per_second))
class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
"""Callback to update learning rate on every batch (not epoch boundaries).
......@@ -88,22 +91,27 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
self.epochs += 1
def on_batch_begin(self, batch, logs=None):
lr = self.schedule(self.epochs, batch, self.batches_per_epoch, self.batch_size)
lr = self.schedule(self.epochs,
batch,
self.batches_per_epoch,
self.batch_size)
if not isinstance(lr, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function should be float.')
if lr != self.prev_lr:
self.model.optimizer.learning_rate = lr # lr should be a float here
self.prev_lr = lr
tf.logging.debug('Epoch %05d Batch %05d: LearningRateBatchScheduler change '
'learning rate to %s.', self.epochs, batch, lr)
tf.logging.debug('Epoch %05d Batch %05d: LearningRateBatchScheduler '
'change learning rate to %s.', self.epochs, batch, lr)
def get_optimizer():
# The learning rate set here is a placeholder and not use. It will be overwritten
# at the beginning of each batch by callback
"""Returns optimizer to use."""
# The learning_rate is overwritten at the beginning of each step by callback.
return gradient_descent_v2.SGD(learning_rate=0.1, momentum=0.9)
def get_callbacks(learning_rate_schedule_fn, num_images):
"""Returns common callbacks."""
time_callback = TimeHistory(FLAGS.batch_size)
tensorboard_callback = tf.keras.callbacks.TensorBoard(
......@@ -117,14 +125,43 @@ def get_callbacks(learning_rate_schedule_fn, num_images):
return time_callback, tensorboard_callback, lr_callback
def build_stats(history, eval_output):
"""Normalizes and returns dictionary of stats.
Args:
history: Results of the training step. Supports both categorical_accuracy
and sparse_categorical_accuracy.
eval_output: Output of the eval step. Assumes first value is eval_loss and
second value is accuracy_top_1.
Returns:
Dictionary of normalized results.
"""
stats = {}
if eval_output:
stats['accuracy_top_1'] = eval_output[1].item()
stats['eval_loss'] = eval_output[0].item()
if history and history.history:
train_hist = history.history
# Gets final loss from training.
stats['loss'] = train_hist['loss'][-1].item()
# Gets top_1 training accuracy.
if 'categorical_accuracy' in train_hist:
stats[TRAIN_TOP_1] = train_hist['categorical_accuracy'][-1].item()
elif 'sparse_categorical_accuracy' in train_hist:
stats[TRAIN_TOP_1] = train_hist['sparse_categorical_accuracy'][-1].item()
return stats
def define_keras_flags():
flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?')
flags.DEFINE_boolean(name='skip_eval', default=False, help='Skip evaluation?')
flags.DEFINE_integer(
name="train_steps", default=None,
help="The number of steps to run for training. If it is larger than "
"# batches per epoch, then use # bathes per epoch. When this flag is "
"set, only one epoch is going to run for training.")
name='train_steps', default=None,
help='The number of steps to run for training. If it is larger than '
'# batches per epoch, then use # bathes per epoch. When this flag is '
'set, only one epoch is going to run for training.')
def get_synth_input_fn(height, width, num_channels, num_classes,
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the keras_common module."""
from __future__ import absolute_import
from __future__ import print_function
from mock import Mock
import numpy as np
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet.keras import keras_common
tf.logging.set_verbosity(tf.logging.ERROR)
class KerasCommonTests(tf.test.TestCase):
"""Tests for keras_common."""
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(KerasCommonTests, cls).setUpClass()
def test_build_stats(self):
history = self._build_history(1.145, cat_accuracy=.99988)
eval_output = self._build_eval_output(.56432111, 5.990)
stats = keras_common.build_stats(history, eval_output)
self.assertEqual(1.145, stats['loss'])
self.assertEqual(.99988, stats['training_accuracy_top_1'])
self.assertEqual(.56432111, stats['accuracy_top_1'])
self.assertEqual(5.990, stats['eval_loss'])
def test_build_stats_sparse(self):
history = self._build_history(1.145, cat_accuracy_sparse=.99988)
eval_output = self._build_eval_output(.928, 1.9844)
stats = keras_common.build_stats(history, eval_output)
self.assertEqual(1.145, stats['loss'])
self.assertEqual(.99988, stats['training_accuracy_top_1'])
self.assertEqual(.928, stats['accuracy_top_1'])
self.assertEqual(1.9844, stats['eval_loss'])
def _build_history(self, loss, cat_accuracy=None,
cat_accuracy_sparse=None):
history_p = Mock()
history = {}
history_p.history = history
history['loss'] = [np.float64(loss)]
if cat_accuracy:
history['categorical_accuracy'] = [np.float64(cat_accuracy)]
if cat_accuracy_sparse:
history['sparse_categorical_accuracy'] = [np.float64(cat_accuracy_sparse)]
return history_p
def _build_eval_output(self, top_1, eval_loss):
eval_output = [np.float64(eval_loss), np.float64(top_1)]
return eval_output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment