resnet_imagenet_main.py 11.9 KB
Newer Older
1
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a ResNet model on the ImageNet dataset."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

Hongkun Yu's avatar
Hongkun Yu committed
21
22
import os

Hongkun Yu's avatar
Hongkun Yu committed
23
from absl import app
24
from absl import flags
25
from absl import logging
26
import tensorflow as tf
27

28
import tensorflow_model_optimization as tfmot
29
from official.modeling import performance
30
31
32
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import distribution_utils
Toby Boyd's avatar
Toby Boyd committed
33
from official.utils.misc import keras_utils
34
from official.utils.misc import model_helpers
Allen Wang's avatar
Allen Wang committed
35
from official.vision.image_classification import test_utils
36
37
from official.vision.image_classification.resnet import common
from official.vision.image_classification.resnet import imagenet_preprocessing
38
from official.vision.image_classification.resnet import resnet_model
39
40


Shining Sun's avatar
Shining Sun committed
41
def run(flags_obj):
42
43
44
45
46
47
48
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
49
    NotImplementedError: If some features are not currently supported.
50
51
52

  Returns:
    Dictionary of training and eval stats.
53
  """
Toby Boyd's avatar
Toby Boyd committed
54
55
  keras_utils.set_session_config(
      enable_eager=flags_obj.enable_eager,
56
      enable_xla=flags_obj.enable_xla)
Shining Sun's avatar
Shining Sun committed
57

58
59
  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
60
61
62
63
64
    keras_utils.set_gpu_thread_mode_and_count(
        per_gpu_thread_count=flags_obj.per_gpu_thread_count,
        gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
        num_gpus=flags_obj.num_gpus,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads)
65
  common.set_cudnn_batchnorm_mode()
66

67
  dtype = flags_core.get_tf_dtype(flags_obj)
68
69
70
  performance.set_mixed_precision_policy(
      flags_core.get_tf_dtype(flags_obj),
      flags_core.get_loss_scale(flags_obj, default_for_fp16=128))
71

72
73
  data_format = flags_obj.data_format
  if data_format is None:
74
75
    data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
                   else 'channels_last')
76
  tf.keras.backend.set_image_data_format(data_format)
77

78
  # Configures cluster spec for distribution strategy.
79
80
  _ = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                           flags_obj.task_index)
81

82
83
84
  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
85
      all_reduce_alg=flags_obj.all_reduce_alg,
86
87
      num_packs=flags_obj.num_packs,
      tpu_address=flags_obj.tpu)
88

rxsang's avatar
rxsang committed
89
90
91
92
93
94
95
  if strategy:
    # flags_obj.enable_get_next_as_optional controls whether enabling
    # get_next_as_optional behavior in DistributedIterator. If true, last
    # partial batch can be supported.
    strategy.extended.experimental_enable_get_next_as_optional = (
        flags_obj.enable_get_next_as_optional
    )
96

97
  strategy_scope = distribution_utils.get_strategy_scope(strategy)
98

99
100
  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
101
    input_fn = common.get_synth_input_fn(
102
103
104
105
        height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_preprocessing.NUM_CHANNELS,
        num_classes=imagenet_preprocessing.NUM_CLASSES,
106
107
        dtype=dtype,
        drop_remainder=True)
108
  else:
109
    input_fn = imagenet_preprocessing.input_fn
110

111
112
113
114
  # When `enable_xla` is True, we always drop the remainder of the batches
  # in the dataset, as XLA-GPU doesn't support dynamic shapes.
  drop_remainder = flags_obj.enable_xla

115
116
117
118
119
120
121
  # Current resnet_model.resnet50 input format is always channel-last.
  # We use keras_application mobilenet model which input format is depends on
  # the keras beckend image data format.
  # This use_keras_image_data_format flags indicates whether image preprocessor
  # output format should be same as the keras backend image data format or just
  # channel-last format.
  use_keras_image_data_format = (flags_obj.model == 'mobilenet')
122
123
124
125
  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
126
127
      parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
          use_keras_image_data_format=use_keras_image_data_format),
Reed's avatar
Reed committed
128
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
129
      dtype=dtype,
130
131
      drop_remainder=drop_remainder,
      tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
132
      training_dataset_cache=flags_obj.training_dataset_cache,
133
  )
134

135
136
137
138
139
140
  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
141
142
        parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
            use_keras_image_data_format=use_keras_image_data_format),
143
144
        dtype=dtype,
        drop_remainder=drop_remainder)
145

146
147
148
149
150
151
152
  lr_schedule = common.PiecewiseConstantDecayWithWarmup(
      batch_size=flags_obj.batch_size,
      epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
      warmup_epochs=common.LR_SCHEDULE[0][1],
      boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
      multipliers=list(p[0] for p in common.LR_SCHEDULE),
      compute_lr_on_cpu=True)
153
154
  steps_per_epoch = (
      imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
155

Shining Sun's avatar
Shining Sun committed
156
  with strategy_scope:
157
158
159
160
    if flags_obj.optimizer == 'resnet50_default':
      optimizer = common.get_optimizer(lr_schedule)
    elif flags_obj.optimizer == 'mobilenet_default':
      initial_learning_rate = \
Jaehong Kim's avatar
Jaehong Kim committed
161
          flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size
162
163
164
      optimizer = tf.keras.optimizers.SGD(
          learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
              initial_learning_rate,
Jaehong Kim's avatar
Jaehong Kim committed
165
166
              decay_steps=steps_per_epoch * flags_obj.num_epochs_per_decay,
              decay_rate=flags_obj.lr_decay_factor,
167
168
              staircase=True),
          momentum=0.9)
169
    if flags_obj.fp16_implementation == 'graph_rewrite':
170
171
172
173
174
175
176
      # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
      # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
      # which will ensure tf.compat.v2.keras.mixed_precision and
      # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
      # up.
      optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
          optimizer)
177

Hongkun Yu's avatar
Hongkun Yu committed
178
    # TODO(hongkuny): Remove trivial model usage and move it to benchmark.
Haoyu Zhang's avatar
Haoyu Zhang committed
179
    if flags_obj.use_trivial_model:
Allen Wang's avatar
Allen Wang committed
180
      model = test_utils.trivial_model(imagenet_preprocessing.NUM_CLASSES)
181
    elif flags_obj.model == 'resnet50_v1.5':
182
      model = resnet_model.resnet50(
183
          num_classes=imagenet_preprocessing.NUM_CLASSES)
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    elif flags_obj.model == 'mobilenet':
      # TODO(kimjaehong): Remove layers attribute when minimum TF version
      # support 2.0 layers by default.
      model = tf.keras.applications.mobilenet.MobileNet(
          weights=None,
          classes=imagenet_preprocessing.NUM_CLASSES,
          layers=tf.keras.layers)
    if flags_obj.pretrained_filepath:
      model.load_weights(flags_obj.pretrained_filepath)

    if flags_obj.pruning_method == 'polynomial_decay':
      if dtype != tf.float32:
        raise NotImplementedError(
            'Pruning is currently only supported on dtype=tf.float32.')
      pruning_params = {
          'pruning_schedule':
              tfmot.sparsity.keras.PolynomialDecay(
                  initial_sparsity=flags_obj.pruning_initial_sparsity,
                  final_sparsity=flags_obj.pruning_final_sparsity,
                  begin_step=flags_obj.pruning_begin_step,
                  end_step=flags_obj.pruning_end_step,
                  frequency=flags_obj.pruning_frequency),
      }
      model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)
    elif flags_obj.pruning_method:
      raise NotImplementedError(
          'Only polynomial_decay is currently supported.')
211
212
213
214
215
216
217

    model.compile(
        loss='sparse_categorical_crossentropy',
        optimizer=optimizer,
        metrics=(['sparse_categorical_accuracy']
                 if flags_obj.report_accuracy_metrics else None),
        run_eagerly=flags_obj.run_eagerly)
Shining Sun's avatar
Shining Sun committed
218

Zongwei Zhou's avatar
Zongwei Zhou committed
219
220
  train_epochs = flags_obj.train_epochs

221
222
223
224
225
  callbacks = common.get_callbacks(
      steps_per_epoch=steps_per_epoch,
      pruning_method=flags_obj.pruning_method,
      enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export,
      model_dir=flags_obj.model_dir)
Shining Sun's avatar
Shining Sun committed
226

227
228
  # if mutliple epochs, ignore the train_steps flag.
  if train_epochs <= 1 and flags_obj.train_steps:
Zongwei Zhou's avatar
Zongwei Zhou committed
229
    steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
Shining Sun's avatar
Shining Sun committed
230
231
    train_epochs = 1

232
233
  num_eval_steps = (
      imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size)
Shining Sun's avatar
Shining Sun committed
234
235
236

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
237
238
239
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
240
241
242
243
    if flags_obj.set_learning_phase_to_train:
      # TODO(haoyuzhang): Understand slowdown of setting learning phase when
      # not using distribution strategy.
      tf.keras.backend.set_learning_phase(1)
Shining Sun's avatar
Shining Sun committed
244
245
246
    num_eval_steps = None
    validation_data = None

247
248
249
250
251
252
  if not strategy and flags_obj.explicit_gpu_placement:
    # TODO(b/135607227): Add device scope automatically in Keras training loop
    # when not using distribition strategy.
    no_dist_strat_device = tf.device('/device:GPU:0')
    no_dist_strat_device.__enter__()

253
254
  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
Zongwei Zhou's avatar
Zongwei Zhou committed
255
                      steps_per_epoch=steps_per_epoch,
256
                      callbacks=callbacks,
257
258
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
259
                      validation_freq=flags_obj.epochs_between_evals,
260
                      verbose=2)
Jaehong Kim's avatar
Jaehong Kim committed
261

Hongkun Yu's avatar
Hongkun Yu committed
262
263
264
265
266
267
  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)

268
269
270
271
272
273
274
275
276
277
  if flags_obj.pruning_method:
    model = tfmot.sparsity.keras.strip_pruning(model)
  if flags_obj.enable_checkpoint_and_export:
    if dtype == tf.bfloat16:
      logging.warning('Keras model.save does not support bfloat16 dtype.')
    else:
      # Keras model.save assumes a float32 input designature.
      export_path = os.path.join(flags_obj.model_dir, 'saved_model')
      model.save(export_path, include_optimizer=False)

278
279
280
  if not strategy and flags_obj.explicit_gpu_placement:
    no_dist_strat_device.__exit__()

281
  stats = common.build_stats(history, eval_output, callbacks)
282
  return stats
Shining Sun's avatar
bug fix  
Shining Sun committed
283

Shining Sun's avatar
Shining Sun committed
284

Toby Boyd's avatar
Toby Boyd committed
285
def define_imagenet_keras_flags():
286
287
288
289
290
  common.define_keras_flags(
      model=True,
      optimizer=True,
      pretrained_filepath=True)
  common.define_pruning_flags()
291
  flags_core.set_defaults()
292
  flags.adopt_module_key_flags(common)
Toby Boyd's avatar
Toby Boyd committed
293
294


295
def main(_):
296
  model_helpers.apply_clean(flags.FLAGS)
297
  with logger.benchmark_context(flags.FLAGS):
298
    stats = run(flags.FLAGS)
299
  logging.info('Run stats:\n%s', stats)
300
301
302


if __name__ == '__main__':
303
  logging.set_verbosity(logging.INFO)
Toby Boyd's avatar
Toby Boyd committed
304
  define_imagenet_keras_flags()
Hongkun Yu's avatar
Hongkun Yu committed
305
  app.run(main)