keras_common.py 5.28 KB
Newer Older
1
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
15
"""Common util functions and classes used by both keras cifar and imagenet."""
16
17
18
19
20
21
22
23
24
25
26
27
28
29

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import time

from absl import flags
import numpy as np
import tensorflow as tf  # pylint: disable=g-bad-import-order

from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2


Shining Sun's avatar
Shining Sun committed
30
FLAGS = flags.FLAGS
Shining Sun's avatar
Shining Sun committed
31
BASE_LEARNING_RATE = 0.1  # This matches Jing's version.
Shining Sun's avatar
Shining Sun committed
32

33
34
35
36
class TimeHistory(tf.keras.callbacks.Callback):
  """Callback for Keras models."""

  def __init__(self, batch_size):
37
    """Callback for logging performance (# image/second).
38
39
40
41
42
43
44

    Args:
      batch_size: Total batch size.

    """
    self._batch_size = batch_size
    super(TimeHistory, self).__init__()
45
    self.log_batch_size = 100
46
47
48
49
50
51
52
53
54
55
56

  def on_train_begin(self, logs=None):
    self.batch_times_secs = []
    self.record_batch = True

  def on_batch_begin(self, batch, logs=None):
    if self.record_batch:
      self.batch_time_start = time.time()
      self.record_batch = False

  def on_batch_end(self, batch, logs=None):
57
    if batch % self.log_batch_size == 0:
58
      last_n_batches = time.time() - self.batch_time_start
Shining Sun's avatar
Shining Sun committed
59
      examples_per_second = (self._batch_size * self.log_batch_size) / last_n_batches
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
      self.batch_times_secs.append(last_n_batches)
      self.record_batch = True
      # TODO(anjalisridhar): add timestamp as well.
      if batch != 0:
        tf.logging.info("BenchmarkMetric: {'num_batches':%d, 'time_taken': %f,"
                        "'images_per_second': %f}" %
                        (batch, last_n_batches, examples_per_second))

class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
  """Callback to update learning rate on every batch (not epoch boundaries).

  N.B. Only support Keras optimizers, not TF optimizers.

  Args:
      schedule: a function that takes an epoch index and a batch index as input
          (both integer, indexed from 0) and returns a new learning rate as
          output (float).
  """

  def __init__(self, schedule, batch_size, num_images):
    super(LearningRateBatchScheduler, self).__init__()
    self.schedule = schedule
    self.batches_per_epoch = num_images / batch_size
    self.batch_size = batch_size
    self.epochs = -1
    self.prev_lr = -1

  def on_epoch_begin(self, epoch, logs=None):
88
89
    if not hasattr(self.model.optimizer, 'learning_rate'):
      raise ValueError('Optimizer must have a "learning_rate" attribute.')
90
91
92
93
94
95
96
    self.epochs += 1

  def on_batch_begin(self, batch, logs=None):
    lr = self.schedule(self.epochs, batch, self.batches_per_epoch, self.batch_size)
    if not isinstance(lr, (float, np.float32, np.float64)):
      raise ValueError('The output of the "schedule" function should be float.')
    if lr != self.prev_lr:
Shining Sun's avatar
Shining Sun committed
97
98
      self.model.optimizer.learning_rate = lr  # lr should be a float here
      # tf.keras.backend.set_value(self.model.optimizer.learning_rate, lr)
99
100
101
102
      self.prev_lr = lr
      tf.logging.debug('Epoch %05d Batch %05d: LearningRateBatchScheduler change '
                   'learning rate to %s.', self.epochs, batch, lr)

Shining Sun's avatar
Shining Sun committed
103
104
105
106
107
def get_optimizer():
  if FLAGS.use_tf_momentum_optimizer:
    learning_rate = BASE_LEARNING_RATE * FLAGS.batch_size / 256
    optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)
  else:
108
109
    # optimizer = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9)
    optimizer = gradient_descent_v2.SGD(learning_rate=0.1, momentum=0.9)
110

Shining Sun's avatar
Shining Sun committed
111
  return optimizer
112
113


114
def get_callbacks(learning_rate_schedule_fn, num_images):
Shining Sun's avatar
Shining Sun committed
115
  time_callback = TimeHistory(FLAGS.batch_size)
116
117

  tensorboard_callback = tf.keras.callbacks.TensorBoard(
Shining Sun's avatar
Shining Sun committed
118
    log_dir=FLAGS.model_dir)
119

Shining Sun's avatar
Shining Sun committed
120
121
122
  lr_callback = LearningRateBatchScheduler(
    learning_rate_schedule_fn,
    batch_size=FLAGS.batch_size,
123
    num_images=num_images)
124
125
126

  return time_callback, tensorboard_callback, lr_callback

Shining Sun's avatar
bug fix  
Shining Sun committed
127
def analyze_fit_and_eval_result(history, eval_output):
128
129
130
131
132
133
  stats = {}
  stats['accuracy_top_1'] = eval_output[1]
  stats['eval_loss'] = eval_output[0]
  stats['training_loss'] = history.history['loss'][-1]
  stats['training_accuracy_top_1'] = history.history['categorical_accuracy'][-1]

Shining Sun's avatar
Shining Sun committed
134
  print('Test loss:{}'.format(stats['eval_loss']))
135
136
137
  print('top_1 accuracy:{}'.format(stats['accuracy_top_1']))
  print('top_1_training_accuracy:{}'.format(stats['training_accuracy_top_1']))

Shining Sun's avatar
Shining Sun committed
138
  return stats
Shining Sun's avatar
Shining Sun committed
139
140
141

def define_keras_flags():
  flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?')
142
  flags.DEFINE_boolean(name='skip_eval', default=False, help='Skip evaluation?')
Shining Sun's avatar
Shining Sun committed
143
144
145
  flags.DEFINE_integer(
      name="train_steps", default=None,
      help="The number of steps to run for training")
146