resnet_imagenet_main.py 15.7 KB
Newer Older
1
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a ResNet model on the ImageNet dataset."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

Hongkun Yu's avatar
Hongkun Yu committed
21
22
import os

Hongkun Yu's avatar
Hongkun Yu committed
23
# Import libraries
Hongkun Yu's avatar
Hongkun Yu committed
24
from absl import app
25
from absl import flags
26
from absl import logging
27
import tensorflow as tf
28

29
from tensorflow_model_optimization.python.core.clustering.keras import cluster
30
import tensorflow_model_optimization as tfmot
31
from official.modeling import performance
32
33
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
Toby Boyd's avatar
Toby Boyd committed
34
from official.utils.misc import keras_utils
35
from official.utils.misc import model_helpers
Allen Wang's avatar
Allen Wang committed
36
from official.vision.image_classification import test_utils
37
38
from official.vision.image_classification.resnet import common
from official.vision.image_classification.resnet import imagenet_preprocessing
39
from official.vision.image_classification.resnet import resnet_model
40
41


42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def selective_layers_to_cluster(model):
  last_3conv2d_layers_to_cluster = [
      layer.name
      for layer in model.layers
      if isinstance(layer, tf.keras.layers.Conv2D) and
      not isinstance(layer, tf.keras.layers.DepthwiseConv2D)
  ]
  last_3conv2d_layers_to_cluster = last_3conv2d_layers_to_cluster[-3:]
  return last_3conv2d_layers_to_cluster


def selective_clustering_clone_wrapper(clustering_params1, clustering_params2,
                                       model):

  def apply_clustering_to_conv2d_but_depthwise(layer):
    layers_list = selective_layers_to_cluster(model)
    if layer.name in layers_list:
      if layer.name != layers_list[-1]:
        print("Wrapped layer " + layer.name +
              " with " +
              str(clustering_params1["number_of_clusters"]) + " clusters.")
        return cluster.cluster_weights(layer, **clustering_params1)
      else:
        print("Wrapped layer " + layer.name +
66
              " with " +
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
              str(clustering_params2["number_of_clusters"]) + " clusters.")
        return cluster.cluster_weights(layer, **clustering_params2)
    return layer

  return apply_clustering_to_conv2d_but_depthwise


def cluster_model_selectively(model, selective_layers_to_cluster,
                              clustering_params1, clustering_params2):
  result_layer_model = tf.keras.models.clone_model(
      model,
      clone_function=selective_clustering_clone_wrapper(clustering_params1,
                                                        clustering_params2,
                                                        model),
  )
  return result_layer_model


def get_selectively_clustered_model(model, clustering_params1,
                                    clustering_params2):
  clustered_model = cluster_model_selectively(model,
                                              selective_layers_to_cluster,
                                              clustering_params1,
                                              clustering_params2)
  return clustered_model


Shining Sun's avatar
Shining Sun committed
94
def run(flags_obj):
95
96
97
98
99
100
101
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
102
    NotImplementedError: If some features are not currently supported.
103
104
105

  Returns:
    Dictionary of training and eval stats.
