resnet_imagenet_main.py 11.8 KB
Newer Older
1
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a ResNet model on the ImageNet dataset."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

Hongkun Yu's avatar
Hongkun Yu committed
21
22
import os

Hongkun Yu's avatar
Hongkun Yu committed
23
from absl import app
24
from absl import flags
25
from absl import logging
26
import tensorflow as tf
27

28
import tensorflow_model_optimization as tfmot
29
from official.modeling import performance
30
31
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
Toby Boyd's avatar
Toby Boyd committed
32
from official.utils.misc import keras_utils
33
from official.utils.misc import model_helpers
Allen Wang's avatar
Allen Wang committed
34
from official.vision.image_classification import test_utils
35
36
from official.vision.image_classification.resnet import common
from official.vision.image_classification.resnet import imagenet_preprocessing
37
from official.vision.image_classification.resnet import resnet_model
38
39


Shining Sun's avatar
Shining Sun committed
40
def run(flags_obj):
41
42
43
44
45
46
47
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
48
    NotImplementedError: If some features are not currently supported.
49
50
51

  Returns:
    Dictionary of training and eval stats.
52
  """
Toby Boyd's avatar
Toby Boyd committed
53
54
  keras_utils.set_session_config(
      enable_eager=flags_obj.enable_eager,
55
      enable_xla=flags_obj.enable_xla)
Shining Sun's avatar
Shining Sun committed
56

57
58
  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
59
60
61
62
63
    keras_utils.set_gpu_thread_mode_and_count(
        per_gpu_thread_count=flags_obj.per_gpu_thread_count,
        gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
        num_gpus=flags_obj.num_gpus,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads)
64
  common.set_cudnn_batchnorm_mode()
65

66
  dtype = flags_core.get_tf_dtype(flags_obj)
67
68
69
  performance.set_mixed_precision_policy(
      flags_core.get_tf_dtype(flags_obj),
      flags_core.get_loss_scale(flags_obj, default_for_fp16=128))
70

71
72
  data_format = flags_obj.data_format
  if data_format is None:
73
74
    data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
                   else 'channels_last')
75
  tf.keras.backend.set_image_data_format(data_format)
76

77
  # Configures cluster spec for distribution strategy.
78
79
  _ = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                           flags_obj.task_index)
80

81
82
83
  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
84
      all_reduce_alg=flags_obj.all_reduce_alg,
85
86
      num_packs=flags_obj.num_packs,
      tpu_address=flags_obj.tpu)
87

rxsang's avatar
rxsang committed
88
89
90
91
92
93
94
  if strategy:
    # flags_obj.enable_get_next_as_optional controls whether enabling
    # get_next_as_optional behavior in DistributedIterator. If true, last
    # partial batch can be supported.
    strategy.extended.experimental_enable_get_next_as_optional = (
        flags_obj.enable_get_next_as_optional
    )
95

96
  strategy_scope = distribution_utils.get_strategy_scope(strategy)
97

98
99
  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
100
    input_fn = common.get_synth_input_fn(
101
102
103
104
        height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_preprocessing.NUM_CHANNELS,
        num_classes=imagenet_preprocessing.NUM_CLASSES,
105
106
        dtype=dtype,
        drop_remainder=True)
107
  else:
108
    input_fn = imagenet_preprocessing.input_fn
109

110
111
112
113
  # When `enable_xla` is True, we always drop the remainder of the batches
  # in the dataset, as XLA-GPU doesn't support dynamic shapes.
  drop_remainder = flags_obj.enable_xla

114
115
116
117
118
119
120
  # Current resnet_model.resnet50 input format is always channel-last.
  # We use keras_application mobilenet model which input format is depends on
  # the keras beckend image data format.
  # This use_keras_image_data_format flags indicates whether image preprocessor
  # output format should be same as the keras backend image data format or just
  # channel-last format.
  use_keras_image_data_format = (flags_obj.model == 'mobilenet')
121
122
123
124
  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
125
126
      parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
          use_keras_image_data_format=use_keras_image_data_format),
Reed's avatar
Reed committed
127
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
128
      dtype=dtype,
129
130
      drop_remainder=drop_remainder,
      tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
131
      training_dataset_cache=flags_obj.training_dataset_cache,
132
  )
133

134
135
136
137
138
139
  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
140
141
        parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
            use_keras_image_data_format=use_keras_image_data_format),
142
143
        dtype=dtype,
        drop_remainder=drop_remainder)
144

145
146
147
148
149
150
151
  lr_schedule = common.PiecewiseConstantDecayWithWarmup(
      batch_size=flags_obj.batch_size,
      epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
      warmup_epochs=common.LR_SCHEDULE[0][1],
      boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
      multipliers=list(p[0] for p in common.LR_SCHEDULE),
      compute_lr_on_cpu=True)
152
153
  steps_per_epoch = (
      imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
154

Shining Sun's avatar
Shining Sun committed
155
  with strategy_scope:
156
157
158
159
    if flags_obj.optimizer == 'resnet50_default':
      optimizer = common.get_optimizer(lr_schedule)
    elif flags_obj.optimizer == 'mobilenet_default':
      initial_learning_rate = \
Jaehong Kim's avatar
Jaehong Kim committed
160
          flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size
161
162
163
      optimizer = tf.keras.optimizers.SGD(
          learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
              initial_learning_rate,
Jaehong Kim's avatar
Jaehong Kim committed
164
165
              decay_steps=steps_per_epoch * flags_obj.num_epochs_per_decay,
              decay_rate=flags_obj.lr_decay_factor,
166
167
              staircase=True),
          momentum=0.9)
168
    if flags_obj.fp16_implementation == 'graph_rewrite':
169
170
171
172
173
174
175
      # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
      # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
      # which will ensure tf.compat.v2.keras.mixed_precision and
      # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
      # up.
      optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
          optimizer)
176

Hongkun Yu's avatar
Hongkun Yu committed
177
    # TODO(hongkuny): Remove trivial model usage and move it to benchmark.
Haoyu Zhang's avatar
Haoyu Zhang committed
178
    if flags_obj.use_trivial_model:
Allen Wang's avatar
Allen Wang committed
179
      model = test_utils.trivial_model(imagenet_preprocessing.NUM_CLASSES)
180
    elif flags_obj.model == 'resnet50_v1.5':
181
      model = resnet_model.resnet50(
182
          num_classes=imagenet_preprocessing.NUM_CLASSES)
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    elif flags_obj.model == 'mobilenet':
      # TODO(kimjaehong): Remove layers attribute when minimum TF version
      # support 2.0 layers by default.
      model = tf.keras.applications.mobilenet.MobileNet(
          weights=None,
          classes=imagenet_preprocessing.NUM_CLASSES,
          layers=tf.keras.layers)
    if flags_obj.pretrained_filepath:
      model.load_weights(flags_obj.pretrained_filepath)

    if flags_obj.pruning_method == 'polynomial_decay':
      if dtype != tf.float32:
        raise NotImplementedError(
            'Pruning is currently only supported on dtype=tf.float32.')
      pruning_params = {
          'pruning_schedule':
              tfmot.sparsity.keras.PolynomialDecay(
                  initial_sparsity=flags_obj.pruning_initial_sparsity,
                  final_sparsity=flags_obj.pruning_final_sparsity,
                  begin_step=flags_obj.pruning_begin_step,
                  end_step=flags_obj.pruning_end_step,
                  frequency=flags_obj.pruning_frequency),
      }
      model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)
    elif flags_obj.pruning_method:
      raise NotImplementedError(
          'Only polynomial_decay is currently supported.')
210
211
212
213
214
215
216

    model.compile(
        loss='sparse_categorical_crossentropy',
        optimizer=optimizer,
        metrics=(['sparse_categorical_accuracy']
                 if flags_obj.report_accuracy_metrics else None),
        run_eagerly=flags_obj.run_eagerly)
Shining Sun's avatar
Shining Sun committed
217

Zongwei Zhou's avatar
Zongwei Zhou committed
218
219
  train_epochs = flags_obj.train_epochs

220
221
222
223
224
  callbacks = common.get_callbacks(
      steps_per_epoch=steps_per_epoch,
      pruning_method=flags_obj.pruning_method,
      enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export,
      model_dir=flags_obj.model_dir)
Shining Sun's avatar
Shining Sun committed
225

226
227
  # if mutliple epochs, ignore the train_steps flag.
  if train_epochs <= 1 and flags_obj.train_steps:
Zongwei Zhou's avatar
Zongwei Zhou committed
228
    steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
Shining Sun's avatar
Shining Sun committed
229
230
    train_epochs = 1

231
232
  num_eval_steps = (
      imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size)
Shining Sun's avatar
Shining Sun committed
233
234
235

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
236
237
238
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
239
240
241
242
    if flags_obj.set_learning_phase_to_train:
      # TODO(haoyuzhang): Understand slowdown of setting learning phase when
      # not using distribution strategy.
      tf.keras.backend.set_learning_phase(1)
Shining Sun's avatar
Shining Sun committed
243
244
245
    num_eval_steps = None
    validation_data = None

246
247
248
249
250
251
  if not strategy and flags_obj.explicit_gpu_placement:
    # TODO(b/135607227): Add device scope automatically in Keras training loop
    # when not using distribition strategy.
    no_dist_strat_device = tf.device('/device:GPU:0')
    no_dist_strat_device.__enter__()

252
253
  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
Zongwei Zhou's avatar
Zongwei Zhou committed
254
                      steps_per_epoch=steps_per_epoch,
255
                      callbacks=callbacks,
256
257
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
258
                      validation_freq=flags_obj.epochs_between_evals,
259
                      verbose=2)
Jaehong Kim's avatar
Jaehong Kim committed
260

Hongkun Yu's avatar
Hongkun Yu committed
261
262
263
264
265
266
  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)

267
268
269
270
271
272
273
274
275
276
  if flags_obj.pruning_method:
    model = tfmot.sparsity.keras.strip_pruning(model)
  if flags_obj.enable_checkpoint_and_export:
    if dtype == tf.bfloat16:
      logging.warning('Keras model.save does not support bfloat16 dtype.')
    else:
      # Keras model.save assumes a float32 input designature.
      export_path = os.path.join(flags_obj.model_dir, 'saved_model')
      model.save(export_path, include_optimizer=False)

277
278
279
  if not strategy and flags_obj.explicit_gpu_placement:
    no_dist_strat_device.__exit__()

280
  stats = common.build_stats(history, eval_output, callbacks)
281
  return stats
Shining Sun's avatar
bug fix  
Shining Sun committed
282

Shining Sun's avatar
Shining Sun committed
283

Toby Boyd's avatar
Toby Boyd committed
284
def define_imagenet_keras_flags():
285
286
287
288
289
  common.define_keras_flags(
      model=True,
      optimizer=True,
      pretrained_filepath=True)
  common.define_pruning_flags()
290
  flags_core.set_defaults()
291
  flags.adopt_module_key_flags(common)
Toby Boyd's avatar
Toby Boyd committed
292
293


294
def main(_):
295
  model_helpers.apply_clean(flags.FLAGS)
296
  stats = run(flags.FLAGS)
297
  logging.info('Run stats:\n%s', stats)
298
299
300


if __name__ == '__main__':
301
  logging.set_verbosity(logging.INFO)
Toby Boyd's avatar
Toby Boyd committed
302
  define_imagenet_keras_flags()
Hongkun Yu's avatar
Hongkun Yu committed
303
  app.run(main)