resnet_cifar_main.py 10.1 KB
Newer Older
Shining Sun's avatar
Shining Sun committed
1
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Shining Sun's avatar
Shining Sun committed
15
"""Runs a ResNet model on the Cifar-10 dataset."""
16
17
18
19
20

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

Hongkun Yu's avatar
Hongkun Yu committed
21
# Import libraries
Jose Baiocchi's avatar
Jose Baiocchi committed
22
from absl import app
23
from absl import flags
24
from absl import logging
Jose Baiocchi's avatar
Jose Baiocchi committed
25
import numpy as np
26
import tensorflow as tf
27
from official.benchmark.models import cifar_preprocessing
28
from official.benchmark.models import resnet_cifar_model
Hongkun Yu's avatar
Hongkun Yu committed
29
from official.benchmark.models import synthetic_util
30
from official.common import distribute_utils
Fan Yang's avatar
Fan Yang committed
31
from official.legacy.image_classification.resnet import common
32
from official.utils.flags import core as flags_core
Toby Boyd's avatar
Toby Boyd committed
33
from official.utils.misc import keras_utils
34
35


36
37
LR_SCHEDULE = [  # (multiplier, epoch to start) tuples
    (0.1, 91), (0.01, 136), (0.001, 182)
38
39
]

40

41
42
43
44
def learning_rate_schedule(current_epoch,
                           current_batch,
                           batches_per_epoch,
                           batch_size):
Shining Sun's avatar
Shining Sun committed
45
  """Handles linear scaling rule and LR decay.
46

47
48
  Scale learning rate at epoch boundaries provided in LR_SCHEDULE by the
  provided scaling factor.
49
50
51
52

  Args:
    current_epoch: integer, current epoch indexed from 0.
    current_batch: integer, current batch in the current epoch, indexed from 0.
53
54
    batches_per_epoch: integer, number of steps in an epoch.
    batch_size: integer, total batch sized.
55
56
57
58

  Returns:
    Adjusted learning rate.
  """
59
  del current_batch, batches_per_epoch  # not used
60
  initial_learning_rate = common.BASE_LEARNING_RATE * batch_size / 128
61
  learning_rate = initial_learning_rate
62
  for mult, start_epoch in LR_SCHEDULE:
63
64
    if current_epoch >= start_epoch:
      learning_rate = initial_learning_rate * mult
65
66
67
68
69
    else:
      break
  return learning_rate


70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
  """Callback to update learning rate on every batch (not epoch boundaries).

  N.B. Only support Keras optimizers, not TF optimizers.

  Attributes:
      schedule: a function that takes an epoch index and a batch index as input
          (both integer, indexed from 0) and returns a new learning rate as
          output (float).
  """

  def __init__(self, schedule, batch_size, steps_per_epoch):
    super(LearningRateBatchScheduler, self).__init__()
    self.schedule = schedule
    self.steps_per_epoch = steps_per_epoch
    self.batch_size = batch_size
    self.epochs = -1
    self.prev_lr = -1

  def on_epoch_begin(self, epoch, logs=None):
    if not hasattr(self.model.optimizer, 'learning_rate'):
      raise ValueError('Optimizer must have a "learning_rate" attribute.')
    self.epochs += 1

  def on_batch_begin(self, batch, logs=None):
    """Executes before step begins."""
    lr = self.schedule(self.epochs,
                       batch,
                       self.steps_per_epoch,
                       self.batch_size)
    if not isinstance(lr, (float, np.float32, np.float64)):
      raise ValueError('The output of the "schedule" function should be float.')
    if lr != self.prev_lr:
      self.model.optimizer.learning_rate = lr  # lr should be a float here
      self.prev_lr = lr
105
      logging.debug(
106
107
108
109
          'Epoch %05d Batch %05d: LearningRateBatchScheduler '
          'change learning rate to %s.', self.epochs, batch, lr)


Shining Sun's avatar
Shining Sun committed
110
111
def run(flags_obj):
  """Run ResNet Cifar-10 training and eval loop using native Keras APIs.
112
113
114
115
116
117

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
118
119
120

  Returns:
    Dictionary of training and eval stats.
121
  """
122
  keras_utils.set_session_config(
123
      enable_xla=flags_obj.enable_xla)
124
125
126

  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
127
128
129
130
131
    keras_utils.set_gpu_thread_mode_and_count(
        per_gpu_thread_count=flags_obj.per_gpu_thread_count,
        gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
        num_gpus=flags_obj.num_gpus,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads)
132
  common.set_cudnn_batchnorm_mode()
133

134
135
136
137
138
  dtype = flags_core.get_tf_dtype(flags_obj)
  if dtype == 'fp16':
    raise ValueError('dtype fp16 is not supported in Keras. Use the default '
                     'value(fp32).')

139
140
  data_format = flags_obj.data_format
  if data_format is None:
141
142
    data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
                   else 'channels_last')
143
  tf.keras.backend.set_image_data_format(data_format)
144

145
  strategy = distribute_utils.get_distribution_strategy(
146
      distribution_strategy=flags_obj.distribution_strategy,
147
148
149
      num_gpus=flags_obj.num_gpus,
      all_reduce_alg=flags_obj.all_reduce_alg,
      num_packs=flags_obj.num_packs)
150

151
152
153
154
155
156
157
158
  if strategy:
    # flags_obj.enable_get_next_as_optional controls whether enabling
    # get_next_as_optional behavior in DistributedIterator. If true, last
    # partial batch can be supported.
    strategy.extended.experimental_enable_get_next_as_optional = (
        flags_obj.enable_get_next_as_optional
    )

