resnet_imagenet_main.py 13.5 KB
Newer Older
1
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a ResNet model on the ImageNet dataset."""

Hongkun Yu's avatar
Hongkun Yu committed
17
18
import os

Hongkun Yu's avatar
Hongkun Yu committed
19
# Import libraries
Hongkun Yu's avatar
Hongkun Yu committed
20
from absl import app
21
from absl import flags
22
from absl import logging
23
import tensorflow as tf
24
from official.common import distribute_utils
Fan Yang's avatar
Fan Yang committed
25
26
27
28
from official.legacy.image_classification import test_utils
from official.legacy.image_classification.resnet import common
from official.legacy.image_classification.resnet import imagenet_preprocessing
from official.legacy.image_classification.resnet import resnet_model
29
from official.modeling import performance
30
from official.utils.flags import core as flags_core
Toby Boyd's avatar
Toby Boyd committed
31
from official.utils.misc import keras_utils
32
from official.utils.misc import model_helpers
33
34


35
36
37
38
def _cluster_last_three_conv2d_layers(model):
  """Helper method to cluster last three conv2d layers."""
  import tensorflow_model_optimization as tfmot  # pylint: disable=g-import-not-at-top
  last_three_conv2d_layers = [
39
      layer for layer in model.layers
Ruomei Yan's avatar
Ruomei Yan committed
40
41
      if isinstance(layer, tf.keras.layers.Conv2D)
    ][-3:]
42
43

  cluster_weights = tfmot.clustering.keras.cluster_weights
44
  centroid_initialization = tfmot.clustering.keras.CentroidInitialization
45
46
47

  def cluster_fn(layer):
    if layer not in last_three_conv2d_layers:
Ruomei Yan's avatar
Ruomei Yan committed
48
      return layer
49

Ruomei Yan's avatar
Ruomei Yan committed
50
51
52
    if layer == last_three_conv2d_layers[0] or \
      layer == last_three_conv2d_layers[1]:
      clustered = cluster_weights(layer, number_of_clusters=256, \
53
54
          cluster_centroids_init=centroid_initialization.LINEAR)
      print('Clustered {} with 256 clusters'.format(layer.name))
55
    else:
Ruomei Yan's avatar
Ruomei Yan committed
56
      clustered = cluster_weights(layer, number_of_clusters=32, \
57
58
          cluster_centroids_init=centroid_initialization.LINEAR)
      print('Clustered {} with 32 clusters'.format(layer.name))
59
    return clustered
60

61
  return tf.keras.models.clone_model(model, clone_function=cluster_fn)
62
63


Shining Sun's avatar
Shining Sun committed
64
def run(flags_obj):
65
66
67
68
69
70
71
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
72
    NotImplementedError: If some features are not currently supported.
73
74
75

  Returns:
    Dictionary of training and eval stats.
76
  """
77
78
  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
79
80
81
82
83
    keras_utils.set_gpu_thread_mode_and_count(
        per_gpu_thread_count=flags_obj.per_gpu_thread_count,
        gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
        num_gpus=flags_obj.num_gpus,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads)
84
  common.set_cudnn_batchnorm_mode()
85

86
  dtype = flags_core.get_tf_dtype(flags_obj)
87
  performance.set_mixed_precision_policy(
88
      flags_core.get_tf_dtype(flags_obj))
89

90
91
  data_format = flags_obj.data_format
  if data_format is None:
92
93
    data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
                   else 'channels_last')
94
  tf.keras.backend.set_image_data_format(data_format)
95

96
  # Configures cluster spec for distribution strategy.
97
98
  _ = distribute_utils.configure_cluster(flags_obj.worker_hosts,
                                         flags_obj.task_index)
99

100
  strategy = distribute_utils.get_distribution_strategy(
101
102
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
103
      all_reduce_alg=flags_obj.all_reduce_alg,
104
105
      num_packs=flags_obj.num_packs,
      tpu_address=flags_obj.tpu)
106

rxsang's avatar
rxsang committed
107
108
109
110
111
112
113
  if strategy:
    # flags_obj.enable_get_next_as_optional controls whether enabling
    # get_next_as_optional behavior in DistributedIterator. If true, last
    # partial batch can be supported.
    strategy.extended.experimental_enable_get_next_as_optional = (
        flags_obj.enable_get_next_as_optional
    )
114

115
  strategy_scope = distribute_utils.get_strategy_scope(strategy)
116

117
118
  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
119
    input_fn = common.get_synth_input_fn(
120
121
122
123
        height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_preprocessing.NUM_CHANNELS,
        num_classes=imagenet_preprocessing.NUM_CLASSES,
124
125
        dtype=dtype,
        drop_remainder=True)
126
  else:
127
    input_fn = imagenet_preprocessing.input_fn
128

129
130
131
132
  # When `enable_xla` is True, we always drop the remainder of the batches
  # in the dataset, as XLA-GPU doesn't support dynamic shapes.
  drop_remainder = flags_obj.enable_xla

