resnet_imagenet_main.py 13.7 KB
Newer Older
1
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a ResNet model on the ImageNet dataset."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

Hongkun Yu's avatar
Hongkun Yu committed
21
22
import os

Hongkun Yu's avatar
Hongkun Yu committed
23
# Import libraries
Hongkun Yu's avatar
Hongkun Yu committed
24
from absl import app
25
from absl import flags
26
from absl import logging
27
import tensorflow as tf
28
from official.common import distribute_utils
29
from official.modeling import performance
30
from official.utils.flags import core as flags_core
Toby Boyd's avatar
Toby Boyd committed
31
from official.utils.misc import keras_utils
32
from official.utils.misc import model_helpers
Allen Wang's avatar
Allen Wang committed
33
from official.vision.image_classification import test_utils
34
35
from official.vision.image_classification.resnet import common
from official.vision.image_classification.resnet import imagenet_preprocessing
36
from official.vision.image_classification.resnet import resnet_model
37
38


39
40
41
42
def _cluster_last_three_conv2d_layers(model):
  """Helper method to cluster last three conv2d layers."""
  import tensorflow_model_optimization as tfmot  # pylint: disable=g-import-not-at-top
  last_three_conv2d_layers = [
43
      layer for layer in model.layers
Ruomei Yan's avatar
Ruomei Yan committed
44
45
      if isinstance(layer, tf.keras.layers.Conv2D)
    ][-3:]
46
47

  cluster_weights = tfmot.clustering.keras.cluster_weights
48
  centroid_initialization = tfmot.clustering.keras.CentroidInitialization
49
50
51

  def cluster_fn(layer):
    if layer not in last_three_conv2d_layers:
Ruomei Yan's avatar
Ruomei Yan committed
52
      return layer
53

Ruomei Yan's avatar
Ruomei Yan committed
54
55
56
    if layer == last_three_conv2d_layers[0] or \
      layer == last_three_conv2d_layers[1]:
      clustered = cluster_weights(layer, number_of_clusters=256, \
57
58
          cluster_centroids_init=centroid_initialization.LINEAR)
      print('Clustered {} with 256 clusters'.format(layer.name))
59
    else:
Ruomei Yan's avatar
Ruomei Yan committed
60
      clustered = cluster_weights(layer, number_of_clusters=32, \
61
62
          cluster_centroids_init=centroid_initialization.LINEAR)
      print('Clustered {} with 32 clusters'.format(layer.name))
63
    return clustered
64

65
  return tf.keras.models.clone_model(model, clone_function=cluster_fn)
66
67


Shining Sun's avatar
Shining Sun committed
68
def run(flags_obj):
69
70
71
72
73
74
75
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
76
    NotImplementedError: If some features are not currently supported.
77
78
79

  Returns:
    Dictionary of training and eval stats.
