resnet_cifar_main.py 10.8 KB
Newer Older
Shining Sun's avatar
Shining Sun committed
1
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Shining Sun's avatar
Shining Sun committed
15
"""Runs a ResNet model on the Cifar-10 dataset."""
16
17
18
19
20

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

21
import numpy as np
22
from absl import flags
23
import tensorflow as tf
24
from official.benchmark.models import resnet_cifar_model
25
26
27
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import distribution_utils
Toby Boyd's avatar
Toby Boyd committed
28
from official.utils.misc import keras_utils
29
30
from official.vision.image_classification import cifar_preprocessing
from official.vision.image_classification import common
31
32


33
34
LR_SCHEDULE = [  # (multiplier, epoch to start) tuples
    (0.1, 91), (0.01, 136), (0.001, 182)
35
36
]

37

38
39
40
41
def learning_rate_schedule(current_epoch,
                           current_batch,
                           batches_per_epoch,
                           batch_size):
Shining Sun's avatar
Shining Sun committed
42
  """Handles linear scaling rule and LR decay.
43

44
45
  Scale learning rate at epoch boundaries provided in LR_SCHEDULE by the
  provided scaling factor.
46
47
48
49

  Args:
    current_epoch: integer, current epoch indexed from 0.
    current_batch: integer, current batch in the current epoch, indexed from 0.
50
51
    batches_per_epoch: integer, number of steps in an epoch.
    batch_size: integer, total batch sized.
52
53
54
55

  Returns:
    Adjusted learning rate.
  """
56
  del current_batch, batches_per_epoch  # not used
57
  initial_learning_rate = common.BASE_LEARNING_RATE * batch_size / 128
58
  learning_rate = initial_learning_rate
59
  for mult, start_epoch in LR_SCHEDULE:
60
61
    if current_epoch >= start_epoch:
      learning_rate = initial_learning_rate * mult
62
63
64
65
66
    else:
      break
  return learning_rate


67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
  """Callback to update learning rate on every batch (not epoch boundaries).

  N.B. Only support Keras optimizers, not TF optimizers.

  Attributes:
      schedule: a function that takes an epoch index and a batch index as input
          (both integer, indexed from 0) and returns a new learning rate as
          output (float).
  """

  def __init__(self, schedule, batch_size, steps_per_epoch):
    super(LearningRateBatchScheduler, self).__init__()
    self.schedule = schedule
    self.steps_per_epoch = steps_per_epoch
    self.batch_size = batch_size
    self.epochs = -1
    self.prev_lr = -1

  def on_epoch_begin(self, epoch, logs=None):
    if not hasattr(self.model.optimizer, 'learning_rate'):
      raise ValueError('Optimizer must have a "learning_rate" attribute.')
    self.epochs += 1

  def on_batch_begin(self, batch, logs=None):
    """Executes before step begins."""
    lr = self.schedule(self.epochs,
                       batch,
                       self.steps_per_epoch,
                       self.batch_size)
    if not isinstance(lr, (float, np.float32, np.float64)):
      raise ValueError('The output of the "schedule" function should be float.')
    if lr != self.prev_lr:
      self.model.optimizer.learning_rate = lr  # lr should be a float here
      self.prev_lr = lr
      tf.compat.v1.logging.debug(
          'Epoch %05d Batch %05d: LearningRateBatchScheduler '
          'change learning rate to %s.', self.epochs, batch, lr)


Shining Sun's avatar
Shining Sun committed
107
108
def run(flags_obj):
  """Run ResNet Cifar-10 training and eval loop using native Keras APIs.
109
110
111
112
113
114

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
115
116
117

  Returns:
    Dictionary of training and eval stats.
118
  """
119
120
  keras_utils.set_session_config(
      enable_eager=flags_obj.enable_eager,
121
      enable_xla=flags_obj.enable_xla)
122
123
124

  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
125
126
127
128
129
    keras_utils.set_gpu_thread_mode_and_count(
        per_gpu_thread_count=flags_obj.per_gpu_thread_count,
        gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
        num_gpus=flags_obj.num_gpus,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads)
130
  common.set_cudnn_batchnorm_mode()
131

132
133
134
135
136
  dtype = flags_core.get_tf_dtype(flags_obj)
  if dtype == 'fp16':
    raise ValueError('dtype fp16 is not supported in Keras. Use the default '
                     'value(fp32).')

137
138
139
140
141
  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)
142

143
144
  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
145
146
147
      num_gpus=flags_obj.num_gpus,
      all_reduce_alg=flags_obj.all_reduce_alg,
      num_packs=flags_obj.num_packs)
148

149
150
151
152
153
154
155
156
  if strategy:
    # flags_obj.enable_get_next_as_optional controls whether enabling
    # get_next_as_optional behavior in DistributedIterator. If true, last
    # partial batch can be supported.
    strategy.extended.experimental_enable_get_next_as_optional = (
        flags_obj.enable_get_next_as_optional
    )

157
  strategy_scope = distribution_utils.get_strategy_scope(strategy)
158

159
  if flags_obj.use_synthetic_data:
160
    distribution_utils.set_up_synthetic_data()
161
    input_fn = common.get_synth_input_fn(
162
163
164
165
        height=cifar_preprocessing.HEIGHT,
        width=cifar_preprocessing.WIDTH,
        num_channels=cifar_preprocessing.NUM_CHANNELS,
        num_classes=cifar_preprocessing.NUM_CLASSES,
166
167
        dtype=flags_core.get_tf_dtype(flags_obj),
        drop_remainder=True)
168
  else:
169
    distribution_utils.undo_set_up_synthetic_data()
