resnet_imagenet_main.py 14 KB
Newer Older
1
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a ResNet model on the ImageNet dataset."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

Hongkun Yu's avatar
Hongkun Yu committed
21
22
import os

Hongkun Yu's avatar
Hongkun Yu committed
23
# Import libraries
Hongkun Yu's avatar
Hongkun Yu committed
24
from absl import app
25
from absl import flags
26
from absl import logging
27
import tensorflow as tf
28
from official.common import distribute_utils
29
from official.modeling import performance
30
from official.utils.flags import core as flags_core
Toby Boyd's avatar
Toby Boyd committed
31
from official.utils.misc import keras_utils
32
from official.utils.misc import model_helpers
Allen Wang's avatar
Allen Wang committed
33
from official.vision.image_classification import test_utils
34
35
from official.vision.image_classification.resnet import common
from official.vision.image_classification.resnet import imagenet_preprocessing
36
from official.vision.image_classification.resnet import resnet_model
37
38


39
40
41
42
def _cluster_last_three_conv2d_layers(model):
  """Helper method to cluster last three conv2d layers."""
  import tensorflow_model_optimization as tfmot  # pylint: disable=g-import-not-at-top
  last_three_conv2d_layers = [
43
      layer for layer in model.layers
Ruomei Yan's avatar
Ruomei Yan committed
44
45
      if isinstance(layer, tf.keras.layers.Conv2D)
    ][-3:]
46
47

  cluster_weights = tfmot.clustering.keras.cluster_weights
48
  centroid_initialization = tfmot.clustering.keras.CentroidInitialization
49
50
51

  def cluster_fn(layer):
    if layer not in last_three_conv2d_layers:
Ruomei Yan's avatar
Ruomei Yan committed
52
      return layer
53

Ruomei Yan's avatar
Ruomei Yan committed
54
55
56
    if layer == last_three_conv2d_layers[0] or \
      layer == last_three_conv2d_layers[1]:
      clustered = cluster_weights(layer, number_of_clusters=256, \
57
58
          cluster_centroids_init=centroid_initialization.LINEAR)
      print('Clustered {} with 256 clusters'.format(layer.name))
59
    else:
Ruomei Yan's avatar
Ruomei Yan committed
60
      clustered = cluster_weights(layer, number_of_clusters=32, \
61
62
          cluster_centroids_init=centroid_initialization.LINEAR)
      print('Clustered {} with 32 clusters'.format(layer.name))
63
    return clustered
64

65
  return tf.keras.models.clone_model(model, clone_function=cluster_fn)
66
67


Shining Sun's avatar
Shining Sun committed
68
def run(flags_obj):
69
70
71
72
73
74
75
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
76
    NotImplementedError: If some features are not currently supported.
77
78
79

  Returns:
    Dictionary of training and eval stats.