80
  """
Toby Boyd's avatar
Toby Boyd committed
81
  keras_utils.set_session_config(
82
      enable_xla=flags_obj.enable_xla)
83
84
  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
85
86
87
88
89
    keras_utils.set_gpu_thread_mode_and_count(
        per_gpu_thread_count=flags_obj.per_gpu_thread_count,
        gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
        num_gpus=flags_obj.num_gpus,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads)
90
  common.set_cudnn_batchnorm_mode()
91

92
  dtype = flags_core.get_tf_dtype(flags_obj)
93
  performance.set_mixed_precision_policy(
94
      flags_core.get_tf_dtype(flags_obj))
95

96
97
  data_format = flags_obj.data_format
  if data_format is None:
98
99
    data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
                   else 'channels_last')
100
  tf.keras.backend.set_image_data_format(data_format)
101

102
  # Configures cluster spec for distribution strategy.
103
104
  _ = distribute_utils.configure_cluster(flags_obj.worker_hosts,
                                         flags_obj.task_index)
105

106
  strategy = distribute_utils.get_distribution_strategy(
107
108
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
109
      all_reduce_alg=flags_obj.all_reduce_alg,
110
111
      num_packs=flags_obj.num_packs,
      tpu_address=flags_obj.tpu)
112

rxsang's avatar
rxsang committed
113
114
115
116
117
118
119
  if strategy:
    # flags_obj.enable_get_next_as_optional controls whether enabling
    # get_next_as_optional behavior in DistributedIterator. If true, last
    # partial batch can be supported.
    strategy.extended.experimental_enable_get_next_as_optional = (
        flags_obj.enable_get_next_as_optional
    )
120

121
  strategy_scope = distribute_utils.get_strategy_scope(strategy)
122

123
124
  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
125
    input_fn = common.get_synth_input_fn(
126
127
128
129
        height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_preprocessing.NUM_CHANNELS,
        num_classes=imagenet_preprocessing.NUM_CLASSES,
130
131
        dtype=dtype,
        drop_remainder=True)
132
  else:
133
    input_fn = imagenet_preprocessing.input_fn
134

135
136
137
138
  # When `enable_xla` is True, we always drop the remainder of the batches
  # in the dataset, as XLA-GPU doesn't support dynamic shapes.
  drop_remainder = flags_obj.enable_xla

139
140
141
142
143
144
  # Current resnet_model.resnet50 input format is always channel-last.
  # We use keras_application mobilenet model which input format is depends on
  # the keras beckend image data format.
  # This use_keras_image_data_format flags indicates whether image preprocessor
  # output format should be same as the keras backend image data format or just
  # channel-last format.
Ruomei Yan's avatar
Ruomei Yan committed
145
  use_keras_image_data_format = \
146
    (flags_obj.model == 'mobilenet' or flags_obj.model == 'mobilenet_pretrained')
Alan Chiao's avatar
Alan Chiao committed
147

148
149
150
151
  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
152
153
      parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
          use_keras_image_data_format=use_keras_image_data_format),
Reed's avatar
Reed committed
154
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
155
      dtype=dtype,
156
157
      drop_remainder=drop_remainder,
      tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
158
      training_dataset_cache=flags_obj.training_dataset_cache,
159
  )
160

161
162
163
164
165
166
  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
167
168
        parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
            use_keras_image_data_format=use_keras_image_data_format),
169
170
        dtype=dtype,
        drop_remainder=drop_remainder)
171

172
173
174
175
176
177
178
  lr_schedule = common.PiecewiseConstantDecayWithWarmup(
      batch_size=flags_obj.batch_size,
      epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
      warmup_epochs=common.LR_SCHEDULE[0][1],
      boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
      multipliers=list(p[0] for p in common.LR_SCHEDULE),
      compute_lr_on_cpu=True)
Ruomei Yan's avatar
Ruomei Yan committed
179
180
  steps_per_epoch = (
      imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
181

Shining Sun's avatar
Shining Sun committed
182
  with strategy_scope:
183
184
    if flags_obj.optimizer == 'resnet50_default':
      optimizer = common.get_optimizer(lr_schedule)
Alan Chiao's avatar
Alan Chiao committed
185
    elif flags_obj.optimizer == 'mobilenet_default' or flags_obj.optimizer == 'mobilenet_fine_tune':
186
      initial_learning_rate = \
Jaehong Kim's avatar
Jaehong Kim committed
187
          flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size
Ruomei Yan's avatar
Ruomei Yan committed
188
189
      if flags_obj.optimizer == 'mobilenet_fine_tune':
        initial_learning_rate = 1e-5
190
191
192
      optimizer = tf.keras.optimizers.SGD(
          learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
              initial_learning_rate,
Jaehong Kim's avatar
Jaehong Kim committed
193
194
              decay_steps=steps_per_epoch * flags_obj.num_epochs_per_decay,
              decay_rate=flags_obj.lr_decay_factor,
195
196
              staircase=True),
          momentum=0.9)
197
198
199
200
201
    optimizer = performance.configure_optimizer(
        optimizer,
        use_float16=flags_core.get_tf_dtype(flags_obj) == tf.float16,
        use_graph_rewrite=flags_obj.fp16_implementation == 'graph_rewrite',
        loss_scale=flags_core.get_loss_scale(flags_obj, default_for_fp16=128),)
202

Hongkun Yu's avatar
Hongkun Yu committed
203
    # TODO(hongkuny): Remove trivial model usage and move it to benchmark.
Haoyu Zhang's avatar
Haoyu Zhang committed
204
    if flags_obj.use_trivial_model:
Allen Wang's avatar
Allen Wang committed
205
      model = test_utils.trivial_model(imagenet_preprocessing.NUM_CLASSES)
206
    elif flags_obj.model == 'resnet50_v1.5':
207
      model = resnet_model.resnet50(
208
          num_classes=imagenet_preprocessing.NUM_CLASSES)
Alan Chiao's avatar
Alan Chiao committed
209
    elif flags_obj.model == 'mobilenet' or flags_obj.model == 'mobilenet_pretrained':
210
211
      # TODO(kimjaehong): Remove layers attribute when minimum TF version
      # support 2.0 layers by default.
Ruomei Yan's avatar
Ruomei Yan committed
212
213
214
215
216
217
      if flags_obj.model == 'mobilenet_pretrained':
        classes_labels = 1000
        initial_weights = 'imagenet'
      else:
        classes_labels = imagenet_preprocessing.NUM_CLASSES
        initial_weights = None
218
      model = tf.keras.applications.mobilenet.MobileNet(
Ruomei Yan's avatar
Ruomei Yan committed
219
220
          weights=initial_weights,
          classes=classes_labels,
221
222
          layers=tf.keras.layers)

223
224
225
226
    if flags_obj.pretrained_filepath:
      model.load_weights(flags_obj.pretrained_filepath)

    if flags_obj.pruning_method == 'polynomial_decay':
227
      import tensorflow_model_optimization as tfmot  # pylint: disable=g-import-not-at-top
228
229
230
231
232
233
234
235
236
237
238
239
240
241
      if dtype != tf.float32:
        raise NotImplementedError(
            'Pruning is currently only supported on dtype=tf.float32.')
      pruning_params = {
          'pruning_schedule':
              tfmot.sparsity.keras.PolynomialDecay(
                  initial_sparsity=flags_obj.pruning_initial_sparsity,
                  final_sparsity=flags_obj.pruning_final_sparsity,
                  begin_step=flags_obj.pruning_begin_step,
                  end_step=flags_obj.pruning_end_step,
                  frequency=flags_obj.pruning_frequency),
      }
      model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)
    elif flags_obj.pruning_method:
242
243
244
      raise NotImplementedError('Only polynomial_decay is currently supported.')

    if flags_obj.clustering_method == 'selective_clustering':
245
      import tensorflow_model_optimization as tfmot  # pylint: disable=g-import-not-at-top
Ruomei Yan's avatar
Ruomei Yan committed
246
247
      if dtype != tf.float32 or \
        flags_obj.fp16_implementation == 'graph_rewrite':
248
249
        raise NotImplementedError(
            'Clustering is currently only supported on dtype=tf.float32.')
250
      model = _cluster_last_three_conv2d_layers(model)
251
    elif flags_obj.clustering_method:
252
      raise NotImplementedError(
253
          'Only selective_clustering is implemented.')
254

Ruomei Yan's avatar
Ruomei Yan committed
255
256
257
258
259
260
    model.compile(
        loss='sparse_categorical_crossentropy',
        optimizer=optimizer,
        metrics=(['sparse_categorical_accuracy']
                 if flags_obj.report_accuracy_metrics else None),
        run_eagerly=flags_obj.run_eagerly)
Shining Sun's avatar
Shining Sun committed
261

Zongwei Zhou's avatar
Zongwei Zhou committed
262
263
  train_epochs = flags_obj.train_epochs

264
265
266
267
  callbacks = common.get_callbacks(
      pruning_method=flags_obj.pruning_method,
      enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export,
      model_dir=flags_obj.model_dir)
Shining Sun's avatar
Shining Sun committed
268

269
  # If mutliple epochs, ignore the train_steps flag.
270
  if train_epochs <= 1 and flags_obj.train_steps:
Zongwei Zhou's avatar
Zongwei Zhou committed
271
    steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
Shining Sun's avatar
Shining Sun committed
272
273
    train_epochs = 1

Ruomei Yan's avatar
Ruomei Yan committed
274
275
  num_eval_steps = (
      imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size)
Shining Sun's avatar
Shining Sun committed
276
277
278

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
279
280
281
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
282
283
284
285
    if flags_obj.set_learning_phase_to_train:
      # TODO(haoyuzhang): Understand slowdown of setting learning phase when
      # not using distribution strategy.
      tf.keras.backend.set_learning_phase(1)
Shining Sun's avatar
Shining Sun committed
286
287
288
    num_eval_steps = None
    validation_data = None

289
290
  if not strategy and flags_obj.explicit_gpu_placement:
    # TODO(b/135607227): Add device scope automatically in Keras training loop
291
    # when not using distribution strategy.
292
293
294
    no_dist_strat_device = tf.device('/device:GPU:0')
    no_dist_strat_device.__enter__()

295
296
  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
Zongwei Zhou's avatar
Zongwei Zhou committed
297
                      steps_per_epoch=steps_per_epoch,
298
                      callbacks=callbacks,
299
300
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
301
                      validation_freq=flags_obj.epochs_between_evals,
302
                      verbose=2)
Jaehong Kim's avatar
Jaehong Kim committed
303

Hongkun Yu's avatar
Hongkun Yu committed
304
305
306
307
308
309
  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)

310
311
  if flags_obj.pruning_method:
    model = tfmot.sparsity.keras.strip_pruning(model)
312
313

  if flags_obj.clustering_method:
314
    model = tfmot.clustering.keras.strip_clustering(model)
315

316
317
318
319
320
321
322
323
  if flags_obj.enable_checkpoint_and_export:
    if dtype == tf.bfloat16:
      logging.warning('Keras model.save does not support bfloat16 dtype.')
    else:
      # Keras model.save assumes a float32 input designature.
      export_path = os.path.join(flags_obj.model_dir, 'saved_model')
      model.save(export_path, include_optimizer=False)

324
325
326
  if not strategy and flags_obj.explicit_gpu_placement:
    no_dist_strat_device.__exit__()

327
  stats = common.build_stats(history, eval_output, callbacks)
328
  return stats
Shining Sun's avatar
bug fix  
Shining Sun committed
329

Shining Sun's avatar
Shining Sun committed
330

Toby Boyd's avatar
Toby Boyd committed
331
def define_imagenet_keras_flags():
Ruomei Yan's avatar
Ruomei Yan committed
332
333
334
335
  common.define_keras_flags(
      model=True,
      optimizer=True,
      pretrained_filepath=True)
336
  common.define_pruning_flags()
337
  common.define_clustering_flags()
338
  flags_core.set_defaults()
339
  flags.adopt_module_key_flags(common)
Toby Boyd's avatar
Toby Boyd committed
340
341


342
def main(_):
343
  model_helpers.apply_clean(flags.FLAGS)
344
  stats = run(flags.FLAGS)
345
  logging.info('Run stats:\n%s', stats)
346
347
348


if __name__ == '__main__':
349
  logging.set_verbosity(logging.INFO)
Toby Boyd's avatar
Toby Boyd committed
350
  define_imagenet_keras_flags()
Ruomei Yan's avatar
Ruomei Yan committed
351
  app.run(main)