159
  strategy_scope = distribute_utils.get_strategy_scope(strategy)
160

161
  if flags_obj.use_synthetic_data:
Hongkun Yu's avatar
Hongkun Yu committed
162
    synthetic_util.set_up_synthetic_data()
163
    input_fn = common.get_synth_input_fn(
164
165
166
167
        height=cifar_preprocessing.HEIGHT,
        width=cifar_preprocessing.WIDTH,
        num_channels=cifar_preprocessing.NUM_CHANNELS,
        num_classes=cifar_preprocessing.NUM_CLASSES,
168
169
        dtype=flags_core.get_tf_dtype(flags_obj),
        drop_remainder=True)
170
  else:
Hongkun Yu's avatar
Hongkun Yu committed
171
    synthetic_util.undo_set_up_synthetic_data()
172
    input_fn = cifar_preprocessing.input_fn
Shining Sun's avatar
Shining Sun committed
173
174
175
176

  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
177
      batch_size=flags_obj.batch_size,
178
      parse_record_fn=cifar_preprocessing.parse_record,
179
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
180
181
182
183
184
      dtype=dtype,
      # Setting drop_remainder to avoid the partial batch logic in normalization
      # layer, which triggers tf.where and leads to extra memory copy of input
      # sizes between host and GPU.
      drop_remainder=(not flags_obj.enable_get_next_as_optional))
185
186
187
188
189
190
191

  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
192
        parse_record_fn=cifar_preprocessing.parse_record)
193

194
195
196
197
198
199
200
201
202
203
  steps_per_epoch = (
      cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
  lr_schedule = 0.1
  if flags_obj.use_tensor_lr:
    initial_learning_rate = common.BASE_LEARNING_RATE * flags_obj.batch_size / 128
    lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
        boundaries=list(p[1] * steps_per_epoch for p in LR_SCHEDULE),
        values=[initial_learning_rate] +
        list(p[0] * initial_learning_rate for p in LR_SCHEDULE))

Shining Sun's avatar
Shining Sun committed
204
  with strategy_scope:
205
    optimizer = common.get_optimizer(lr_schedule)
206
    model = resnet_cifar_model.resnet56(classes=cifar_preprocessing.NUM_CLASSES)
207
208
209
210
211
212
    model.compile(
        loss='sparse_categorical_crossentropy',
        optimizer=optimizer,
        metrics=(['sparse_categorical_accuracy']
                 if flags_obj.report_accuracy_metrics else None),
        run_eagerly=flags_obj.run_eagerly)
Shining Sun's avatar
Shining Sun committed
213

Shining Sun's avatar
Shining Sun committed
214
215
  train_epochs = flags_obj.train_epochs

216
  callbacks = common.get_callbacks()
217
218
219
220
221
222
223

  if not flags_obj.use_tensor_lr:
    lr_callback = LearningRateBatchScheduler(
        schedule=learning_rate_schedule,
        batch_size=flags_obj.batch_size,
        steps_per_epoch=steps_per_epoch)
    callbacks.append(lr_callback)
Zongwei Zhou's avatar
Zongwei Zhou committed
224
225
226
227

  # if mutliple epochs, ignore the train_steps flag.
  if train_epochs <= 1 and flags_obj.train_steps:
    steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
Shining Sun's avatar
Shining Sun committed
228
229
    train_epochs = 1

230
  num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] //
231
232
                    flags_obj.batch_size)

Shining Sun's avatar
Shining Sun committed
233
234
  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
235
236
237
238
    if flags_obj.set_learning_phase_to_train:
      # TODO(haoyuzhang): Understand slowdown of setting learning phase when
      # not using distribution strategy.
      tf.keras.backend.set_learning_phase(1)
Shining Sun's avatar
Shining Sun committed
239
240
241
    num_eval_steps = None
    validation_data = None

242
243
244
245
246
247
  if not strategy and flags_obj.explicit_gpu_placement:
    # TODO(b/135607227): Add device scope automatically in Keras training loop
    # when not using distribition strategy.
    no_dist_strat_device = tf.device('/device:GPU:0')
    no_dist_strat_device.__enter__()

248
  history = model.fit(train_input_dataset,
249
                      epochs=train_epochs,
Zongwei Zhou's avatar
Zongwei Zhou committed
250
                      steps_per_epoch=steps_per_epoch,
251
                      callbacks=callbacks,
252
253
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
254
                      validation_freq=flags_obj.epochs_between_evals,
255
                      verbose=2)
256
  eval_output = None
257
  if not flags_obj.skip_eval:
Shining Sun's avatar
Shining Sun committed
258
259
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
260
                                 verbose=2)
261
262
263
264

  if not strategy and flags_obj.explicit_gpu_placement:
    no_dist_strat_device.__exit__()

265
  stats = common.build_stats(history, eval_output, callbacks)
266
  return stats
267

268

269
def define_cifar_flags():
270
  common.define_keras_flags()
271
272
273
274
275
276
277

  flags_core.set_defaults(data_dir='/tmp/cifar10_data/cifar-10-batches-bin',
                          model_dir='/tmp/cifar10_model',
                          epochs_between_evals=10,
                          batch_size=128)


278
def main(_):
279
  return run(flags.FLAGS)
280
281
282


if __name__ == '__main__':
283
  logging.set_verbosity(logging.INFO)
284
  define_cifar_flags()
Jose Baiocchi's avatar
Jose Baiocchi committed
285
  app.run(main)