106
  """
Toby Boyd's avatar
Toby Boyd committed
107
  keras_utils.set_session_config(
108
      enable_xla=flags_obj.enable_xla)
109
110
  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
111
112
113
114
115
    keras_utils.set_gpu_thread_mode_and_count(
        per_gpu_thread_count=flags_obj.per_gpu_thread_count,
        gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
        num_gpus=flags_obj.num_gpus,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads)
116
  common.set_cudnn_batchnorm_mode()
117

118
  dtype = flags_core.get_tf_dtype(flags_obj)
119
120
121
  performance.set_mixed_precision_policy(
      flags_core.get_tf_dtype(flags_obj),
      flags_core.get_loss_scale(flags_obj, default_for_fp16=128))
122

123
124
  data_format = flags_obj.data_format
  if data_format is None:
125
126
    data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
                   else 'channels_last')
127
  tf.keras.backend.set_image_data_format(data_format)
128

129
  # Configures cluster spec for distribution strategy.
130
131
  _ = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                           flags_obj.task_index)
132

133
134
135
  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
136
      all_reduce_alg=flags_obj.all_reduce_alg,
137
138
      num_packs=flags_obj.num_packs,
      tpu_address=flags_obj.tpu)
139

rxsang's avatar
rxsang committed
140
141
142
143
144
145
146
  if strategy:
    # flags_obj.enable_get_next_as_optional controls whether enabling
    # get_next_as_optional behavior in DistributedIterator. If true, last
    # partial batch can be supported.
    strategy.extended.experimental_enable_get_next_as_optional = (
        flags_obj.enable_get_next_as_optional
    )
147

148
  strategy_scope = distribution_utils.get_strategy_scope(strategy)
149

150
151
  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
152
    input_fn = common.get_synth_input_fn(
153
154
155
156
        height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_preprocessing.NUM_CHANNELS,
        num_classes=imagenet_preprocessing.NUM_CLASSES,
157
158
        dtype=dtype,
        drop_remainder=True)
159
  else:
160
    input_fn = imagenet_preprocessing.input_fn
161

162
163
164
165
  # When `enable_xla` is True, we always drop the remainder of the batches
  # in the dataset, as XLA-GPU doesn't support dynamic shapes.
  drop_remainder = flags_obj.enable_xla

166
167
168
169
170
171
  # Current resnet_model.resnet50 input format is always channel-last.
  # We use keras_application mobilenet model which input format is depends on
  # the keras beckend image data format.
  # This use_keras_image_data_format flags indicates whether image preprocessor
  # output format should be same as the keras backend image data format or just
  # channel-last format.
172
  use_keras_image_data_format = (flags_obj.model == 'mobilenet' or 'mobilenet_pretrained')
173
174
175
176
  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
177
178
      parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
          use_keras_image_data_format=use_keras_image_data_format),
Reed's avatar
Reed committed
179
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
180
      dtype=dtype,
181
182
      drop_remainder=drop_remainder,
      tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
183
      training_dataset_cache=flags_obj.training_dataset_cache,
184
  )
185

186
187
188
189
190
191
  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
192
193
        parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
            use_keras_image_data_format=use_keras_image_data_format),
194
195
        dtype=dtype,
        drop_remainder=drop_remainder)
196

197
198
199
200
201
202
203
  lr_schedule = common.PiecewiseConstantDecayWithWarmup(
      batch_size=flags_obj.batch_size,
      epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
      warmup_epochs=common.LR_SCHEDULE[0][1],
      boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
      multipliers=list(p[0] for p in common.LR_SCHEDULE),
      compute_lr_on_cpu=True)
204
205
  steps_per_epoch = (imagenet_preprocessing.NUM_IMAGES['train'] //
                     flags_obj.batch_size)
206

Shining Sun's avatar
Shining Sun committed
207
  with strategy_scope:
208
209
210
211
    if flags_obj.optimizer == 'resnet50_default':
      optimizer = common.get_optimizer(lr_schedule)
    elif flags_obj.optimizer == 'mobilenet_default':
      initial_learning_rate = \
Jaehong Kim's avatar
Jaehong Kim committed
212
          flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size
213
214
215
      optimizer = tf.keras.optimizers.SGD(
          learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
              initial_learning_rate,
Jaehong Kim's avatar
Jaehong Kim committed
216
217
              decay_steps=steps_per_epoch * flags_obj.num_epochs_per_decay,
              decay_rate=flags_obj.lr_decay_factor,
218
219
              staircase=True),
          momentum=0.9)
220
221
222
    elif flags_obj.optimizer == 'mobilenet_fine_tune':
      optimizer = tf.keras.optimizers.SGD(learning_rate=1e-5, momentum=0.9)

223
    if flags_obj.fp16_implementation == 'graph_rewrite':
224
225
226
227
228
229
230
      # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
      # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
      # which will ensure tf.compat.v2.keras.mixed_precision and
      # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
      # up.
      optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
          optimizer)
231

Hongkun Yu's avatar
Hongkun Yu committed
232
    # TODO(hongkuny): Remove trivial model usage and move it to benchmark.
Haoyu Zhang's avatar
Haoyu Zhang committed
233
    if flags_obj.use_trivial_model:
Allen Wang's avatar
Allen Wang committed
234
      model = test_utils.trivial_model(imagenet_preprocessing.NUM_CLASSES)
235
    elif flags_obj.model == 'resnet50_v1.5':
236
      model = resnet_model.resnet50(
237
          num_classes=imagenet_preprocessing.NUM_CLASSES)
238
239
240
241
242
243
244
    elif flags_obj.model == 'mobilenet':
      # TODO(kimjaehong): Remove layers attribute when minimum TF version
      # support 2.0 layers by default.
      model = tf.keras.applications.mobilenet.MobileNet(
          weights=None,
          classes=imagenet_preprocessing.NUM_CLASSES,
          layers=tf.keras.layers)
245
    elif flags_obj.model == 'mobilenet_pretrained':
246
      shape = (224, 224, 3)
247
248
249
250
251
252
253
254
255
256
257
258
      model = tf.keras.applications.mobilenet.MobileNet(
          input_shape=shape,
          alpha=1.0,
          depth_multiplier=1,
          dropout=1e-7,
          include_top=True,
          weights='imagenet',
          input_tensor=tf.keras.layers.Input(shape),
          pooling=None,
          classes=1000,
          layers=tf.keras.layers)

259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
    if flags_obj.pretrained_filepath:
      model.load_weights(flags_obj.pretrained_filepath)

    if flags_obj.pruning_method == 'polynomial_decay':
      if dtype != tf.float32:
        raise NotImplementedError(
            'Pruning is currently only supported on dtype=tf.float32.')
      pruning_params = {
          'pruning_schedule':
              tfmot.sparsity.keras.PolynomialDecay(
                  initial_sparsity=flags_obj.pruning_initial_sparsity,
                  final_sparsity=flags_obj.pruning_final_sparsity,
                  begin_step=flags_obj.pruning_begin_step,
                  end_step=flags_obj.pruning_end_step,
                  frequency=flags_obj.pruning_frequency),
      }
      model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)
    elif flags_obj.pruning_method:
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
      raise NotImplementedError('Only polynomial_decay is currently supported.')

    if flags_obj.clustering_method == 'selective_clustering':
      if dtype != tf.float32:
        raise NotImplementedError(
            'Clustering is currently only supported on dtype=tf.float32.')
      clustering_params1 = {
          'number_of_clusters': flags_obj.number_of_clusters,
          'cluster_centroids_init': 'linear'
      }
      clustering_params2 = {
          'number_of_clusters': 32,
          'cluster_centroids_init': 'linear'
      }
      model = get_selectively_clustered_model(model, clustering_params1,
                                              clustering_params2)
    elif flags_obj.clustering_method:
294
      raise NotImplementedError(
295
          'Only selective_clustering is implemented.')
296

297
298
299
300
301
    model.compile(loss='sparse_categorical_crossentropy',
                  optimizer=optimizer,
                  metrics=(['sparse_categorical_accuracy']
                          if flags_obj.report_accuracy_metrics else None),
                  run_eagerly=flags_obj.run_eagerly)
Shining Sun's avatar
Shining Sun committed
302

Zongwei Zhou's avatar
Zongwei Zhou committed
303
304
  train_epochs = flags_obj.train_epochs

305
306
307
308
  callbacks = common.get_callbacks(
      pruning_method=flags_obj.pruning_method,
      enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export,
      model_dir=flags_obj.model_dir)
Shining Sun's avatar
Shining Sun committed
309

310
  # If mutliple epochs, ignore the train_steps flag.
311
  if train_epochs <= 1 and flags_obj.train_steps:
Zongwei Zhou's avatar
Zongwei Zhou committed
312
    steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
Shining Sun's avatar
Shining Sun committed
313
314
    train_epochs = 1

315
316
  num_eval_steps = (imagenet_preprocessing.NUM_IMAGES['validation'] //
                    flags_obj.batch_size)
Shining Sun's avatar
Shining Sun committed
317
318
319

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
320
321
322
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
323
324
325
326
    if flags_obj.set_learning_phase_to_train:
      # TODO(haoyuzhang): Understand slowdown of setting learning phase when
      # not using distribution strategy.
      tf.keras.backend.set_learning_phase(1)
Shining Sun's avatar
Shining Sun committed
327
328
329
    num_eval_steps = None
    validation_data = None

330
  # if not strategy and flags_obj.explicit_gpu_placement:
331
332
  if not strategy and flags_obj.explicit_gpu_placement:
    # TODO(b/135607227): Add device scope automatically in Keras training loop
333
    # when not using distribution strategy.
334
335
336
    no_dist_strat_device = tf.device('/device:GPU:0')
    no_dist_strat_device.__enter__()

337
338
  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
Zongwei Zhou's avatar
Zongwei Zhou committed
339
                      steps_per_epoch=steps_per_epoch,
340
                      callbacks=callbacks,
341
342
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
343
                      validation_freq=flags_obj.epochs_between_evals,
344
                      verbose=2)
Jaehong Kim's avatar
Jaehong Kim committed
345

Hongkun Yu's avatar
Hongkun Yu committed
346
347
348
349
350
351
  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)

352
353
  if flags_obj.pruning_method:
    model = tfmot.sparsity.keras.strip_pruning(model)
354
355
356
357

  if flags_obj.clustering_method:
    model = cluster.strip_clustering(model)

358
359
360
361
362
363
364
365
  if flags_obj.enable_checkpoint_and_export:
    if dtype == tf.bfloat16:
      logging.warning('Keras model.save does not support bfloat16 dtype.')
    else:
      # Keras model.save assumes a float32 input designature.
      export_path = os.path.join(flags_obj.model_dir, 'saved_model')
      model.save(export_path, include_optimizer=False)

366
367
368
  if not strategy and flags_obj.explicit_gpu_placement:
    no_dist_strat_device.__exit__()

369
370
371
372
373
374
375
  if flags_obj.clustering_method:
    if flags_obj.save_files_to:
      keras_file = os.path.join(flags_obj.save_files_to, 'clustered.h5')
    else:
      keras_file = './clustered.h5'
    print('Saving clustered and stripped model to: ', keras_file)
    tf.keras.models.save_model(model, keras_file)
376

377
  stats = common.build_stats(history, eval_output, callbacks)
378
  return stats
Shining Sun's avatar
bug fix  
Shining Sun committed
379

Shining Sun's avatar
Shining Sun committed
380

Toby Boyd's avatar
Toby Boyd committed
381
def define_imagenet_keras_flags():
382
383
384
  common.define_keras_flags(model=True,
                            optimizer=True,
                            pretrained_filepath=True)
385
  common.define_pruning_flags()
386
  common.define_clustering_flags()
387
  flags_core.set_defaults()
388
  flags.adopt_module_key_flags(common)
Toby Boyd's avatar
Toby Boyd committed
389
390


391
def main(_):
392
  model_helpers.apply_clean(flags.FLAGS)
393
  stats = run(flags.FLAGS)
394
  logging.info('Run stats:\n%s', stats)
395
396
397


if __name__ == '__main__':
398
  logging.set_verbosity(logging.INFO)
Toby Boyd's avatar
Toby Boyd committed
399
  define_imagenet_keras_flags()
400
  app.run(main)