keras_imagenet_main.py 9.5 KB
Newer Older
1
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a ResNet model on the ImageNet dataset."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl import app as absl_app
from absl import flags
import tensorflow as tf  # pylint: disable=g-bad-import-order

from official.resnet import imagenet_main
26
from official.resnet.keras import keras_common
Shining Sun's avatar
Shining Sun committed
27
from official.resnet.keras import resnet_model
Haoyu Zhang's avatar
Haoyu Zhang committed
28
from official.resnet.keras import trivial_model
29
30
31
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import distribution_utils
Toby Boyd's avatar
Toby Boyd committed
32
from official.utils.misc import keras_utils
33
from official.utils.misc import model_helpers
34
35
36
37
38


LR_SCHEDULE = [    # (multiplier, epoch to start) tuples
    (1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
]
Shining Sun's avatar
Shining Sun committed
39

40

Toby Boyd's avatar
Toby Boyd committed
41
42
43
44
def learning_rate_schedule(current_epoch,
                           current_batch,
                           batches_per_epoch,
                           batch_size):
45
46
  """Handles linear scaling rule, gradual warmup, and LR decay.

Toby Boyd's avatar
Toby Boyd committed
47
48
  Scale learning rate at epoch boundaries provided in LR_SCHEDULE by the
  provided scaling factor.
49
50
51
52

  Args:
    current_epoch: integer, current epoch indexed from 0.
    current_batch: integer, current batch in the current epoch, indexed from 0.
Toby Boyd's avatar
Toby Boyd committed
53
54
    batches_per_epoch: integer, number of steps in an epoch.
    batch_size: integer, total batch sized.
55
56
57
58

  Returns:
    Adjusted learning rate.
  """
Toby Boyd's avatar
Toby Boyd committed
59
  initial_lr = keras_common.BASE_LEARNING_RATE * batch_size / 256
60
61
62
63
  epoch = current_epoch + float(current_batch) / batches_per_epoch
  warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0]
  if epoch < warmup_end_epoch:
    # Learning rate increases linearly per step.
Toby Boyd's avatar
Toby Boyd committed
64
    return initial_lr * warmup_lr_multiplier * epoch / warmup_end_epoch
65
66
  for mult, start_epoch in LR_SCHEDULE:
    if epoch >= start_epoch:
Toby Boyd's avatar
Toby Boyd committed
67
      learning_rate = initial_lr * mult
68
69
70
71
72
73
    else:
      break
  return learning_rate


def parse_record_keras(raw_record, is_training, dtype):
Shining Sun's avatar
Shining Sun committed
74
  """Adjust the shape of label."""
Shining Sun's avatar
Shining Sun committed
75
  image, label = imagenet_main.parse_record(raw_record, is_training, dtype)
Shining Sun's avatar
Shining Sun committed
76

Shining Sun's avatar
Shining Sun committed
77
78
79
  # Subtract one so that labels are in [0, 1000), and cast to float32 for
  # Keras model.
  label = tf.cast(tf.cast(tf.reshape(label, shape=[1]), dtype=tf.int32) - 1,
Toby Boyd's avatar
Toby Boyd committed
80
                  dtype=tf.float32)
81
82
83
  return image, label


Shining Sun's avatar
Shining Sun committed
84
def run(flags_obj):
85
86
87
88
89
90
91
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
92
93
94

  Returns:
    Dictionary of training and eval stats.
