keras_imagenet_main.py 9.06 KB
Newer Older
1
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a ResNet model on the ImageNet dataset."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl import app as absl_app
from absl import flags
import tensorflow as tf  # pylint: disable=g-bad-import-order

Haoyu Zhang's avatar
Haoyu Zhang committed
25
from tensorflow.python.eager import profiler
26
from official.resnet import imagenet_main
27
from official.resnet.keras import keras_common
Shining Sun's avatar
Shining Sun committed
28
from official.resnet.keras import resnet_model
Haoyu Zhang's avatar
Haoyu Zhang committed
29
from official.resnet.keras import trivial_model
30
31
32
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import distribution_utils
33
from official.utils.misc import model_helpers
34
35
36
37
38


LR_SCHEDULE = [    # (multiplier, epoch to start) tuples
    (1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
]
Shining Sun's avatar
Shining Sun committed
39

40

Toby Boyd's avatar
Toby Boyd committed
41
42
43
44
def learning_rate_schedule(current_epoch,
                           current_batch,
                           batches_per_epoch,
                           batch_size):
45
46
  """Handles linear scaling rule, gradual warmup, and LR decay.

Toby Boyd's avatar
Toby Boyd committed
47
48
  Scale learning rate at epoch boundaries provided in LR_SCHEDULE by the
  provided scaling factor.
49
50
51
52

  Args:
    current_epoch: integer, current epoch indexed from 0.
    current_batch: integer, current batch in the current epoch, indexed from 0.
Toby Boyd's avatar
Toby Boyd committed
53
54
    batches_per_epoch: integer, number of steps in an epoch.
    batch_size: integer, total batch sized.
55
56
57
58

  Returns:
    Adjusted learning rate.
  """
Toby Boyd's avatar
Toby Boyd committed
59
  initial_lr = keras_common.BASE_LEARNING_RATE * batch_size / 256
60
61
62
63
  epoch = current_epoch + float(current_batch) / batches_per_epoch
  warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0]
  if epoch < warmup_end_epoch:
    # Learning rate increases linearly per step.
Toby Boyd's avatar
Toby Boyd committed
64
    return initial_lr * warmup_lr_multiplier * epoch / warmup_end_epoch
65
66
  for mult, start_epoch in LR_SCHEDULE:
    if epoch >= start_epoch:
Toby Boyd's avatar
Toby Boyd committed
67
      learning_rate = initial_lr * mult
68
69
70
71
72
73
    else:
      break
  return learning_rate


def parse_record_keras(raw_record, is_training, dtype):
Shining Sun's avatar
Shining Sun committed
74
  """Adjust the shape of label."""
Shining Sun's avatar
Shining Sun committed
75
  image, label = imagenet_main.parse_record(raw_record, is_training, dtype)
Shining Sun's avatar
Shining Sun committed
76

Shining Sun's avatar
Shining Sun committed
77
78
79
  # Subtract one so that labels are in [0, 1000), and cast to float32 for
  # Keras model.
  label = tf.cast(tf.cast(tf.reshape(label, shape=[1]), dtype=tf.int32) - 1,
Toby Boyd's avatar
Toby Boyd committed
80
                  dtype=tf.float32)
81
82
83
  return image, label


Shining Sun's avatar
Shining Sun committed
84
def run(flags_obj):
85
86
87
88
89
90
91
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
92
93
94

  Returns:
    Dictionary of training and eval stats.
