resnet_imagenet_main.py 11.7 KB
Newer Older
1
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a ResNet model on the ImageNet dataset."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

Hongkun Yu's avatar
Hongkun Yu committed
21
22
import os

Hongkun Yu's avatar
Hongkun Yu committed
23
from absl import app
24
from absl import flags
25
from absl import logging
26
import tensorflow as tf
27

28
import tensorflow_model_optimization as tfmot
29
from official.modeling import performance
30
31
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
Toby Boyd's avatar
Toby Boyd committed
32
from official.utils.misc import keras_utils
33
from official.utils.misc import model_helpers
Allen Wang's avatar
Allen Wang committed
34
from official.vision.image_classification import test_utils
35
36
from official.vision.image_classification.resnet import common
from official.vision.image_classification.resnet import imagenet_preprocessing
37
from official.vision.image_classification.resnet import resnet_model
38
39


Shining Sun's avatar
Shining Sun committed
40
def run(flags_obj):
41
42
43
44
45
46
47
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
48
    NotImplementedError: If some features are not currently supported.
49
50
51

  Returns:
    Dictionary of training and eval stats.
52
  """
Toby Boyd's avatar
Toby Boyd committed
53
  keras_utils.set_session_config(
54
      enable_xla=flags_obj.enable_xla)
Shining Sun's avatar
Shining Sun committed
55

56
57
  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
58
59
60
61
62
    keras_utils.set_gpu_thread_mode_and_count(
        per_gpu_thread_count=flags_obj.per_gpu_thread_count,
        gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
        num_gpus=flags_obj.num_gpus,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads)
63
  common.set_cudnn_batchnorm_mode()
64

65
  dtype = flags_core.get_tf_dtype(flags_obj)
66
67
68
  performance.set_mixed_precision_policy(
      flags_core.get_tf_dtype(flags_obj),
      flags_core.get_loss_scale(flags_obj, default_for_fp16=128))
69

70
71
  data_format = flags_obj.data_format
  if data_format is None:
72
73
    data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
                   else 'channels_last')
74
  tf.keras.backend.set_image_data_format(data_format)
75

76
  # Configures cluster spec for distribution strategy.
77
78
  _ = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                           flags_obj.task_index)
79

80
81
82
  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
83
      all_reduce_alg=flags_obj.all_reduce_alg,
84
85
      num_packs=flags_obj.num_packs,
      tpu_address=flags_obj.tpu)
86

rxsang's avatar
rxsang committed
87
88
89
90
91
92
93
  if strategy:
    # flags_obj.enable_get_next_as_optional controls whether enabling
    # get_next_as_optional behavior in DistributedIterator. If true, last
    # partial batch can be supported.
    strategy.extended.experimental_enable_get_next_as_optional = (
        flags_obj.enable_get_next_as_optional
    )
94

95
  strategy_scope = distribution_utils.get_strategy_scope(strategy)
96

97
98
  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
99
    input_fn = common.get_synth_input_fn(
100
101
102
103
        height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_preprocessing.NUM_CHANNELS,
        num_classes=imagenet_preprocessing.NUM_CLASSES,
104
105
        dtype=dtype,
        drop_remainder=True)
106
  else:
107
    input_fn = imagenet_preprocessing.input_fn
108

109
110
111
112
  # When `enable_xla` is True, we always drop the remainder of the batches
  # in the dataset, as XLA-GPU doesn't support dynamic shapes.
  drop_remainder = flags_obj.enable_xla

113
114
115
116
117
118
119
  # Current resnet_model.resnet50 input format is always channel-last.
  # We use keras_application mobilenet model which input format is depends on
  # the keras beckend image data format.
  # This use_keras_image_data_format flags indicates whether image preprocessor
  # output format should be same as the keras backend image data format or just
  # channel-last format.
  use_keras_image_data_format = (flags_obj.model == 'mobilenet')
120
121
122
123
  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
124
125
      parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
          use_keras_image_data_format=use_keras_image_data_format),
Reed's avatar
Reed committed
126
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
127
      dtype=dtype,
128
129
      drop_remainder=drop_remainder,
      tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
130
      training_dataset_cache=flags_obj.training_dataset_cache,
131
  )
132

133
134
135
136
137
138
  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
139
140
        parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
            use_keras_image_data_format=use_keras_image_data_format),
141
142
        dtype=dtype,
        drop_remainder=drop_remainder)
143

144
145
146
147
148
149
150
  lr_schedule = common.PiecewiseConstantDecayWithWarmup(
      batch_size=flags_obj.batch_size,
      epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
      warmup_epochs=common.LR_SCHEDULE[0][1],
      boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
      multipliers=list(p[0] for p in common.LR_SCHEDULE),
      compute_lr_on_cpu=True)
151
152
  steps_per_epoch = (
      imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
153

Shining Sun's avatar
Shining Sun committed
154
  with strategy_scope:
155
156
157
158
    if flags_obj.optimizer == 'resnet50_default':
      optimizer = common.get_optimizer(lr_schedule)
    elif flags_obj.optimizer == 'mobilenet_default':
      initial_learning_rate = \
Jaehong Kim's avatar
Jaehong Kim committed
159
          flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size
160
161
162
      optimizer = tf.keras.optimizers.SGD(
          learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
              initial_learning_rate,
Jaehong Kim's avatar
Jaehong Kim committed
163
164
              decay_steps=steps_per_epoch * flags_obj.num_epochs_per_decay,
              decay_rate=flags_obj.lr_decay_factor,
165
166
              staircase=True),
          momentum=0.9)
167
    if flags_obj.fp16_implementation == 'graph_rewrite':
168
169
170
171
172
173
174
      # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
      # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
      # which will ensure tf.compat.v2.keras.mixed_precision and
      # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
      # up.
      optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
          optimizer)
175

Hongkun Yu's avatar
Hongkun Yu committed
176
    # TODO(hongkuny): Remove trivial model usage and move it to benchmark.
Haoyu Zhang's avatar
Haoyu Zhang committed
177
    if flags_obj.use_trivial_model:
Allen Wang's avatar
Allen Wang committed
178
      model = test_utils.trivial_model(imagenet_preprocessing.NUM_CLASSES)
179
    elif flags_obj.model == 'resnet50_v1.5':
180
      model = resnet_model.resnet50(
181
          num_classes=imagenet_preprocessing.NUM_CLASSES)
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    elif flags_obj.model == 'mobilenet':
      # TODO(kimjaehong): Remove layers attribute when minimum TF version
      # support 2.0 layers by default.
      model = tf.keras.applications.mobilenet.MobileNet(
          weights=None,
          classes=imagenet_preprocessing.NUM_CLASSES,
          layers=tf.keras.layers)
    if flags_obj.pretrained_filepath:
      model.load_weights(flags_obj.pretrained_filepath)

    if flags_obj.pruning_method == 'polynomial_decay':
      if dtype != tf.float32:
        raise NotImplementedError(
            'Pruning is currently only supported on dtype=tf.float32.')
      pruning_params = {
          'pruning_schedule':
              tfmot.sparsity.keras.PolynomialDecay(
                  initial_sparsity=flags_obj.pruning_initial_sparsity,
                  final_sparsity=flags_obj.pruning_final_sparsity,
                  begin_step=flags_obj.pruning_begin_step,
                  end_step=flags_obj.pruning_end_step,
                  frequency=flags_obj.pruning_frequency),
      }
      model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)
    elif flags_obj.pruning_method:
      raise NotImplementedError(
          'Only polynomial_decay is currently supported.')
209
210
211
212
213
214
215

    model.compile(
        loss='sparse_categorical_crossentropy',
        optimizer=optimizer,
        metrics=(['sparse_categorical_accuracy']
                 if flags_obj.report_accuracy_metrics else None),
        run_eagerly=flags_obj.run_eagerly)
Shining Sun's avatar
Shining Sun committed
216

Zongwei Zhou's avatar
Zongwei Zhou committed
217
218
  train_epochs = flags_obj.train_epochs

219
220
221
222
  callbacks = common.get_callbacks(
      pruning_method=flags_obj.pruning_method,
      enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export,
      model_dir=flags_obj.model_dir)
Shining Sun's avatar
Shining Sun committed
223

224
225
  # if mutliple epochs, ignore the train_steps flag.
  if train_epochs <= 1 and flags_obj.train_steps:
Zongwei Zhou's avatar
Zongwei Zhou committed
226
    steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
Shining Sun's avatar
Shining Sun committed
227
228
    train_epochs = 1

229
230
  num_eval_steps = (
      imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size)
Shining Sun's avatar
Shining Sun committed
231
232
233

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
234
235
236
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
237
238
239
240
    if flags_obj.set_learning_phase_to_train:
      # TODO(haoyuzhang): Understand slowdown of setting learning phase when
      # not using distribution strategy.
      tf.keras.backend.set_learning_phase(1)
Shining Sun's avatar
Shining Sun committed
241
242
243
    num_eval_steps = None
    validation_data = None

244
245
246
247
248
249
  if not strategy and flags_obj.explicit_gpu_placement:
    # TODO(b/135607227): Add device scope automatically in Keras training loop
    # when not using distribition strategy.
    no_dist_strat_device = tf.device('/device:GPU:0')
    no_dist_strat_device.__enter__()

250
251
  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
Zongwei Zhou's avatar
Zongwei Zhou committed
252
                      steps_per_epoch=steps_per_epoch,
253
                      callbacks=callbacks,
254
255
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
256
                      validation_freq=flags_obj.epochs_between_evals,
257
                      verbose=2)
Jaehong Kim's avatar
Jaehong Kim committed
258

Hongkun Yu's avatar
Hongkun Yu committed
259
260
261
262
263
264
  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)

265
266
267
268
269
270
271
272
273
274
  if flags_obj.pruning_method:
    model = tfmot.sparsity.keras.strip_pruning(model)
  if flags_obj.enable_checkpoint_and_export:
    if dtype == tf.bfloat16:
      logging.warning('Keras model.save does not support bfloat16 dtype.')
    else:
      # Keras model.save assumes a float32 input designature.
      export_path = os.path.join(flags_obj.model_dir, 'saved_model')
      model.save(export_path, include_optimizer=False)

275
276
277
  if not strategy and flags_obj.explicit_gpu_placement:
    no_dist_strat_device.__exit__()

278
  stats = common.build_stats(history, eval_output, callbacks)
279
  return stats
Shining Sun's avatar
bug fix  
Shining Sun committed
280

Shining Sun's avatar
Shining Sun committed
281

Toby Boyd's avatar
Toby Boyd committed
282
def define_imagenet_keras_flags():
283
284
285
286
287
  common.define_keras_flags(
      model=True,
      optimizer=True,
      pretrained_filepath=True)
  common.define_pruning_flags()
288
  flags_core.set_defaults()
289
  flags.adopt_module_key_flags(common)
Toby Boyd's avatar
Toby Boyd committed
290
291


292
def main(_):
293
  model_helpers.apply_clean(flags.FLAGS)
294
  stats = run(flags.FLAGS)
295
  logging.info('Run stats:\n%s', stats)
296
297
298


if __name__ == '__main__':
299
  logging.set_verbosity(logging.INFO)
Toby Boyd's avatar
Toby Boyd committed
300
  define_imagenet_keras_flags()
Hongkun Yu's avatar
Hongkun Yu committed
301
  app.run(main)