resnet_imagenet_main.py 14 KB
Newer Older
1
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a ResNet model on the ImageNet dataset."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

Hongkun Yu's avatar
Hongkun Yu committed
21
22
import os

Hongkun Yu's avatar
Hongkun Yu committed
23
# Import libraries
Hongkun Yu's avatar
Hongkun Yu committed
24
from absl import app
25
from absl import flags
26
from absl import logging
27
import tensorflow as tf
28
from official.common import distribute_utils
29
from official.modeling import performance
30
from official.utils.flags import core as flags_core
Toby Boyd's avatar
Toby Boyd committed
31
from official.utils.misc import keras_utils
32
from official.utils.misc import model_helpers
Allen Wang's avatar
Allen Wang committed
33
from official.vision.image_classification import test_utils
34
35
from official.vision.image_classification.resnet import common
from official.vision.image_classification.resnet import imagenet_preprocessing
36
from official.vision.image_classification.resnet import resnet_model
37
38


39
40
41
42
def _cluster_last_three_conv2d_layers(model):
  """Helper method to cluster last three conv2d layers."""
  import tensorflow_model_optimization as tfmot  # pylint: disable=g-import-not-at-top
  last_three_conv2d_layers = [
43
      layer for layer in model.layers
Ruomei Yan's avatar
Ruomei Yan committed
44
45
      if isinstance(layer, tf.keras.layers.Conv2D)
    ][-3:]
46
47

  cluster_weights = tfmot.clustering.keras.cluster_weights
48
  centroid_initialization = tfmot.clustering.keras.CentroidInitialization
49
50
51

  def cluster_fn(layer):
    if layer not in last_three_conv2d_layers:
Ruomei Yan's avatar
Ruomei Yan committed
52
      return layer
53

Ruomei Yan's avatar
Ruomei Yan committed
54
55
56
    if layer == last_three_conv2d_layers[0] or \
      layer == last_three_conv2d_layers[1]:
      clustered = cluster_weights(layer, number_of_clusters=256, \
57
58
          cluster_centroids_init=centroid_initialization.LINEAR)
      print('Clustered {} with 256 clusters'.format(layer.name))
59
    else:
Ruomei Yan's avatar
Ruomei Yan committed
60
      clustered = cluster_weights(layer, number_of_clusters=32, \
61
62
          cluster_centroids_init=centroid_initialization.LINEAR)
      print('Clustered {} with 32 clusters'.format(layer.name))
63
    return clustered
64

65
  return tf.keras.models.clone_model(model, clone_function=cluster_fn)
66
67


Shining Sun's avatar
Shining Sun committed
68
def run(flags_obj):
69
70
71
72
73
74
75
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
76
    NotImplementedError: If some features are not currently supported.
77
78
79

  Returns:
    Dictionary of training and eval stats.