95
  """
96
97
  # TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends.
  # Eager is default in tf 2.0 and should not be toggled
98
99
100
101
  if keras_common.is_v2_0():
    keras_common.set_config_v2()
  else:
    config = keras_common.get_config_proto_v1()
102
103
104
105
106
    if flags_obj.enable_eager:
      tf.compat.v1.enable_eager_execution(config=config)
    else:
      sess = tf.Session(config=config)
      tf.keras.backend.set_session(sess)
Shining Sun's avatar
Shining Sun committed
107

108
109
110
111
  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
    keras_common.set_gpu_thread_mode_and_count(flags_obj)

112
  dtype = flags_core.get_tf_dtype(flags_obj)
Reed's avatar
Reed committed
113
114
115
  if dtype == 'float16':
    policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars')
    tf.keras.mixed_precision.experimental.set_policy(policy)
116

117
118
119
120
121
  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)
122

123
124
125
126
127
128
129
  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
      num_workers=distribution_utils.configure_cluster())

  strategy_scope = keras_common.get_strategy_scope(strategy)

130
131
  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
132
    distribution_utils.set_up_synthetic_data()
Shining Sun's avatar
Shining Sun committed
133
    input_fn = keras_common.get_synth_input_fn(
134
135
136
137
        height=imagenet_main.DEFAULT_IMAGE_SIZE,
        width=imagenet_main.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_main.NUM_CHANNELS,
        num_classes=imagenet_main.NUM_CLASSES,
138
139
        dtype=dtype,
        drop_remainder=True)
140
  else:
141
    distribution_utils.undo_set_up_synthetic_data()
Shining Sun's avatar
Shining Sun committed
142
    input_fn = imagenet_main.input_fn
143

144
145
146
147
  # When `enable_xla` is True, we always drop the remainder of the batches
  # in the dataset, as XLA-GPU doesn't support dynamic shapes.
  drop_remainder = flags_obj.enable_xla

148
149
150
151
152
153
  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
      num_epochs=flags_obj.train_epochs,
      parse_record_fn=parse_record_keras,
Reed's avatar
Reed committed
154
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
155
156
      dtype=dtype,
      drop_remainder=drop_remainder)
157

158
159
160
161
162
163
164
165
  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
        parse_record_fn=parse_record_keras,
166
167
        dtype=dtype,
        drop_remainder=drop_remainder)
168

Shining Sun's avatar
Shining Sun committed
169
  with strategy_scope:
Shining Sun's avatar
Shining Sun committed
170
    optimizer = keras_common.get_optimizer()
Reed's avatar
Reed committed
171
172
173
174
175
    if dtype == 'float16':
      # TODO(reedwm): Remove manually wrapping optimizer once mixed precision
      # can be enabled with a single line of code.
      optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
          optimizer, loss_scale=flags_core.get_loss_scale(flags_obj))
Haoyu Zhang's avatar
Haoyu Zhang committed
176

177
    if flags_obj.enable_xla:
178
179
180
181
      if strategy and strategy.num_replicas_in_sync > 1:
        # TODO(b/129791381): Specify `per_replica_batch_size` value in
        # DistributionStrategy multi-replica case.
        per_replica_batch_size = None
182
183
184
185
186
      else:
        per_replica_batch_size = flags_obj.batch_size
    else:
      per_replica_batch_size = None

Haoyu Zhang's avatar
Haoyu Zhang committed
187
188
189
    if flags_obj.use_trivial_model:
      model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES)
    else:
190
191
192
193
      model = resnet_model.resnet50(
          num_classes=imagenet_main.NUM_CLASSES,
          dtype=dtype,
          batch_size=per_replica_batch_size)
Shining Sun's avatar
Shining Sun committed
194

Shining Sun's avatar
Shining Sun committed
195
196
197
    model.compile(loss='sparse_categorical_crossentropy',
                  optimizer=optimizer,
                  metrics=['sparse_categorical_accuracy'])
Shining Sun's avatar
Shining Sun committed
198

199
200
  time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
      learning_rate_schedule, imagenet_main.NUM_IMAGES['train'])
201

Shining Sun's avatar
Shining Sun committed
202
203
204
205
206
207
208
  train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size
  train_epochs = flags_obj.train_epochs

  if flags_obj.train_steps:
    train_steps = min(flags_obj.train_steps, train_steps)
    train_epochs = 1

Shining Sun's avatar
Shining Sun committed
209
  num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] //
Toby Boyd's avatar
Toby Boyd committed
210
                    flags_obj.batch_size)
Shining Sun's avatar
Shining Sun committed
211
212
213

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
214
215
216
217
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
    tf.keras.backend.set_learning_phase(1)
Shining Sun's avatar
Shining Sun committed
218
219
220
    num_eval_steps = None
    validation_data = None

221
222
223
  callbacks = [time_callback, lr_callback]
  if flags_obj.enable_tensorboard:
    callbacks.append(tensorboard_callback)
Haoyu Zhang's avatar
Haoyu Zhang committed
224
225
  if flags_obj.enable_e2e_xprof:
    profiler.start()
226

227
228
229
  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
                      steps_per_epoch=train_steps,
230
                      callbacks=callbacks,
231
232
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
233
                      validation_freq=flags_obj.epochs_between_evals,
234
                      verbose=2)
235

Haoyu Zhang's avatar
Haoyu Zhang committed
236
237
238
239
  if flags_obj.enable_e2e_xprof:
    results = profiler.stop()
    profiler.save(flags_obj.model_dir, results)

240
  eval_output = None
241
  if not flags_obj.skip_eval:
242
243
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
244
                                 verbose=2)
245
  stats = keras_common.build_stats(history, eval_output, time_callback)
246
  return stats
Shining Sun's avatar
bug fix  
Shining Sun committed
247

Shining Sun's avatar
Shining Sun committed
248

249
def main(_):
250
  model_helpers.apply_clean(flags.FLAGS)
251
  with logger.benchmark_context(flags.FLAGS):
252
    return run(flags.FLAGS)
253
254
255


if __name__ == '__main__':
256
  tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
257
  imagenet_main.define_imagenet_flags()
Shining Sun's avatar
Shining Sun committed
258
  keras_common.define_keras_flags()
259
  absl_app.run(main)