keras_imagenet_main.py 10.4 KB
Newer Older
1
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a ResNet model on the ImageNet dataset."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl import app as absl_app
from absl import flags
23
from absl import logging
24
25
import tensorflow as tf  # pylint: disable=g-bad-import-order

26
from official.resnet.keras import imagenet_preprocessing
27
from official.resnet.keras import keras_common
Shining Sun's avatar
Shining Sun committed
28
from official.resnet.keras import resnet_model
Haoyu Zhang's avatar
Haoyu Zhang committed
29
from official.resnet.keras import trivial_model
30
31
32
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import distribution_utils
Toby Boyd's avatar
Toby Boyd committed
33
from official.utils.misc import keras_utils
34
from official.utils.misc import model_helpers
35
36
37
38
39


LR_SCHEDULE = [    # (multiplier, epoch to start) tuples
    (1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
]
Shining Sun's avatar
Shining Sun committed
40

41

Toby Boyd's avatar
Toby Boyd committed
42
43
44
45
def learning_rate_schedule(current_epoch,
                           current_batch,
                           batches_per_epoch,
                           batch_size):
46
47
  """Handles linear scaling rule, gradual warmup, and LR decay.

Toby Boyd's avatar
Toby Boyd committed
48
49
  Scale learning rate at epoch boundaries provided in LR_SCHEDULE by the
  provided scaling factor.
50
51
52
53

  Args:
    current_epoch: integer, current epoch indexed from 0.
    current_batch: integer, current batch in the current epoch, indexed from 0.
Toby Boyd's avatar
Toby Boyd committed
54
55
    batches_per_epoch: integer, number of steps in an epoch.
    batch_size: integer, total batch sized.
56
57
58
59

  Returns:
    Adjusted learning rate.
  """
Toby Boyd's avatar
Toby Boyd committed
60
  initial_lr = keras_common.BASE_LEARNING_RATE * batch_size / 256
61
62
63
64
  epoch = current_epoch + float(current_batch) / batches_per_epoch
  warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0]
  if epoch < warmup_end_epoch:
    # Learning rate increases linearly per step.
Toby Boyd's avatar
Toby Boyd committed
65
    return initial_lr * warmup_lr_multiplier * epoch / warmup_end_epoch
66
67
  for mult, start_epoch in LR_SCHEDULE:
    if epoch >= start_epoch:
Toby Boyd's avatar
Toby Boyd committed
68
      learning_rate = initial_lr * mult
69
70
71
72
73
    else:
      break
  return learning_rate


Shining Sun's avatar
Shining Sun committed
74
def run(flags_obj):
75
76
77
78
79
80
81
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
82
83
84

  Returns:
    Dictionary of training and eval stats.
