resnet_imagenet_main.py 11.8 KB
Newer Older
1
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a ResNet model on the ImageNet dataset."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

Hongkun Yu's avatar
Hongkun Yu committed
21
22
import os

Hongkun Yu's avatar
Hongkun Yu committed
23
from absl import app
24
from absl import flags
25
from absl import logging
26
import tensorflow as tf
27

28
import tensorflow_model_optimization as tfmot
29
from official.modeling import performance
30
31
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
Toby Boyd's avatar
Toby Boyd committed
32
from official.utils.misc import keras_utils
33
from official.utils.misc import model_helpers
Allen Wang's avatar
Allen Wang committed
34
from official.vision.image_classification import test_utils
35
36
from official.vision.image_classification.resnet import common
from official.vision.image_classification.resnet import imagenet_preprocessing
37
from official.vision.image_classification.resnet import resnet_model
38
39


Shining Sun's avatar
Shining Sun committed
40
def run(flags_obj):
41
42
43
44
45
46
47
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
48
    NotImplementedError: If some features are not currently supported.
49
50
51

  Returns:
    Dictionary of training and eval stats.
52
  """
Toby Boyd's avatar
Toby Boyd committed
53
54
  keras_utils.set_session_config(
      enable_eager=flags_obj.enable_eager,
55
      enable_xla=flags_obj.enable_xla)
Shining Sun's avatar
Shining Sun committed
56

57
58
  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
59
60
61
62
63
    keras_utils.set_gpu_thread_mode_and_count(
        per_gpu_thread_count=flags_obj.per_gpu_thread_count,
        gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
        num_gpus=flags_obj.num_gpus,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads)
64
  common.set_cudnn_batchnorm_mode()
65

66
  dtype = flags_core.get_tf_dtype(flags_obj)
67
68
69
  performance.set_mixed_precision_policy(
      flags_core.get_tf_dtype(flags_obj),
      flags_core.get_loss_scale(flags_obj, default_for_fp16=128))
70

71
72
  data_format = flags_obj.data_format
  if data_format is None:
73
74
    data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
                   else 'channels_last')
75
  tf.keras.backend.set_image_data_format(data_format)
76

77
  # Configures cluster spec for distribution strategy.
78
79
  _ = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                           flags_obj.task_index)
80

81
82
83
  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
84
      all_reduce_alg=flags_obj.all_reduce_alg,
85
86
      num_packs=flags_obj.num_packs,
      tpu_address=flags_obj.tpu)
87

rxsang's avatar
rxsang committed
88
89
90
91
92
93
94
  if strategy:
    # flags_obj.enable_get_next_as_optional controls whether enabling
    # get_next_as_optional behavior in DistributedIterator. If true, last
    # partial batch can be supported.
    strategy.extended.experimental_enable_get_next_as_optional = (
        flags_obj.enable_get_next_as_optional
    )
95

96
  strategy_scope = distribution_utils.get_strategy_scope(strategy)
97

98
99
  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
100
    input_fn = common.get_synth_input_fn(
101
102
103
104
        height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_preprocessing.NUM_CHANNELS,
        num_classes=imagenet_preprocessing.NUM_CLASSES,
105
106
        dtype=dtype,
        drop_remainder=True)
107
  else:
108
    input_fn = imagenet_preprocessing.input_fn
109

110
111
112
113
  # When `enable_xla` is True, we always drop the remainder of the batches
  # in the dataset, as XLA-GPU doesn't support dynamic shapes.
  drop_remainder = flags_obj.enable_xla

114
115
116
117
118
119
120
  # Current resnet_model.resnet50 input format is always channel-last.
  # We use keras_application mobilenet model which input format is depends on
  # the keras beckend image data format.
  # This use_keras_image_data_format flags indicates whether image preprocessor
  # output format should be same as the keras backend image data format or just
  # channel-last format.
  use_keras_image_data_format = (flags_obj.model == 'mobilenet')
121
122
123
124
  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
125
126
      parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
          use_keras_image_data_format=use_keras_image_data_format),
Reed's avatar
Reed committed
127
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
128
      dtype=dtype,
129
130
      drop_remainder=drop_remainder,
      tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
131
      training_dataset_cache=flags_obj.training_dataset_cache,
132
  )
133

134
135
136
137
138
139
  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
140
141
        parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
            use_keras_image_data_format=use_keras_image_data_format),
142
143
        dtype=dtype,
        drop_remainder=drop_remainder)
144

145
146
147
148
149
150
151
  lr_schedule = common.PiecewiseConstantDecayWithWarmup(
      batch_size=flags_obj.batch_size,
      epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
      warmup_epochs=common.LR_SCHEDULE[0][1],
      boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
      multipliers=list(p[0] for p in common.LR_SCHEDULE),
      compute_lr_on_cpu=True)
152
153
  steps_per_epoch = (
      imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
154

Shining Sun's avatar
Shining Sun committed
155
  with strategy_scope:
156
157
158
159
    if flags_obj.optimizer == 'resnet50_default':
      optimizer = common.get_optimizer(lr_schedule)
    elif flags_obj.optimizer == 'mobilenet_default':
      initial_learning_rate = \
Jaehong Kim's avatar
Jaehong Kim committed
160
          flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size
161
162
163
      optimizer = tf.keras.optimizers.SGD(
          learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
              initial_learning_rate,
Jaehong Kim's avatar
Jaehong Kim committed
164
165
              decay_steps=steps_per_epoch * flags_obj.num_epochs_per_decay,
              decay_rate=flags_obj.lr_decay_factor,
166
167
              staircase=True),
          momentum=0.9)
168
    if flags_obj.fp16_implementation == 'graph_rewrite':
169
170
171
172
173
174
175
      # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
      # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
      # which will ensure tf.compat.v2.keras.mixed_precision and
      # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
      # up.
      optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
          optimizer)
176

Hongkun Yu's avatar
Hongkun Yu committed
177
    # TODO(hongkuny): Remove trivial model usage and move it to benchmark.
Haoyu Zhang's avatar
Haoyu Zhang committed
178
    if flags_obj.use_trivial_model:
Allen Wang's avatar
Allen Wang committed
179
      model = test_utils.trivial_model(imagenet_preprocessing.NUM_CLASSES)
180
    elif flags_obj.model == 'resnet50_v1.5':
181
      model = resnet_model.resnet50(
182
          num_classes=imagenet_preprocessing.NUM_CLASSES)
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    elif flags_obj.model == 'mobilenet':
      # TODO(kimjaehong): Remove layers attribute when minimum TF version
      # support 2.0 layers by default.
      model = tf.keras.applications.mobilenet.MobileNet(
          weights=None,
          classes=imagenet_preprocessing.NUM_CLASSES,
          layers=tf.keras.layers)
    if flags_obj.pretrained_filepath:
      model.load_weights(flags_obj.pretrained_filepath)

    if flags_obj.pruning_method == 'polynomial_decay':
      if dtype != tf.float32:
        raise NotImplementedError(
            'Pruning is currently only supported on dtype=tf.float32.')
      pruning_params = {
          'pruning_schedule':
              tfmot.sparsity.keras.PolynomialDecay(
                  initial_sparsity=flags_obj.pruning_initial_sparsity,
                  final_sparsity=flags_obj.pruning_final_sparsity,
                  begin_step=flags_obj.pruning_begin_step,
                  end_step=flags_obj.pruning_end_step,
                  frequency=flags_obj.pruning_frequency),
      }
      model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)
    elif flags_obj.pruning_method:
      raise NotImplementedError(
          'Only polynomial_decay is currently supported.')
210
211
212
213
214
215
216

    model.compile(
        loss='sparse_categorical_crossentropy',
        optimizer=optimizer,
        metrics=(['sparse_categorical_accuracy']
                 if flags_obj.report_accuracy_metrics else None),
        run_eagerly=flags_obj.run_eagerly)
Shining Sun's avatar
Shining Sun committed
217

Zongwei Zhou's avatar
Zongwei Zhou committed
218
219
  train_epochs = flags_obj.train_epochs

220
221
222
223
  callbacks = common.get_callbacks(
      pruning_method=flags_obj.pruning_method,
      enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export,
      model_dir=flags_obj.model_dir)
Shining Sun's avatar
Shining Sun committed
224

225
226
  # if mutliple epochs, ignore the train_steps flag.
  if train_epochs <= 1 and flags_obj.train_steps:
Zongwei Zhou's avatar
Zongwei Zhou committed
227
    steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
Shining Sun's avatar
Shining Sun committed
228
229
    train_epochs = 1

230
231
  num_eval_steps = (
      imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size)
Shining Sun's avatar
Shining Sun committed
232
233
234

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
235
236
237
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
238
239
240
241
    if flags_obj.set_learning_phase_to_train:
      # TODO(haoyuzhang): Understand slowdown of setting learning phase when
      # not using distribution strategy.
      tf.keras.backend.set_learning_phase(1)
Shining Sun's avatar
Shining Sun committed
242
243
244
    num_eval_steps = None
    validation_data = None

245
246
247
248
249
250
  if not strategy and flags_obj.explicit_gpu_placement:
    # TODO(b/135607227): Add device scope automatically in Keras training loop
    # when not using distribition strategy.
    no_dist_strat_device = tf.device('/device:GPU:0')
    no_dist_strat_device.__enter__()

251
252
  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
Zongwei Zhou's avatar
Zongwei Zhou committed
253
                      steps_per_epoch=steps_per_epoch,
254
                      callbacks=callbacks,
255
256
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
257
                      validation_freq=flags_obj.epochs_between_evals,
258
                      verbose=2)
Jaehong Kim's avatar
Jaehong Kim committed
259

Hongkun Yu's avatar
Hongkun Yu committed
260
261
262
263
264
265
  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)

266
267
268
269
270
271
272
273
274
275
  if flags_obj.pruning_method:
    model = tfmot.sparsity.keras.strip_pruning(model)
  if flags_obj.enable_checkpoint_and_export:
    if dtype == tf.bfloat16:
      logging.warning('Keras model.save does not support bfloat16 dtype.')
    else:
      # Keras model.save assumes a float32 input designature.
      export_path = os.path.join(flags_obj.model_dir, 'saved_model')
      model.save(export_path, include_optimizer=False)

276
277
278
  if not strategy and flags_obj.explicit_gpu_placement:
    no_dist_strat_device.__exit__()

279
  stats = common.build_stats(history, eval_output, callbacks)
280
  return stats
Shining Sun's avatar
bug fix  
Shining Sun committed
281

Shining Sun's avatar
Shining Sun committed
282

Toby Boyd's avatar
Toby Boyd committed
283
def define_imagenet_keras_flags():
284
285
286
287
288
  common.define_keras_flags(
      model=True,
      optimizer=True,
      pretrained_filepath=True)
  common.define_pruning_flags()
289
  flags_core.set_defaults()
290
  flags.adopt_module_key_flags(common)
Toby Boyd's avatar
Toby Boyd committed
291
292


293
def main(_):
294
  model_helpers.apply_clean(flags.FLAGS)
295
  stats = run(flags.FLAGS)
296
  logging.info('Run stats:\n%s', stats)
297
298
299


if __name__ == '__main__':
300
  logging.set_verbosity(logging.INFO)
Toby Boyd's avatar
Toby Boyd committed
301
  define_imagenet_keras_flags()
Hongkun Yu's avatar
Hongkun Yu committed
302
  app.run(main)