170
    input_fn = cifar_preprocessing.input_fn
Shining Sun's avatar
Shining Sun committed
171
172
173
174

  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
175
      batch_size=flags_obj.batch_size,
Shining Sun's avatar
Shining Sun committed
176
      num_epochs=flags_obj.train_epochs,
177
      parse_record_fn=cifar_preprocessing.parse_record,
178
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
179
180
181
182
183
      dtype=dtype,
      # Setting drop_remainder to avoid the partial batch logic in normalization
      # layer, which triggers tf.where and leads to extra memory copy of input
      # sizes between host and GPU.
      drop_remainder=(not flags_obj.enable_get_next_as_optional))
184
185
186
187
188
189
190
191

  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
192
        parse_record_fn=cifar_preprocessing.parse_record)
193

194
195
196
197
198
199
200
201
202
203
  steps_per_epoch = (
      cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
  lr_schedule = 0.1
  if flags_obj.use_tensor_lr:
    initial_learning_rate = common.BASE_LEARNING_RATE * flags_obj.batch_size / 128
    lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
        boundaries=list(p[1] * steps_per_epoch for p in LR_SCHEDULE),
        values=[initial_learning_rate] +
        list(p[0] * initial_learning_rate for p in LR_SCHEDULE))

Shining Sun's avatar
Shining Sun committed
204
  with strategy_scope:
205
    optimizer = common.get_optimizer(lr_schedule)
206
    model = resnet_cifar_model.resnet56(classes=cifar_preprocessing.NUM_CLASSES)
Shining Sun's avatar
Shining Sun committed
207

208
209
210
211
    # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
    # a valid arg for this model. Also remove as a valid flag.
    if flags_obj.force_v2_in_keras_compile is not None:
      model.compile(
Pavithra Vijay's avatar
Pavithra Vijay committed
212
          loss='sparse_categorical_crossentropy',
213
          optimizer=optimizer,
Pavithra Vijay's avatar
Pavithra Vijay committed
214
          metrics=(['sparse_categorical_accuracy']
215
216
217
218
219
                   if flags_obj.report_accuracy_metrics else None),
          run_eagerly=flags_obj.run_eagerly,
          experimental_run_tf_function=flags_obj.force_v2_in_keras_compile)
    else:
      model.compile(
Pavithra Vijay's avatar
Pavithra Vijay committed
220
          loss='sparse_categorical_crossentropy',
221
          optimizer=optimizer,
Pavithra Vijay's avatar
Pavithra Vijay committed
222
          metrics=(['sparse_categorical_accuracy']
223
224
                   if flags_obj.report_accuracy_metrics else None),
          run_eagerly=flags_obj.run_eagerly)
Shining Sun's avatar
Shining Sun committed
225

Shining Sun's avatar
Shining Sun committed
226
227
  train_epochs = flags_obj.train_epochs

228
229
230
231
232
233
234
235
  callbacks = common.get_callbacks(steps_per_epoch)

  if not flags_obj.use_tensor_lr:
    lr_callback = LearningRateBatchScheduler(
        schedule=learning_rate_schedule,
        batch_size=flags_obj.batch_size,
        steps_per_epoch=steps_per_epoch)
    callbacks.append(lr_callback)
Zongwei Zhou's avatar
Zongwei Zhou committed
236
237
238
239

  # if mutliple epochs, ignore the train_steps flag.
  if train_epochs <= 1 and flags_obj.train_steps:
    steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
Shining Sun's avatar
Shining Sun committed
240
241
    train_epochs = 1

242
  num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] //
243
244
                    flags_obj.batch_size)

Shining Sun's avatar
Shining Sun committed
245
246
  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
247
248
249
250
    if flags_obj.set_learning_phase_to_train:
      # TODO(haoyuzhang): Understand slowdown of setting learning phase when
      # not using distribution strategy.
      tf.keras.backend.set_learning_phase(1)
Shining Sun's avatar
Shining Sun committed
251
252
253
    num_eval_steps = None
    validation_data = None

254
255
256
257
258
259
  if not strategy and flags_obj.explicit_gpu_placement:
    # TODO(b/135607227): Add device scope automatically in Keras training loop
    # when not using distribition strategy.
    no_dist_strat_device = tf.device('/device:GPU:0')
    no_dist_strat_device.__enter__()

260
  history = model.fit(train_input_dataset,
261
                      epochs=train_epochs,
Zongwei Zhou's avatar
Zongwei Zhou committed
262
                      steps_per_epoch=steps_per_epoch,
263
                      callbacks=callbacks,
264
265
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
266
                      validation_freq=flags_obj.epochs_between_evals,
267
                      verbose=2)
268
  eval_output = None
269
  if not flags_obj.skip_eval:
Shining Sun's avatar
Shining Sun committed
270
271
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
272
                                 verbose=2)
273
274
275
276

  if not strategy and flags_obj.explicit_gpu_placement:
    no_dist_strat_device.__exit__()

277
  stats = common.build_stats(history, eval_output, callbacks)
278
  return stats
279

280

281
def define_cifar_flags():
282
  common.define_keras_flags(dynamic_loss_scale=False)
283
284
285
286
287
288
289

  flags_core.set_defaults(data_dir='/tmp/cifar10_data/cifar-10-batches-bin',
                          model_dir='/tmp/cifar10_model',
                          epochs_between_evals=10,
                          batch_size=128)


290
def main(_):
291
  with logger.benchmark_context(flags.FLAGS):
292
    return run(flags.FLAGS)
293
294
295


if __name__ == '__main__':
296
  tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
297
  define_cifar_flags()
298
  absl_app.run(main)