85
  """
Toby Boyd's avatar
Toby Boyd committed
86
87
  keras_utils.set_session_config(
      enable_eager=flags_obj.enable_eager,
88
      enable_xla=flags_obj.enable_xla)
Shining Sun's avatar
Shining Sun committed
89

90
91
92
  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
    keras_common.set_gpu_thread_mode_and_count(flags_obj)
93
94
  if flags_obj.data_delay_prefetch:
    keras_common.data_delay_prefetch()
95
  keras_common.set_cudnn_batchnorm_mode()
96

97
  dtype = flags_core.get_tf_dtype(flags_obj)
Reed's avatar
Reed committed
98
99
100
  if dtype == 'float16':
    policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars')
    tf.keras.mixed_precision.experimental.set_policy(policy)
101

102
103
104
105
106
  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)
107

108
109
110
111
  # Configures cluster spec for distribution strategy.
  num_workers = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                                     flags_obj.task_index)

112
113
114
  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
115
      num_workers=num_workers,
116
117
      all_reduce_alg=flags_obj.all_reduce_alg,
      num_packs=flags_obj.num_packs)
118

rxsang's avatar
rxsang committed
119
120
121
122
123
124
125
  if strategy:
    # flags_obj.enable_get_next_as_optional controls whether enabling
    # get_next_as_optional behavior in DistributedIterator. If true, last
    # partial batch can be supported.
    strategy.extended.experimental_enable_get_next_as_optional = (
        flags_obj.enable_get_next_as_optional
    )
126

127
  strategy_scope = distribution_utils.get_strategy_scope(strategy)
128

129
130
  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
131
    distribution_utils.set_up_synthetic_data()
Shining Sun's avatar
Shining Sun committed
132
    input_fn = keras_common.get_synth_input_fn(
133
134
135
136
        height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_preprocessing.NUM_CHANNELS,
        num_classes=imagenet_preprocessing.NUM_CLASSES,
137
138
        dtype=dtype,
        drop_remainder=True)
139
  else:
140
    distribution_utils.undo_set_up_synthetic_data()
141
    input_fn = imagenet_preprocessing.input_fn
142

143
144
145
146
  # When `enable_xla` is True, we always drop the remainder of the batches
  # in the dataset, as XLA-GPU doesn't support dynamic shapes.
  drop_remainder = flags_obj.enable_xla

147
148
149
150
151
  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
      num_epochs=flags_obj.train_epochs,
152
      parse_record_fn=imagenet_preprocessing.parse_record,
Reed's avatar
Reed committed
153
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
154
      dtype=dtype,
155
156
157
      drop_remainder=drop_remainder,
      tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
  )
158

159
160
161
162
163
164
165
  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
166
        parse_record_fn=imagenet_preprocessing.parse_record,
167
168
        dtype=dtype,
        drop_remainder=drop_remainder)
169

170
171
172
173
  lr_schedule = 0.1
  if flags_obj.use_tensor_lr:
    lr_schedule = keras_common.PiecewiseConstantDecayWithWarmup(
        batch_size=flags_obj.batch_size,
174
        epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
175
176
177
178
179
        warmup_epochs=LR_SCHEDULE[0][1],
        boundaries=list(p[1] for p in LR_SCHEDULE[1:]),
        multipliers=list(p[0] for p in LR_SCHEDULE),
        compute_lr_on_cpu=True)

Shining Sun's avatar
Shining Sun committed
180
  with strategy_scope:
181
    optimizer = keras_common.get_optimizer(lr_schedule)
Reed's avatar
Reed committed
182
183
184
185
    if dtype == 'float16':
      # TODO(reedwm): Remove manually wrapping optimizer once mixed precision
      # can be enabled with a single line of code.
      optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
186
187
          optimizer, loss_scale=flags_core.get_loss_scale(flags_obj,
                                                          default_for_fp16=128))
Haoyu Zhang's avatar
Haoyu Zhang committed
188
189

    if flags_obj.use_trivial_model:
190
191
      model = trivial_model.trivial_model(
          imagenet_preprocessing.NUM_CLASSES, dtype)
Haoyu Zhang's avatar
Haoyu Zhang committed
192
    else:
193
      model = resnet_model.resnet50(
194
          num_classes=imagenet_preprocessing.NUM_CLASSES, dtype=dtype)
Shining Sun's avatar
Shining Sun committed
195

196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
    # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
    # a valid arg for this model. Also remove as a valid flag.
    if flags_obj.force_v2_in_keras_compile is not None:
      model.compile(
          loss='sparse_categorical_crossentropy',
          optimizer=optimizer,
          metrics=(['sparse_categorical_accuracy']
                   if flags_obj.report_accuracy_metrics else None),
          run_eagerly=flags_obj.run_eagerly,
          experimental_run_tf_function=flags_obj.force_v2_in_keras_compile)
    else:
      model.compile(
          loss='sparse_categorical_crossentropy',
          optimizer=optimizer,
          metrics=(['sparse_categorical_accuracy']
                   if flags_obj.report_accuracy_metrics else None),
          run_eagerly=flags_obj.run_eagerly)
Shining Sun's avatar
Shining Sun committed
213

214
  callbacks = keras_common.get_callbacks(
215
      learning_rate_schedule, imagenet_preprocessing.NUM_IMAGES['train'])
216

217
218
  train_steps = (
      imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
Shining Sun's avatar
Shining Sun committed
219
220
221
222
223
224
  train_epochs = flags_obj.train_epochs

  if flags_obj.train_steps:
    train_steps = min(flags_obj.train_steps, train_steps)
    train_epochs = 1

225
226
  num_eval_steps = (
      imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size)
Shining Sun's avatar
Shining Sun committed
227
228
229

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
230
231
232
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
233
234
235
236
    if flags_obj.set_learning_phase_to_train:
      # TODO(haoyuzhang): Understand slowdown of setting learning phase when
      # not using distribution strategy.
      tf.keras.backend.set_learning_phase(1)
Shining Sun's avatar
Shining Sun committed
237
238
239
    num_eval_steps = None
    validation_data = None

240
241
242
243
244
245
  if not strategy and flags_obj.explicit_gpu_placement:
    # TODO(b/135607227): Add device scope automatically in Keras training loop
    # when not using distribition strategy.
    no_dist_strat_device = tf.device('/device:GPU:0')
    no_dist_strat_device.__enter__()

246
247
248
  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
                      steps_per_epoch=train_steps,
249
                      callbacks=callbacks,
250
251
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
252
                      validation_freq=flags_obj.epochs_between_evals,
253
                      verbose=2)
254

255
  eval_output = None
256
  if not flags_obj.skip_eval:
257
258
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
259
                                 verbose=2)
260
261
262
263

  if not strategy and flags_obj.explicit_gpu_placement:
    no_dist_strat_device.__exit__()

264
  stats = keras_common.build_stats(history, eval_output, callbacks)
265
  return stats
Shining Sun's avatar
bug fix  
Shining Sun committed
266

Shining Sun's avatar
Shining Sun committed
267

Toby Boyd's avatar
Toby Boyd committed
268
269
def define_imagenet_keras_flags():
  keras_common.define_keras_flags()
270
  flags_core.set_defaults(train_epochs=90)
271
  flags.adopt_module_key_flags(keras_common)
Toby Boyd's avatar
Toby Boyd committed
272
273


274
def main(_):
275
  model_helpers.apply_clean(flags.FLAGS)
276
  with logger.benchmark_context(flags.FLAGS):
277
    stats = run(flags.FLAGS)
278
  logging.info('Run stats:\n%s', stats)
279
280
281


if __name__ == '__main__':
282
  logging.set_verbosity(logging.INFO)
Toby Boyd's avatar
Toby Boyd committed
283
  define_imagenet_keras_flags()
284
  absl_app.run(main)