resnet_imagenet_main.py 13.9 KB
Newer Older
1
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a ResNet model on the ImageNet dataset."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

Hongkun Yu's avatar
Hongkun Yu committed
21
22
import os

Hongkun Yu's avatar
Hongkun Yu committed
23
# Import libraries
Hongkun Yu's avatar
Hongkun Yu committed
24
from absl import app
25
from absl import flags
26
from absl import logging
27
import tensorflow as tf
28

29
from official.modeling import performance
30
31
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
Toby Boyd's avatar
Toby Boyd committed
32
from official.utils.misc import keras_utils
33
from official.utils.misc import model_helpers
Allen Wang's avatar
Allen Wang committed
34
from official.vision.image_classification import test_utils
35
36
from official.vision.image_classification.resnet import common
from official.vision.image_classification.resnet import imagenet_preprocessing
37
from official.vision.image_classification.resnet import resnet_model
38
39


40
41
42
43
def _cluster_last_three_conv2d_layers(model):
  """Helper method to cluster last three conv2d layers."""
  import tensorflow_model_optimization as tfmot  # pylint: disable=g-import-not-at-top
  last_three_conv2d_layers = [
44
      layer for layer in model.layers
Ruomei Yan's avatar
Ruomei Yan committed
45
46
      if isinstance(layer, tf.keras.layers.Conv2D)
    ][-3:]
47
48

  cluster_weights = tfmot.clustering.keras.cluster_weights
49
  centroid_initialization = tfmot.clustering.keras.CentroidInitialization
50
51
52

  def cluster_fn(layer):
    if layer not in last_three_conv2d_layers:
Ruomei Yan's avatar
Ruomei Yan committed
53
      return layer
54

Ruomei Yan's avatar
Ruomei Yan committed
55
56
57
    if layer == last_three_conv2d_layers[0] or \
      layer == last_three_conv2d_layers[1]:
      clustered = cluster_weights(layer, number_of_clusters=256, \
58
59
          cluster_centroids_init=centroid_initialization.LINEAR)
      print('Clustered {} with 256 clusters'.format(layer.name))
60
    else:
Ruomei Yan's avatar
Ruomei Yan committed
61
      clustered = cluster_weights(layer, number_of_clusters=32, \
62
63
          cluster_centroids_init=centroid_initialization.LINEAR)
      print('Clustered {} with 32 clusters'.format(layer.name))
64
    return clustered
65

66
  return tf.keras.models.clone_model(model, clone_function=cluster_fn)
67
68


Shining Sun's avatar
Shining Sun committed
69
def run(flags_obj):
70
71
72
73
74
75
76
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
77
    NotImplementedError: If some features are not currently supported.
78
79
80

  Returns:
    Dictionary of training and eval stats.