133
134
135
136
137
138
  # Current resnet_model.resnet50 input format is always channel-last.
  # We use keras_application mobilenet model which input format is depends on
  # the keras beckend image data format.
  # This use_keras_image_data_format flags indicates whether image preprocessor
  # output format should be same as the keras backend image data format or just
  # channel-last format.
Ruomei Yan's avatar
Ruomei Yan committed
139
  use_keras_image_data_format = \
140
    (flags_obj.model == 'mobilenet' or flags_obj.model == 'mobilenet_pretrained')
Alan Chiao's avatar
Alan Chiao committed
141

142
143
144
145
  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
146
147
      parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
          use_keras_image_data_format=use_keras_image_data_format),
Reed's avatar
Reed committed
148
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
149
      dtype=dtype,
150
151
      drop_remainder=drop_remainder,
      tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
152
      training_dataset_cache=flags_obj.training_dataset_cache,
153
  )
154

155
156
157
158
159
160
  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
161
162
        parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
            use_keras_image_data_format=use_keras_image_data_format),
163
164
        dtype=dtype,
        drop_remainder=drop_remainder)
165

166
167
168
169
170
171
172
  lr_schedule = common.PiecewiseConstantDecayWithWarmup(
      batch_size=flags_obj.batch_size,
      epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
      warmup_epochs=common.LR_SCHEDULE[0][1],
      boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
      multipliers=list(p[0] for p in common.LR_SCHEDULE),
      compute_lr_on_cpu=True)
Ruomei Yan's avatar
Ruomei Yan committed
173
174
  steps_per_epoch = (
      imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
175

Shining Sun's avatar
Shining Sun committed
176
  with strategy_scope:
177
178
    if flags_obj.optimizer == 'resnet50_default':
      optimizer = common.get_optimizer(lr_schedule)
Alan Chiao's avatar
Alan Chiao committed
179
    elif flags_obj.optimizer == 'mobilenet_default' or flags_obj.optimizer == 'mobilenet_fine_tune':
180
      initial_learning_rate = \
Jaehong Kim's avatar
Jaehong Kim committed
181
          flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size
Ruomei Yan's avatar
Ruomei Yan committed
182
183
      if flags_obj.optimizer == 'mobilenet_fine_tune':
        initial_learning_rate = 1e-5
184
185
186
      optimizer = tf.keras.optimizers.SGD(
          learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
              initial_learning_rate,
Jaehong Kim's avatar
Jaehong Kim committed
187
188
              decay_steps=steps_per_epoch * flags_obj.num_epochs_per_decay,
              decay_rate=flags_obj.lr_decay_factor,
189
190
              staircase=True),
          momentum=0.9)
191
192
193
194
    optimizer = performance.configure_optimizer(
        optimizer,
        use_float16=flags_core.get_tf_dtype(flags_obj) == tf.float16,
        loss_scale=flags_core.get_loss_scale(flags_obj, default_for_fp16=128),)
195

Hongkun Yu's avatar
Hongkun Yu committed
196
    # TODO(hongkuny): Remove trivial model usage and move it to benchmark.
Haoyu Zhang's avatar
Haoyu Zhang committed
197
    if flags_obj.use_trivial_model:
Allen Wang's avatar
Allen Wang committed
198
      model = test_utils.trivial_model(imagenet_preprocessing.NUM_CLASSES)
199
    elif flags_obj.model == 'resnet50_v1.5':
200
      model = resnet_model.resnet50(
201
          num_classes=imagenet_preprocessing.NUM_CLASSES)
Alan Chiao's avatar
Alan Chiao committed
202
    elif flags_obj.model == 'mobilenet' or flags_obj.model == 'mobilenet_pretrained':
203
204
      # TODO(kimjaehong): Remove layers attribute when minimum TF version
      # support 2.0 layers by default.
Ruomei Yan's avatar
Ruomei Yan committed
205
206
207
208
209
210
      if flags_obj.model == 'mobilenet_pretrained':
        classes_labels = 1000
        initial_weights = 'imagenet'
      else:
        classes_labels = imagenet_preprocessing.NUM_CLASSES
        initial_weights = None
211
      model = tf.keras.applications.mobilenet.MobileNet(
Ruomei Yan's avatar
Ruomei Yan committed
212
213
          weights=initial_weights,
          classes=classes_labels,
214
215
          layers=tf.keras.layers)

216
217
218
219
    if flags_obj.pretrained_filepath:
      model.load_weights(flags_obj.pretrained_filepath)

    if flags_obj.pruning_method == 'polynomial_decay':
220
      import tensorflow_model_optimization as tfmot  # pylint: disable=g-import-not-at-top
221
222
223
224
225
226
227
228
229
230
231
232
233
234
      if dtype != tf.float32:
        raise NotImplementedError(
            'Pruning is currently only supported on dtype=tf.float32.')
      pruning_params = {
          'pruning_schedule':
              tfmot.sparsity.keras.PolynomialDecay(
                  initial_sparsity=flags_obj.pruning_initial_sparsity,
                  final_sparsity=flags_obj.pruning_final_sparsity,
                  begin_step=flags_obj.pruning_begin_step,
                  end_step=flags_obj.pruning_end_step,
                  frequency=flags_obj.pruning_frequency),
      }
      model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)
    elif flags_obj.pruning_method:
235
236
237
      raise NotImplementedError('Only polynomial_decay is currently supported.')

    if flags_obj.clustering_method == 'selective_clustering':
238
      import tensorflow_model_optimization as tfmot  # pylint: disable=g-import-not-at-top
239
      if dtype != tf.float32:
240
241
        raise NotImplementedError(
            'Clustering is currently only supported on dtype=tf.float32.')
242
      model = _cluster_last_three_conv2d_layers(model)
243
    elif flags_obj.clustering_method:
244
      raise NotImplementedError(
245
          'Only selective_clustering is implemented.')
246

Ruomei Yan's avatar
Ruomei Yan committed
247
248
249
250
251
    model.compile(
        loss='sparse_categorical_crossentropy',
        optimizer=optimizer,
        metrics=(['sparse_categorical_accuracy']
                 if flags_obj.report_accuracy_metrics else None),
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
252
253
        run_eagerly=flags_obj.run_eagerly,
        jit_compile=flags_obj.enable_xla)
Shining Sun's avatar
Shining Sun committed
254

Zongwei Zhou's avatar
Zongwei Zhou committed
255
256
  train_epochs = flags_obj.train_epochs

257
258
259
260
  callbacks = common.get_callbacks(
      pruning_method=flags_obj.pruning_method,
      enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export,
      model_dir=flags_obj.model_dir)
Shining Sun's avatar
Shining Sun committed
261

262
  # If mutliple epochs, ignore the train_steps flag.
263
  if train_epochs <= 1 and flags_obj.train_steps:
Zongwei Zhou's avatar
Zongwei Zhou committed
264
    steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
Shining Sun's avatar
Shining Sun committed
265
266
    train_epochs = 1

Ruomei Yan's avatar
Ruomei Yan committed
267
268
  num_eval_steps = (
      imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size)
Shining Sun's avatar
Shining Sun committed
269
270
271

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
272
273
274
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
275
276
277
278
    if flags_obj.set_learning_phase_to_train:
      # TODO(haoyuzhang): Understand slowdown of setting learning phase when
      # not using distribution strategy.
      tf.keras.backend.set_learning_phase(1)
Shining Sun's avatar
Shining Sun committed
279
280
281
    num_eval_steps = None
    validation_data = None

282
283
  if not strategy and flags_obj.explicit_gpu_placement:
    # TODO(b/135607227): Add device scope automatically in Keras training loop
284
    # when not using distribution strategy.
285
286
287
    no_dist_strat_device = tf.device('/device:GPU:0')
    no_dist_strat_device.__enter__()

288
289
  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
Zongwei Zhou's avatar
Zongwei Zhou committed
290
                      steps_per_epoch=steps_per_epoch,
291
                      callbacks=callbacks,
292
293
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
294
                      validation_freq=flags_obj.epochs_between_evals,
295
                      verbose=2)
Jaehong Kim's avatar
Jaehong Kim committed
296

Hongkun Yu's avatar
Hongkun Yu committed
297
298
299
300
301
302
  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)

303
304
  if flags_obj.pruning_method:
    model = tfmot.sparsity.keras.strip_pruning(model)
305
306

  if flags_obj.clustering_method:
307
    model = tfmot.clustering.keras.strip_clustering(model)
308

309
310
311
312
313
314
315
316
  if flags_obj.enable_checkpoint_and_export:
    if dtype == tf.bfloat16:
      logging.warning('Keras model.save does not support bfloat16 dtype.')
    else:
      # Keras model.save assumes a float32 input designature.
      export_path = os.path.join(flags_obj.model_dir, 'saved_model')
      model.save(export_path, include_optimizer=False)

317
318
319
  if not strategy and flags_obj.explicit_gpu_placement:
    no_dist_strat_device.__exit__()

320
  stats = common.build_stats(history, eval_output, callbacks)
321
  return stats
Shining Sun's avatar
bug fix  
Shining Sun committed
322

Shining Sun's avatar
Shining Sun committed
323

Toby Boyd's avatar
Toby Boyd committed
324
def define_imagenet_keras_flags():
Ruomei Yan's avatar
Ruomei Yan committed
325
326
327
328
  common.define_keras_flags(
      model=True,
      optimizer=True,
      pretrained_filepath=True)
329
  common.define_pruning_flags()
330
  common.define_clustering_flags()
331
  flags_core.set_defaults()
332
  flags.adopt_module_key_flags(common)
Toby Boyd's avatar
Toby Boyd committed
333
334


335
def main(_):
336
  model_helpers.apply_clean(flags.FLAGS)
337
  stats = run(flags.FLAGS)
338
  logging.info('Run stats:\n%s', stats)
339
340
341


if __name__ == '__main__':
342
  logging.set_verbosity(logging.INFO)
Toby Boyd's avatar
Toby Boyd committed
343
  define_imagenet_keras_flags()
Ruomei Yan's avatar
Ruomei Yan committed
344
  app.run(main)