80
  """
Toby Boyd's avatar
Toby Boyd committed
81
  keras_utils.set_session_config(
82
      enable_xla=flags_obj.enable_xla)
83
84
  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
85
86
87
88
89
    keras_utils.set_gpu_thread_mode_and_count(
        per_gpu_thread_count=flags_obj.per_gpu_thread_count,
        gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
        num_gpus=flags_obj.num_gpus,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads)
90
  common.set_cudnn_batchnorm_mode()
91

92
  dtype = flags_core.get_tf_dtype(flags_obj)
93
94
  performance.set_mixed_precision_policy(
      flags_core.get_tf_dtype(flags_obj),
95
96
      flags_core.get_loss_scale(flags_obj, default_for_fp16=128),
      use_experimental_api=True)
97

98
99
  data_format = flags_obj.data_format
  if data_format is None:
100
101
    data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
                   else 'channels_last')
102
  tf.keras.backend.set_image_data_format(data_format)
103

104
  # Configures cluster spec for distribution strategy.
105
106
  _ = distribute_utils.configure_cluster(flags_obj.worker_hosts,
                                         flags_obj.task_index)
107

108
  strategy = distribute_utils.get_distribution_strategy(
109
110
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
111
      all_reduce_alg=flags_obj.all_reduce_alg,
112
113
      num_packs=flags_obj.num_packs,
      tpu_address=flags_obj.tpu)
114

rxsang's avatar
rxsang committed
115
116
117
118
119
120
121
  if strategy:
    # flags_obj.enable_get_next_as_optional controls whether enabling
    # get_next_as_optional behavior in DistributedIterator. If true, last
    # partial batch can be supported.
    strategy.extended.experimental_enable_get_next_as_optional = (
        flags_obj.enable_get_next_as_optional
    )
122

123
  strategy_scope = distribute_utils.get_strategy_scope(strategy)
124

125
126
  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
127
    input_fn = common.get_synth_input_fn(
128
129
130
131
        height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_preprocessing.NUM_CHANNELS,
        num_classes=imagenet_preprocessing.NUM_CLASSES,
132
133
        dtype=dtype,
        drop_remainder=True)
134
  else:
135
    input_fn = imagenet_preprocessing.input_fn
136

137
138
139
140
  # When `enable_xla` is True, we always drop the remainder of the batches
  # in the dataset, as XLA-GPU doesn't support dynamic shapes.
  drop_remainder = flags_obj.enable_xla

141
142
143
144
145
146
  # Current resnet_model.resnet50 input format is always channel-last.
  # We use keras_application mobilenet model which input format is depends on
  # the keras beckend image data format.
  # This use_keras_image_data_format flags indicates whether image preprocessor
  # output format should be same as the keras backend image data format or just
  # channel-last format.
Ruomei Yan's avatar
Ruomei Yan committed
147
  use_keras_image_data_format = \
148
    (flags_obj.model == 'mobilenet' or flags_obj.model == 'mobilenet_pretrained')
Alan Chiao's avatar
Alan Chiao committed
149

150
151
152
153
  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
154
155
      parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
          use_keras_image_data_format=use_keras_image_data_format),
Reed's avatar
Reed committed
156
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
157
      dtype=dtype,
158
159
      drop_remainder=drop_remainder,
      tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
160
      training_dataset_cache=flags_obj.training_dataset_cache,
161
  )
162

163
164
165
166
167
168
  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
169
170
        parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
            use_keras_image_data_format=use_keras_image_data_format),
171
172
        dtype=dtype,
        drop_remainder=drop_remainder)
173

174
175
176
177
178
179
180
  lr_schedule = common.PiecewiseConstantDecayWithWarmup(
      batch_size=flags_obj.batch_size,
      epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
      warmup_epochs=common.LR_SCHEDULE[0][1],
      boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
      multipliers=list(p[0] for p in common.LR_SCHEDULE),
      compute_lr_on_cpu=True)
Ruomei Yan's avatar
Ruomei Yan committed
181
182
  steps_per_epoch = (
      imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
183

Shining Sun's avatar
Shining Sun committed
184
  with strategy_scope:
185
186
    if flags_obj.optimizer == 'resnet50_default':
      optimizer = common.get_optimizer(lr_schedule)
Alan Chiao's avatar
Alan Chiao committed
187
    elif flags_obj.optimizer == 'mobilenet_default' or flags_obj.optimizer == 'mobilenet_fine_tune':
188
      initial_learning_rate = \
Jaehong Kim's avatar
Jaehong Kim committed
189
          flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size
Ruomei Yan's avatar
Ruomei Yan committed
190
191
      if flags_obj.optimizer == 'mobilenet_fine_tune':
        initial_learning_rate = 1e-5
192
193
194
      optimizer = tf.keras.optimizers.SGD(
          learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
              initial_learning_rate,
Jaehong Kim's avatar
Jaehong Kim committed
195
196
              decay_steps=steps_per_epoch * flags_obj.num_epochs_per_decay,
              decay_rate=flags_obj.lr_decay_factor,
197
198
              staircase=True),
          momentum=0.9)
199

200
    if flags_obj.fp16_implementation == 'graph_rewrite':
201
202
203
204
205
206
207
      # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
      # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
      # which will ensure tf.compat.v2.keras.mixed_precision and
      # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
      # up.
      optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
          optimizer)
208

Hongkun Yu's avatar
Hongkun Yu committed
209
    # TODO(hongkuny): Remove trivial model usage and move it to benchmark.
Haoyu Zhang's avatar
Haoyu Zhang committed
210
    if flags_obj.use_trivial_model:
Allen Wang's avatar
Allen Wang committed
211
      model = test_utils.trivial_model(imagenet_preprocessing.NUM_CLASSES)
212
    elif flags_obj.model == 'resnet50_v1.5':
213
      model = resnet_model.resnet50(
214
          num_classes=imagenet_preprocessing.NUM_CLASSES)
Alan Chiao's avatar
Alan Chiao committed
215
    elif flags_obj.model == 'mobilenet' or flags_obj.model == 'mobilenet_pretrained':
216
217
      # TODO(kimjaehong): Remove layers attribute when minimum TF version
      # support 2.0 layers by default.
Ruomei Yan's avatar
Ruomei Yan committed
218
219
220
221
222
223
      if flags_obj.model == 'mobilenet_pretrained':
        classes_labels = 1000
        initial_weights = 'imagenet'
      else:
        classes_labels = imagenet_preprocessing.NUM_CLASSES
        initial_weights = None
224
      model = tf.keras.applications.mobilenet.MobileNet(
Ruomei Yan's avatar
Ruomei Yan committed
225
226
          weights=initial_weights,
          classes=classes_labels,
227
228
          layers=tf.keras.layers)

229
230
231
232
    if flags_obj.pretrained_filepath:
      model.load_weights(flags_obj.pretrained_filepath)

    if flags_obj.pruning_method == 'polynomial_decay':
233
      import tensorflow_model_optimization as tfmot  # pylint: disable=g-import-not-at-top
234
235
236
237
238
239
240
241
242
243
244
245
246
247
      if dtype != tf.float32:
        raise NotImplementedError(
            'Pruning is currently only supported on dtype=tf.float32.')
      pruning_params = {
          'pruning_schedule':
              tfmot.sparsity.keras.PolynomialDecay(
                  initial_sparsity=flags_obj.pruning_initial_sparsity,
                  final_sparsity=flags_obj.pruning_final_sparsity,
                  begin_step=flags_obj.pruning_begin_step,
                  end_step=flags_obj.pruning_end_step,
                  frequency=flags_obj.pruning_frequency),
      }
      model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)
    elif flags_obj.pruning_method:
248
249
250
      raise NotImplementedError('Only polynomial_decay is currently supported.')

    if flags_obj.clustering_method == 'selective_clustering':
251
      import tensorflow_model_optimization as tfmot  # pylint: disable=g-import-not-at-top
Ruomei Yan's avatar
Ruomei Yan committed
252
253
      if dtype != tf.float32 or \
        flags_obj.fp16_implementation == 'graph_rewrite':
254
255
        raise NotImplementedError(
            'Clustering is currently only supported on dtype=tf.float32.')
256
      model = _cluster_last_three_conv2d_layers(model)
257
    elif flags_obj.clustering_method:
258
      raise NotImplementedError(
259
          'Only selective_clustering is implemented.')
260

Ruomei Yan's avatar
Ruomei Yan committed
261
262
263
264
265
266
    model.compile(
        loss='sparse_categorical_crossentropy',
        optimizer=optimizer,
        metrics=(['sparse_categorical_accuracy']
                 if flags_obj.report_accuracy_metrics else None),
        run_eagerly=flags_obj.run_eagerly)
Shining Sun's avatar
Shining Sun committed
267

Zongwei Zhou's avatar
Zongwei Zhou committed
268
269
  train_epochs = flags_obj.train_epochs

270
271
272
273
  callbacks = common.get_callbacks(
      pruning_method=flags_obj.pruning_method,
      enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export,
      model_dir=flags_obj.model_dir)
Shining Sun's avatar
Shining Sun committed
274

275
  # If mutliple epochs, ignore the train_steps flag.
276
  if train_epochs <= 1 and flags_obj.train_steps:
Zongwei Zhou's avatar
Zongwei Zhou committed
277
    steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
Shining Sun's avatar
Shining Sun committed
278
279
    train_epochs = 1

Ruomei Yan's avatar
Ruomei Yan committed
280
281
  num_eval_steps = (
      imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size)
Shining Sun's avatar
Shining Sun committed
282
283
284

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
285
286
287
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
288
289
290
291
    if flags_obj.set_learning_phase_to_train:
      # TODO(haoyuzhang): Understand slowdown of setting learning phase when
      # not using distribution strategy.
      tf.keras.backend.set_learning_phase(1)
Shining Sun's avatar
Shining Sun committed
292
293
294
    num_eval_steps = None
    validation_data = None

295
296
  if not strategy and flags_obj.explicit_gpu_placement:
    # TODO(b/135607227): Add device scope automatically in Keras training loop
297
    # when not using distribution strategy.
298
299
300
    no_dist_strat_device = tf.device('/device:GPU:0')
    no_dist_strat_device.__enter__()

301
302
  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
Zongwei Zhou's avatar
Zongwei Zhou committed
303
                      steps_per_epoch=steps_per_epoch,
304
                      callbacks=callbacks,
305
306
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
307
                      validation_freq=flags_obj.epochs_between_evals,
308
                      verbose=2)
Jaehong Kim's avatar
Jaehong Kim committed
309

Hongkun Yu's avatar
Hongkun Yu committed
310
311
312
313
314
315
  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)

316
317
  if flags_obj.pruning_method:
    model = tfmot.sparsity.keras.strip_pruning(model)
318
319

  if flags_obj.clustering_method:
320
    model = tfmot.clustering.keras.strip_clustering(model)
321

322
323
324
325
326
327
328
329
  if flags_obj.enable_checkpoint_and_export:
    if dtype == tf.bfloat16:
      logging.warning('Keras model.save does not support bfloat16 dtype.')
    else:
      # Keras model.save assumes a float32 input designature.
      export_path = os.path.join(flags_obj.model_dir, 'saved_model')
      model.save(export_path, include_optimizer=False)

330
331
332
  if not strategy and flags_obj.explicit_gpu_placement:
    no_dist_strat_device.__exit__()

333
  stats = common.build_stats(history, eval_output, callbacks)
334
  return stats
Shining Sun's avatar
bug fix  
Shining Sun committed
335

Shining Sun's avatar
Shining Sun committed
336

Toby Boyd's avatar
Toby Boyd committed
337
def define_imagenet_keras_flags():
Ruomei Yan's avatar
Ruomei Yan committed
338
339
340
341
  common.define_keras_flags(
      model=True,
      optimizer=True,
      pretrained_filepath=True)
342
  common.define_pruning_flags()
343
  common.define_clustering_flags()
344
  flags_core.set_defaults()
345
  flags.adopt_module_key_flags(common)
Toby Boyd's avatar
Toby Boyd committed
346
347


348
def main(_):
349
  model_helpers.apply_clean(flags.FLAGS)
350
  stats = run(flags.FLAGS)
351
  logging.info('Run stats:\n%s', stats)
352
353
354


if __name__ == '__main__':
355
  logging.set_verbosity(logging.INFO)
Toby Boyd's avatar
Toby Boyd committed
356
  define_imagenet_keras_flags()
Ruomei Yan's avatar
Ruomei Yan committed
357
  app.run(main)