81
  """
Toby Boyd's avatar
Toby Boyd committed
82
  keras_utils.set_session_config(
83
      enable_xla=flags_obj.enable_xla)
84
85
  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
86
87
88
89
90
    keras_utils.set_gpu_thread_mode_and_count(
        per_gpu_thread_count=flags_obj.per_gpu_thread_count,
        gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
        num_gpus=flags_obj.num_gpus,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads)
91
  common.set_cudnn_batchnorm_mode()
92

93
  dtype = flags_core.get_tf_dtype(flags_obj)
94
95
96
  performance.set_mixed_precision_policy(
      flags_core.get_tf_dtype(flags_obj),
      flags_core.get_loss_scale(flags_obj, default_for_fp16=128))
97

98
99
  data_format = flags_obj.data_format
  if data_format is None:
100
101
    data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
                   else 'channels_last')
102
  tf.keras.backend.set_image_data_format(data_format)
103

104
  # Configures cluster spec for distribution strategy.
105
106
  _ = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                           flags_obj.task_index)
107

108
109
110
  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
111
      all_reduce_alg=flags_obj.all_reduce_alg,
112
113
      num_packs=flags_obj.num_packs,
      tpu_address=flags_obj.tpu)
114

rxsang's avatar
rxsang committed
115
116
117
118
119
120
121
  if strategy:
    # flags_obj.enable_get_next_as_optional controls whether enabling
    # get_next_as_optional behavior in DistributedIterator. If true, last
    # partial batch can be supported.
    strategy.extended.experimental_enable_get_next_as_optional = (
        flags_obj.enable_get_next_as_optional
    )
122

123
  strategy_scope = distribution_utils.get_strategy_scope(strategy)
124

125
126
  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
127
    input_fn = common.get_synth_input_fn(
128
129
130
131
        height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_preprocessing.NUM_CHANNELS,
        num_classes=imagenet_preprocessing.NUM_CLASSES,
132
133
        dtype=dtype,
        drop_remainder=True)
134
  else:
135
    input_fn = imagenet_preprocessing.input_fn
136

137
138
139
140
  # When `enable_xla` is True, we always drop the remainder of the batches
  # in the dataset, as XLA-GPU doesn't support dynamic shapes.
  drop_remainder = flags_obj.enable_xla

141
142
143
144
145
146
  # Current resnet_model.resnet50 input format is always channel-last.
  # We use keras_application mobilenet model which input format is depends on
  # the keras beckend image data format.
  # This use_keras_image_data_format flags indicates whether image preprocessor
  # output format should be same as the keras backend image data format or just
  # channel-last format.
Ruomei Yan's avatar
Ruomei Yan committed
147
148
  use_keras_image_data_format = \
    (flags_obj.model == 'mobilenet' or 'mobilenet_pretrained')
149
150
151
152
  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
153
154
      parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
          use_keras_image_data_format=use_keras_image_data_format),
Reed's avatar
Reed committed
155
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
156
      dtype=dtype,
157
158
      drop_remainder=drop_remainder,
      tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
159
      training_dataset_cache=flags_obj.training_dataset_cache,
160
  )
161

162
163
164
165
166
167
  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
168
169
        parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
            use_keras_image_data_format=use_keras_image_data_format),
170
171
        dtype=dtype,
        drop_remainder=drop_remainder)
172

173
174
175
176
177
178
179
  lr_schedule = common.PiecewiseConstantDecayWithWarmup(
      batch_size=flags_obj.batch_size,
      epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
      warmup_epochs=common.LR_SCHEDULE[0][1],
      boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
      multipliers=list(p[0] for p in common.LR_SCHEDULE),
      compute_lr_on_cpu=True)
Ruomei Yan's avatar
Ruomei Yan committed
180
181
  steps_per_epoch = (
      imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
182

Shining Sun's avatar
Shining Sun committed
183
  with strategy_scope:
184
185
    if flags_obj.optimizer == 'resnet50_default':
      optimizer = common.get_optimizer(lr_schedule)
Ruomei Yan's avatar
Ruomei Yan committed
186
    elif flags_obj.optimizer == 'mobilenet_default' or 'mobilenet_fine_tune':
187
      initial_learning_rate = \
Jaehong Kim's avatar
Jaehong Kim committed
188
          flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size
Ruomei Yan's avatar
Ruomei Yan committed
189
190
      if flags_obj.optimizer == 'mobilenet_fine_tune':
        initial_learning_rate = 1e-5
191
192
193
      optimizer = tf.keras.optimizers.SGD(
          learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
              initial_learning_rate,
Jaehong Kim's avatar
Jaehong Kim committed
194
195
              decay_steps=steps_per_epoch * flags_obj.num_epochs_per_decay,
              decay_rate=flags_obj.lr_decay_factor,
196
197
              staircase=True),
          momentum=0.9)
198

199
    if flags_obj.fp16_implementation == 'graph_rewrite':
200
201
202
203
204
205
206
      # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
      # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
      # which will ensure tf.compat.v2.keras.mixed_precision and
      # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
      # up.
      optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
          optimizer)
207

Hongkun Yu's avatar
Hongkun Yu committed
208
    # TODO(hongkuny): Remove trivial model usage and move it to benchmark.
Haoyu Zhang's avatar
Haoyu Zhang committed
209
    if flags_obj.use_trivial_model:
Allen Wang's avatar
Allen Wang committed
210
      model = test_utils.trivial_model(imagenet_preprocessing.NUM_CLASSES)
211
    elif flags_obj.model == 'resnet50_v1.5':
212
      model = resnet_model.resnet50(
213
          num_classes=imagenet_preprocessing.NUM_CLASSES)
Ruomei Yan's avatar
Ruomei Yan committed
214
    elif flags_obj.model == 'mobilenet' or 'mobilenet_pretrained':
215
216
      # TODO(kimjaehong): Remove layers attribute when minimum TF version
      # support 2.0 layers by default.
Ruomei Yan's avatar
Ruomei Yan committed
217
218
219
220
221
222
      if flags_obj.model == 'mobilenet_pretrained':
        classes_labels = 1000
        initial_weights = 'imagenet'
      else:
        classes_labels = imagenet_preprocessing.NUM_CLASSES
        initial_weights = None
223
      model = tf.keras.applications.mobilenet.MobileNet(
Ruomei Yan's avatar
Ruomei Yan committed
224
225
          weights=initial_weights,
          classes=classes_labels,
226
227
          layers=tf.keras.layers)

228
229
230
231
    if flags_obj.pretrained_filepath:
      model.load_weights(flags_obj.pretrained_filepath)

    if flags_obj.pruning_method == 'polynomial_decay':
232
      import tensorflow_model_optimization as tfmot  # pylint: disable=g-import-not-at-top
233
234
235
236
237
238
239
240
241
242
243
244
245
246
      if dtype != tf.float32:
        raise NotImplementedError(
            'Pruning is currently only supported on dtype=tf.float32.')
      pruning_params = {
          'pruning_schedule':
              tfmot.sparsity.keras.PolynomialDecay(
                  initial_sparsity=flags_obj.pruning_initial_sparsity,
                  final_sparsity=flags_obj.pruning_final_sparsity,
                  begin_step=flags_obj.pruning_begin_step,
                  end_step=flags_obj.pruning_end_step,
                  frequency=flags_obj.pruning_frequency),
      }
      model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)
    elif flags_obj.pruning_method:
247
248
249
      raise NotImplementedError('Only polynomial_decay is currently supported.')

    if flags_obj.clustering_method == 'selective_clustering':
250
      import tensorflow_model_optimization as tfmot  # pylint: disable=g-import-not-at-top
Ruomei Yan's avatar
Ruomei Yan committed
251
252
      if dtype != tf.float32 or \
        flags_obj.fp16_implementation == 'graph_rewrite':
253
254
        raise NotImplementedError(
            'Clustering is currently only supported on dtype=tf.float32.')
255
      model = _cluster_last_three_conv2d_layers(model)
256
    elif flags_obj.clustering_method:
257
      raise NotImplementedError(
258
          'Only selective_clustering is implemented.')
259

Ruomei Yan's avatar
Ruomei Yan committed
260
261
262
263
264
265
    model.compile(
        loss='sparse_categorical_crossentropy',
        optimizer=optimizer,
        metrics=(['sparse_categorical_accuracy']
                 if flags_obj.report_accuracy_metrics else None),
        run_eagerly=flags_obj.run_eagerly)
Shining Sun's avatar
Shining Sun committed
266

Zongwei Zhou's avatar
Zongwei Zhou committed
267
268
  train_epochs = flags_obj.train_epochs

269
270
271
272
  callbacks = common.get_callbacks(
      pruning_method=flags_obj.pruning_method,
      enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export,
      model_dir=flags_obj.model_dir)
Shining Sun's avatar
Shining Sun committed
273

274
  # If mutliple epochs, ignore the train_steps flag.
275
  if train_epochs <= 1 and flags_obj.train_steps:
Zongwei Zhou's avatar
Zongwei Zhou committed
276
    steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
Shining Sun's avatar
Shining Sun committed
277
278
    train_epochs = 1

Ruomei Yan's avatar
Ruomei Yan committed
279
280
  num_eval_steps = (
      imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size)
Shining Sun's avatar
Shining Sun committed
281
282
283

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
284
285
286
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
287
288
289
290
    if flags_obj.set_learning_phase_to_train:
      # TODO(haoyuzhang): Understand slowdown of setting learning phase when
      # not using distribution strategy.
      tf.keras.backend.set_learning_phase(1)
Shining Sun's avatar
Shining Sun committed
291
292
293
    num_eval_steps = None
    validation_data = None

294
295
  if not strategy and flags_obj.explicit_gpu_placement:
    # TODO(b/135607227): Add device scope automatically in Keras training loop
296
    # when not using distribution strategy.
297
298
299
    no_dist_strat_device = tf.device('/device:GPU:0')
    no_dist_strat_device.__enter__()

300
301
  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
Zongwei Zhou's avatar
Zongwei Zhou committed
302
                      steps_per_epoch=steps_per_epoch,
303
                      callbacks=callbacks,
304
305
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
306
                      validation_freq=flags_obj.epochs_between_evals,
307
                      verbose=2)
Jaehong Kim's avatar
Jaehong Kim committed
308

Hongkun Yu's avatar
Hongkun Yu committed
309
310
311
312
313
314
  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)

315
316
  if flags_obj.pruning_method:
    model = tfmot.sparsity.keras.strip_pruning(model)
317
318

  if flags_obj.clustering_method:
319
    model = tfmot.clustering.keras.strip_clustering(model)
320

321
322
323
324
325
326
327
328
  if flags_obj.enable_checkpoint_and_export:
    if dtype == tf.bfloat16:
      logging.warning('Keras model.save does not support bfloat16 dtype.')
    else:
      # Keras model.save assumes a float32 input designature.
      export_path = os.path.join(flags_obj.model_dir, 'saved_model')
      model.save(export_path, include_optimizer=False)

329
330
331
  if not strategy and flags_obj.explicit_gpu_placement:
    no_dist_strat_device.__exit__()

332
  stats = common.build_stats(history, eval_output, callbacks)
333
  return stats
Shining Sun's avatar
bug fix  
Shining Sun committed
334

Shining Sun's avatar
Shining Sun committed
335

Toby Boyd's avatar
Toby Boyd committed
336
def define_imagenet_keras_flags():
Ruomei Yan's avatar
Ruomei Yan committed
337
338
339
340
  common.define_keras_flags(
      model=True,
      optimizer=True,
      pretrained_filepath=True)
341
  common.define_pruning_flags()
342
  common.define_clustering_flags()
343
  flags_core.set_defaults()
344
  flags.adopt_module_key_flags(common)
Toby Boyd's avatar
Toby Boyd committed
345
346


347
def main(_):
348
  model_helpers.apply_clean(flags.FLAGS)
349
  stats = run(flags.FLAGS)
350
  logging.info('Run stats:\n%s', stats)
351
352
353


if __name__ == '__main__':
354
  logging.set_verbosity(logging.INFO)
Toby Boyd's avatar
Toby Boyd committed
355
  define_imagenet_keras_flags()
Ruomei Yan's avatar
Ruomei Yan committed
356
  app.run(main)