95
  """
Toby Boyd's avatar
Toby Boyd committed
96
97
98
99
100
  keras_utils.set_session_config(
      enable_eager=flags_obj.enable_eager,
      enable_xla=flags_obj.enable_xla,
      enable_grappler_layout_optimizer=
      flags_obj.enable_grappler_layout_optimizer)
Shining Sun's avatar
Shining Sun committed
101

102
103
104
  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
    keras_common.set_gpu_thread_mode_and_count(flags_obj)
105
106
  if flags_obj.data_delay_prefetch:
    keras_common.data_delay_prefetch()
107
  keras_common.set_cudnn_batchnorm_mode()
108

109
  dtype = flags_core.get_tf_dtype(flags_obj)
Reed's avatar
Reed committed
110
111
112
  if dtype == 'float16':
    policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars')
    tf.keras.mixed_precision.experimental.set_policy(policy)
113

114
115
116
117
118
  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)
119

120
121
122
  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
123
124
125
      num_workers=distribution_utils.configure_cluster(),
      all_reduce_alg=flags_obj.all_reduce_alg,
      num_packs=flags_obj.num_packs)
126

rxsang's avatar
rxsang committed
127
128
129
130
131
132
133
  if strategy:
    # flags_obj.enable_get_next_as_optional controls whether enabling
    # get_next_as_optional behavior in DistributedIterator. If true, last
    # partial batch can be supported.
    strategy.extended.experimental_enable_get_next_as_optional = (
        flags_obj.enable_get_next_as_optional
    )
134

135
  strategy_scope = distribution_utils.get_strategy_scope(strategy)
136

137
138
  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
139
    distribution_utils.set_up_synthetic_data()
Shining Sun's avatar
Shining Sun committed
140
    input_fn = keras_common.get_synth_input_fn(
141
142
143
144
        height=imagenet_main.DEFAULT_IMAGE_SIZE,
        width=imagenet_main.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_main.NUM_CHANNELS,
        num_classes=imagenet_main.NUM_CLASSES,
145
146
        dtype=dtype,
        drop_remainder=True)
147
  else:
148
    distribution_utils.undo_set_up_synthetic_data()
Shining Sun's avatar
Shining Sun committed
149
    input_fn = imagenet_main.input_fn
150

151
152
153
154
  # When `enable_xla` is True, we always drop the remainder of the batches
  # in the dataset, as XLA-GPU doesn't support dynamic shapes.
  drop_remainder = flags_obj.enable_xla

155
156
157
158
159
160
  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
      num_epochs=flags_obj.train_epochs,
      parse_record_fn=parse_record_keras,
Reed's avatar
Reed committed
161
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
162
      dtype=dtype,
163
164
165
      drop_remainder=drop_remainder,
      tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
  )
166

167
168
169
170
171
172
173
174
  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
        parse_record_fn=parse_record_keras,
175
176
        dtype=dtype,
        drop_remainder=drop_remainder)
177

178
179
180
181
182
183
184
185
186
187
  lr_schedule = 0.1
  if flags_obj.use_tensor_lr:
    lr_schedule = keras_common.PiecewiseConstantDecayWithWarmup(
        batch_size=flags_obj.batch_size,
        epoch_size=imagenet_main.NUM_IMAGES['train'],
        warmup_epochs=LR_SCHEDULE[0][1],
        boundaries=list(p[1] for p in LR_SCHEDULE[1:]),
        multipliers=list(p[0] for p in LR_SCHEDULE),
        compute_lr_on_cpu=True)

Shining Sun's avatar
Shining Sun committed
188
  with strategy_scope:
189
    optimizer = keras_common.get_optimizer(lr_schedule)
Reed's avatar
Reed committed
190
191
192
193
    if dtype == 'float16':
      # TODO(reedwm): Remove manually wrapping optimizer once mixed precision
      # can be enabled with a single line of code.
      optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
194
195
          optimizer, loss_scale=flags_core.get_loss_scale(flags_obj,
                                                          default_for_fp16=128))
Haoyu Zhang's avatar
Haoyu Zhang committed
196
197

    if flags_obj.use_trivial_model:
198
      model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES, dtype)
Haoyu Zhang's avatar
Haoyu Zhang committed
199
    else:
200
201
      model = resnet_model.resnet50(
          num_classes=imagenet_main.NUM_CLASSES,
rxsang's avatar
rxsang committed
202
          dtype=dtype)
Shining Sun's avatar
Shining Sun committed
203

Shining Sun's avatar
Shining Sun committed
204
205
    model.compile(loss='sparse_categorical_crossentropy',
                  optimizer=optimizer,
206
207
                  metrics=(['sparse_categorical_accuracy']
                           if flags_obj.report_accuracy_metrics else None),
208
                  run_eagerly=flags_obj.run_eagerly,
209
                  cloning=flags_obj.clone_model_in_keras_dist_strat)
Shining Sun's avatar
Shining Sun committed
210

211
  callbacks = keras_common.get_callbacks(
212
      learning_rate_schedule, imagenet_main.NUM_IMAGES['train'])
213

Shining Sun's avatar
Shining Sun committed
214
215
216
217
218
219
220
  train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size
  train_epochs = flags_obj.train_epochs

  if flags_obj.train_steps:
    train_steps = min(flags_obj.train_steps, train_steps)
    train_epochs = 1

Shining Sun's avatar
Shining Sun committed
221
  num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] //
Toby Boyd's avatar
Toby Boyd committed
222
                    flags_obj.batch_size)
Shining Sun's avatar
Shining Sun committed
223
224
225

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
226
227
228
229
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
    tf.keras.backend.set_learning_phase(1)
Shining Sun's avatar
Shining Sun committed
230
231
232
    num_eval_steps = None
    validation_data = None

233
234
235
  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
                      steps_per_epoch=train_steps,
236
                      callbacks=callbacks,
237
238
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
239
                      validation_freq=flags_obj.epochs_between_evals,
240
                      verbose=2)
241

242
  eval_output = None
243
  if not flags_obj.skip_eval:
244
245
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
246
                                 verbose=2)
247
  stats = keras_common.build_stats(history, eval_output, callbacks)
248
  return stats
Shining Sun's avatar
bug fix  
Shining Sun committed
249

Shining Sun's avatar
Shining Sun committed
250

Toby Boyd's avatar
Toby Boyd committed
251
252
253
254
255
def define_imagenet_keras_flags():
  imagenet_main.define_imagenet_flags(dynamic_loss_scale=True, enable_xla=True)
  keras_common.define_keras_flags()


256
def main(_):
257
  model_helpers.apply_clean(flags.FLAGS)
258
  with logger.benchmark_context(flags.FLAGS):
259
    return run(flags.FLAGS)
260
261
262


if __name__ == '__main__':
263
  tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
Toby Boyd's avatar
Toby Boyd committed
264
  define_imagenet_keras_flags()
265
  absl_app.run(main)