keras_imagenet_main.py 6.18 KB
Newer Older
1
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a ResNet model on the ImageNet dataset."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl import app as absl_app
from absl import flags
import tensorflow as tf  # pylint: disable=g-bad-import-order

from official.resnet import imagenet_main
26
from official.resnet.keras import keras_common
Shining Sun's avatar
Shining Sun committed
27
from official.resnet.keras import resnet_model
28
29
30
31
32
33
34
35
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import distribution_utils


LR_SCHEDULE = [    # (multiplier, epoch to start) tuples
    (1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
]
Shining Sun's avatar
Shining Sun committed
36

37

Toby Boyd's avatar
Toby Boyd committed
38
39
40
41
def learning_rate_schedule(current_epoch,
                           current_batch,
                           batches_per_epoch,
                           batch_size):
42
43
  """Handles linear scaling rule, gradual warmup, and LR decay.

Toby Boyd's avatar
Toby Boyd committed
44
45
  Scale learning rate at epoch boundaries provided in LR_SCHEDULE by the
  provided scaling factor.
46
47
48
49

  Args:
    current_epoch: integer, current epoch indexed from 0.
    current_batch: integer, current batch in the current epoch, indexed from 0.
Toby Boyd's avatar
Toby Boyd committed
50
51
    batches_per_epoch: integer, number of steps in an epoch.
    batch_size: integer, total batch sized.
52
53
54
55

  Returns:
    Adjusted learning rate.
  """
Toby Boyd's avatar
Toby Boyd committed
56
  initial_lr = keras_common.BASE_LEARNING_RATE * batch_size / 256
57
58
59
60
  epoch = current_epoch + float(current_batch) / batches_per_epoch
  warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0]
  if epoch < warmup_end_epoch:
    # Learning rate increases linearly per step.
Toby Boyd's avatar
Toby Boyd committed
61
    return initial_lr * warmup_lr_multiplier * epoch / warmup_end_epoch
62
63
  for mult, start_epoch in LR_SCHEDULE:
    if epoch >= start_epoch:
Toby Boyd's avatar
Toby Boyd committed
64
      learning_rate = initial_lr * mult
65
66
67
68
69
70
    else:
      break
  return learning_rate


def parse_record_keras(raw_record, is_training, dtype):
Shining Sun's avatar
Shining Sun committed
71
  """Adjust the shape of label."""
Shining Sun's avatar
Shining Sun committed
72
  image, label = imagenet_main.parse_record(raw_record, is_training, dtype)
Shining Sun's avatar
Shining Sun committed
73

Shining Sun's avatar
Shining Sun committed
74
75
76
  # Subtract one so that labels are in [0, 1000), and cast to float32 for
  # Keras model.
  label = tf.cast(tf.cast(tf.reshape(label, shape=[1]), dtype=tf.int32) - 1,
Toby Boyd's avatar
Toby Boyd committed
77
                  dtype=tf.float32)
78
79
80
  return image, label


Shining Sun's avatar
Shining Sun committed
81
def run(flags_obj):
82
83
84
85
86
87
88
89
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
  """
90
91
  if flags_obj.enable_eager:
    tf.enable_eager_execution()
Shining Sun's avatar
Shining Sun committed
92

93
94
95
96
97
98
99
100
101
102
  dtype = flags_core.get_tf_dtype(flags_obj)
  if dtype == 'fp16':
    raise ValueError('dtype fp16 is not supported in Keras. Use the default '
                     'value(fp32).')

  per_device_batch_size = distribution_utils.per_device_batch_size(
      flags_obj.batch_size, flags_core.get_num_gpus(flags_obj))

  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
Shining Sun's avatar
Shining Sun committed
103
    input_fn = keras_common.get_synth_input_fn(
104
105
106
107
        height=imagenet_main.DEFAULT_IMAGE_SIZE,
        width=imagenet_main.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_main.NUM_CHANNELS,
        num_classes=imagenet_main.NUM_CLASSES,
Shining Sun's avatar
Shining Sun committed
108
        dtype=flags_core.get_tf_dtype(flags_obj))
109
  else:
Shining Sun's avatar
Shining Sun committed
110
    input_fn = imagenet_main.input_fn
111

Toby Boyd's avatar
Toby Boyd committed
112
113
114
115
116
  train_input_dataset = input_fn(is_training=True,
                                 data_dir=flags_obj.data_dir,
                                 batch_size=per_device_batch_size,
                                 num_epochs=flags_obj.train_epochs,
                                 parse_record_fn=parse_record_keras)
117

Toby Boyd's avatar
Toby Boyd committed
118
119
120
121
122
  eval_input_dataset = input_fn(is_training=False,
                                data_dir=flags_obj.data_dir,
                                batch_size=per_device_batch_size,
                                num_epochs=flags_obj.train_epochs,
                                parse_record_fn=parse_record_keras)
123

Shining Sun's avatar
Shining Sun committed
124
  optimizer = keras_common.get_optimizer()
125
  strategy = distribution_utils.get_distribution_strategy(
Toby Boyd's avatar
Toby Boyd committed
126
      flags_obj.num_gpus, flags_obj.turn_off_distribution_strategy)
127

Shining Sun's avatar
Shining Sun committed
128
  model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES)
Shining Sun's avatar
Shining Sun committed
129

Shining Sun's avatar
Shining Sun committed
130
  model.compile(loss='sparse_categorical_crossentropy',
Shining Sun's avatar
Shining Sun committed
131
                optimizer=optimizer,
Shining Sun's avatar
Shining Sun committed
132
                metrics=['sparse_categorical_accuracy'],
133
                distribute=strategy)
Shining Sun's avatar
Shining Sun committed
134

135
136
  time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
      learning_rate_schedule, imagenet_main.NUM_IMAGES['train'])
137

Shining Sun's avatar
Shining Sun committed
138
139
140
141
142
143
144
  train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size
  train_epochs = flags_obj.train_epochs

  if flags_obj.train_steps:
    train_steps = min(flags_obj.train_steps, train_steps)
    train_epochs = 1

Shining Sun's avatar
Shining Sun committed
145
  num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] //
Toby Boyd's avatar
Toby Boyd committed
146
                    flags_obj.batch_size)
Shining Sun's avatar
Shining Sun committed
147
148
149
150
151
152

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
    num_eval_steps = None
    validation_data = None

Toby Boyd's avatar
Toby Boyd committed
153
154
155
156
157
158
159
160
161
162
163
  model.fit(train_input_dataset,
            epochs=train_epochs,
            steps_per_epoch=train_steps,
            callbacks=[
                time_callback,
                lr_callback,
                tensorboard_callback
                ],
            validation_steps=num_eval_steps,
            validation_data=validation_data,
            verbose=1)
Shining Sun's avatar
Shining Sun committed
164

165
  if not flags_obj.skip_eval:
Toby Boyd's avatar
Toby Boyd committed
166
167
168
    model.evaluate(eval_input_dataset,
                   steps=num_eval_steps,
                   verbose=1)
Shining Sun's avatar
bug fix  
Shining Sun committed
169

Shining Sun's avatar
Shining Sun committed
170

171
172
def main(_):
  with logger.benchmark_context(flags.FLAGS):
Shining Sun's avatar
Shining Sun committed
173
    run(flags.FLAGS)
174
175
176


if __name__ == '__main__':
177
  tf.logging.set_verbosity(tf.logging.INFO)
178
  imagenet_main.define_imagenet_flags()
Shining Sun's avatar
Shining Sun committed
179
  keras_common.define_keras_flags()
180
  absl_app.run(main)