80
  """
Toby Boyd's avatar
Toby Boyd committed
81
  keras_utils.set_session_config(
82
      enable_xla=flags_obj.enable_xla)
83
84
  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
85
86
87
88
89
    keras_utils.set_gpu_thread_mode_and_count(
        per_gpu_thread_count=flags_obj.per_gpu_thread_count,
        gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
        num_gpus=flags_obj.num_gpus,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads)
90
  common.set_cudnn_batchnorm_mode()
91

92
  dtype = flags_core.get_tf_dtype(flags_obj)
93
94
  performance.set_mixed_precision_policy(
      flags_core.get_tf_dtype(flags_obj),
95
96
      flags_core.get_loss_scale(flags_obj, default_for_fp16=128),
      use_experimental_api=True)
97

98
99
  data_format = flags_obj.data_format
  if data_format is None:
100
101
    data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
                   else 'channels_last')
102
  tf.keras.backend.set_image_data_format(data_format)
103

104
  # Configures cluster spec for distribution strategy.
105
106
  _ = distribute_utils.configure_cluster(flags_obj.worker_hosts,
                                         flags_obj.task_index)
107

108
  strategy = distribute_utils.get_distribution_strategy(
109
110
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
111
      all_reduce_alg=flags_obj.all_reduce_alg,
112
113
      num_packs=flags_obj.num_packs,
      tpu_address=flags_obj.tpu)
114

rxsang's avatar
rxsang committed
115
116
117
118
119
120
121
  if strategy:
    # flags_obj.enable_get_next_as_optional controls whether enabling
    # get_next_as_optional behavior in DistributedIterator. If true, last
    # partial batch can be supported.
    strategy.extended.experimental_enable_get_next_as_optional = (
        flags_obj.enable_get_next_as_optional
    )
122

123
  strategy_scope = distribute_utils.get_strategy_scope(strategy)
124

125
126
  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
127
    input_fn = common.get_synth_input_fn(
128
129
130
131
        height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_preprocessing.NUM_CHANNELS,
        num_classes=imagenet_preprocessing.NUM_CLASSES,
132
133
        dtype=dtype,
        drop_remainder=True)
134
  else:
135
    input_fn = imagenet_preprocessing.input_fn
136

137
138
139
140
  # When `enable_xla` is True, we always drop the remainder of the batches
  # in the dataset, as XLA-GPU doesn't support dynamic shapes.
  drop_remainder = flags_obj.enable_xla

141
142
143
144
145
146
  # Current resnet_model.resnet50 input format is always channel-last.
  # We use keras_application mobilenet model which input format is depends on
  # the keras beckend image data format.
  # This use_keras_image_data_format flags indicates whether image preprocessor
  # output format should be same as the keras backend image data format or just
  # channel-last format.
Ruomei Yan's avatar
Ruomei Yan committed
147
  use_keras_image_data_format = \
148
    (flags_obj.model == 'mobilenet' or flags_obj.model == 'mobilenet_pretrained')
Alan Chiao's avatar
Alan Chiao committed
149

150
151
152
153
  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
154
155
      parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
          use_keras_image_data_format=use_keras_image_data_format),
Reed's avatar
Reed committed
156
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
157
      dtype=dtype,
158
159
      drop_remainder=drop_remainder,
      tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
160
      training_dataset_cache=flags_obj.training_dataset_cache,
161
  )
162

163
164
165
166
167
168
  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
169
170
        parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
            use_keras_image_data_format=use_keras_image_data_format),
171
172
        dtype=dtype,
        drop_remainder=drop_remainder)
173

174
175
176
177
178
179
180
  lr_schedule = common.PiecewiseConstantDecayWithWarmup(
      batch_size=flags_obj.batch_size,
      epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
      warmup_epochs=common.LR_SCHEDULE[0][1],
      boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
      multipliers=list(p[0] for p in common.LR_SCHEDULE),
      compute_lr_on_cpu=True)
Ruomei Yan's avatar
Ruomei Yan committed
181
182
  steps_per_epoch = (
      imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
183

Shining Sun's avatar
Shining Sun committed
184
  with strategy_scope:
185
186
    if flags_obj.optimizer == 'resnet50_default':
      optimizer = common.get_optimizer(lr_schedule)
Alan Chiao's avatar
Alan Chiao committed
187
    elif flags_obj.optimizer == 'mobilenet_default' or flags_obj.optimizer == 'mobilenet_fine_tune':
188
      initial_learning_rate = \
Jaehong Kim's avatar
Jaehong Kim committed
189
          flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size
Ruomei Yan's avatar
Ruomei Yan committed
190
191
      if flags_obj.optimizer == 'mobilenet_fine_tune':
        initial_learning_rate = 1e-5
192
193
194
      optimizer = tf.keras.optimizers.SGD(
          learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
              initial_learning_rate,
Jaehong Kim's avatar
Jaehong Kim committed
195
196
              decay_steps=steps_per_epoch * flags_obj.num_epochs_per_decay,
              decay_rate=flags_obj.lr_decay_factor,
197
198
              staircase=True),
          momentum=0.9)
199

200
    if flags_obj.fp16_implementation == 'graph_rewrite':
201
202
203
204
205
      # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
      # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
      # which will ensure tf.compat.v2.keras.mixed_precision and
      # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
      # up.
206
207
208
      optimizer = (
          tf.compat.v1.mixed_precision.enable_mixed_precision_graph_rewrite(
              optimizer))
209

Hongkun Yu's avatar
Hongkun Yu committed
210
    # TODO(hongkuny): Remove trivial model usage and move it to benchmark.
Haoyu Zhang's avatar
Haoyu Zhang committed
211
    if flags_obj.use_trivial_model:
Allen Wang's avatar
Allen Wang committed
212
      model = test_utils.trivial_model(imagenet_preprocessing.NUM_CLASSES)
213
    elif flags_obj.model == 'resnet50_v1.5':
214
      model = resnet_model.resnet50(
215
          num_classes=imagenet_preprocessing.NUM_CLASSES)
Alan Chiao's avatar
Alan Chiao committed
216
    elif flags_obj.model == 'mobilenet' or flags_obj.model == 'mobilenet_pretrained':
217
218
      # TODO(kimjaehong): Remove layers attribute when minimum TF version
      # support 2.0 layers by default.
Ruomei Yan's avatar
Ruomei Yan committed
219
220
221
222
223
224
      if flags_obj.model == 'mobilenet_pretrained':
        classes_labels = 1000
        initial_weights = 'imagenet'
      else:
        classes_labels = imagenet_preprocessing.NUM_CLASSES
        initial_weights = None
225
      model = tf.keras.applications.mobilenet.MobileNet(
Ruomei Yan's avatar
Ruomei Yan committed
226
227
          weights=initial_weights,
          classes=classes_labels,
228
229
          layers=tf.keras.layers)

230
231
232
233
    if flags_obj.pretrained_filepath:
      model.load_weights(flags_obj.pretrained_filepath)

    if flags_obj.pruning_method == 'polynomial_decay':
234
      import tensorflow_model_optimization as tfmot  # pylint: disable=g-import-not-at-top
235
236
237
238
239
240
241
242
243
244
245
246
247
248
      if dtype != tf.float32:
        raise NotImplementedError(
            'Pruning is currently only supported on dtype=tf.float32.')
      pruning_params = {
          'pruning_schedule':
              tfmot.sparsity.keras.PolynomialDecay(
                  initial_sparsity=flags_obj.pruning_initial_sparsity,
                  final_sparsity=flags_obj.pruning_final_sparsity,
                  begin_step=flags_obj.pruning_begin_step,
                  end_step=flags_obj.pruning_end_step,
                  frequency=flags_obj.pruning_frequency),
      }
      model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)
    elif flags_obj.pruning_method:
249
250
251
      raise NotImplementedError('Only polynomial_decay is currently supported.')

    if flags_obj.clustering_method == 'selective_clustering':
252
      import tensorflow_model_optimization as tfmot  # pylint: disable=g-import-not-at-top
Ruomei Yan's avatar
Ruomei Yan committed
253
254
      if dtype != tf.float32 or \
        flags_obj.fp16_implementation == 'graph_rewrite':
255
256
        raise NotImplementedError(
            'Clustering is currently only supported on dtype=tf.float32.')
257
      model = _cluster_last_three_conv2d_layers(model)
258
    elif flags_obj.clustering_method:
259
      raise NotImplementedError(
260
          'Only selective_clustering is implemented.')
261

Ruomei Yan's avatar
Ruomei Yan committed
262
263
264
265
266
267
    model.compile(
        loss='sparse_categorical_crossentropy',
        optimizer=optimizer,
        metrics=(['sparse_categorical_accuracy']
                 if flags_obj.report_accuracy_metrics else None),
        run_eagerly=flags_obj.run_eagerly)
Shining Sun's avatar
Shining Sun committed
268

Zongwei Zhou's avatar
Zongwei Zhou committed
269
270
  train_epochs = flags_obj.train_epochs

271
272
273
274
  callbacks = common.get_callbacks(
      pruning_method=flags_obj.pruning_method,
      enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export,
      model_dir=flags_obj.model_dir)
Shining Sun's avatar
Shining Sun committed
275

276
  # If mutliple epochs, ignore the train_steps flag.
277
  if train_epochs <= 1 and flags_obj.train_steps:
Zongwei Zhou's avatar
Zongwei Zhou committed
278
    steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
Shining Sun's avatar
Shining Sun committed
279
280
    train_epochs = 1

Ruomei Yan's avatar
Ruomei Yan committed
281
282
  num_eval_steps = (
      imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size)
Shining Sun's avatar
Shining Sun committed
283
284
285

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
286
287
288
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
289
290
291
292
    if flags_obj.set_learning_phase_to_train:
      # TODO(haoyuzhang): Understand slowdown of setting learning phase when
      # not using distribution strategy.
      tf.keras.backend.set_learning_phase(1)
Shining Sun's avatar
Shining Sun committed
293
294
295
    num_eval_steps = None
    validation_data = None

296
297
  if not strategy and flags_obj.explicit_gpu_placement:
    # TODO(b/135607227): Add device scope automatically in Keras training loop
298
    # when not using distribution strategy.
299
300
301
    no_dist_strat_device = tf.device('/device:GPU:0')
    no_dist_strat_device.__enter__()

302
303
  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
Zongwei Zhou's avatar
Zongwei Zhou committed
304
                      steps_per_epoch=steps_per_epoch,
305
                      callbacks=callbacks,
306
307
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
308
                      validation_freq=flags_obj.epochs_between_evals,
309
                      verbose=2)
Jaehong Kim's avatar
Jaehong Kim committed
310

Hongkun Yu's avatar
Hongkun Yu committed
311
312
313
314
315
316
  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)

317
318
  if flags_obj.pruning_method:
    model = tfmot.sparsity.keras.strip_pruning(model)
319
320

  if flags_obj.clustering_method:
321
    model = tfmot.clustering.keras.strip_clustering(model)
322

323
324
325
326
327
328
329
330
  if flags_obj.enable_checkpoint_and_export:
    if dtype == tf.bfloat16:
      logging.warning('Keras model.save does not support bfloat16 dtype.')
    else:
      # Keras model.save assumes a float32 input designature.
      export_path = os.path.join(flags_obj.model_dir, 'saved_model')
      model.save(export_path, include_optimizer=False)

331
332
333
  if not strategy and flags_obj.explicit_gpu_placement:
    no_dist_strat_device.__exit__()

334
  stats = common.build_stats(history, eval_output, callbacks)
335
  return stats
Shining Sun's avatar
bug fix  
Shining Sun committed
336

Shining Sun's avatar
Shining Sun committed
337

Toby Boyd's avatar
Toby Boyd committed
338
def define_imagenet_keras_flags():
Ruomei Yan's avatar
Ruomei Yan committed
339
340
341
342
  common.define_keras_flags(
      model=True,
      optimizer=True,
      pretrained_filepath=True)
343
  common.define_pruning_flags()
344
  common.define_clustering_flags()
345
  flags_core.set_defaults()
346
  flags.adopt_module_key_flags(common)
Toby Boyd's avatar
Toby Boyd committed
347
348


349
def main(_):
350
  model_helpers.apply_clean(flags.FLAGS)
351
  stats = run(flags.FLAGS)
352
  logging.info('Run stats:\n%s', stats)
353
354
355


if __name__ == '__main__':
356
  logging.set_verbosity(logging.INFO)
Toby Boyd's avatar
Toby Boyd committed
357
  define_imagenet_keras_flags()
Ruomei Yan's avatar
Ruomei Yan committed
358
  app.run(main)