resnet_imagenet_main.py 13.7 KB
Newer Older
1
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a ResNet model on the ImageNet dataset."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

Hongkun Yu's avatar
Hongkun Yu committed
21
22
import os

Hongkun Yu's avatar
Hongkun Yu committed
23
# Import libraries
Hongkun Yu's avatar
Hongkun Yu committed
24
from absl import app
25
from absl import flags
26
from absl import logging
27
import tensorflow as tf
28

29
from official.modeling import performance
30
31
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
Toby Boyd's avatar
Toby Boyd committed
32
from official.utils.misc import keras_utils
33
from official.utils.misc import model_helpers
Allen Wang's avatar
Allen Wang committed
34
from official.vision.image_classification import test_utils
35
36
from official.vision.image_classification.resnet import common
from official.vision.image_classification.resnet import imagenet_preprocessing
37
from official.vision.image_classification.resnet import resnet_model
38
39


40
def cluster_last_three_conv2d_layers(model):
Ruomei Yan's avatar
Ruomei Yan committed
41
  import tensorflow_model_optimization as tfmot
42
43
  last_three_conv2d_layers =  [
      layer for layer in model.layers
Ruomei Yan's avatar
Ruomei Yan committed
44
45
      if isinstance(layer, tf.keras.layers.Conv2D)
    ][-3:]
46
47
48
49
50
51

  cluster_weights = tfmot.clustering.keras.cluster_weights
  CentroidInitialization = tfmot.clustering.keras.CentroidInitialization

  def cluster_fn(layer):
    if layer not in last_three_conv2d_layers:
Ruomei Yan's avatar
Ruomei Yan committed
52
      return layer
53

Ruomei Yan's avatar
Ruomei Yan committed
54
55
56
57
58
    if layer == last_three_conv2d_layers[0] or \
      layer == last_three_conv2d_layers[1]:
      clustered = cluster_weights(layer, number_of_clusters=256, \
          cluster_centroids_init=CentroidInitialization.LINEAR)
      print("Clustered {} with 256 clusters".format(layer.name))
59
    else:
Ruomei Yan's avatar
Ruomei Yan committed
60
61
62
      clustered = cluster_weights(layer, number_of_clusters=32, \
          cluster_centroids_init=CentroidInitialization.LINEAR)
      print("Clustered {} with 32 clusters".format(layer.name))
63
    return clustered
64

65
  return tf.keras.models.clone_model(model, clone_function=cluster_fn)
66
67


Shining Sun's avatar
Shining Sun committed
68
def run(flags_obj):
69
70
71
72
73
74
75
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
76
    NotImplementedError: If some features are not currently supported.
77
78
79

  Returns:
    Dictionary of training and eval stats.
80
  """
Toby Boyd's avatar
Toby Boyd committed
81
  keras_utils.set_session_config(
82
      enable_xla=flags_obj.enable_xla)
83
84
  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
85
86
87
88
89
    keras_utils.set_gpu_thread_mode_and_count(
        per_gpu_thread_count=flags_obj.per_gpu_thread_count,
        gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
        num_gpus=flags_obj.num_gpus,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads)
90
  common.set_cudnn_batchnorm_mode()
91

92
  dtype = flags_core.get_tf_dtype(flags_obj)
93
94
95
  performance.set_mixed_precision_policy(
      flags_core.get_tf_dtype(flags_obj),
      flags_core.get_loss_scale(flags_obj, default_for_fp16=128))
96

97
98
  data_format = flags_obj.data_format
  if data_format is None:
99
100
    data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
                   else 'channels_last')
101
  tf.keras.backend.set_image_data_format(data_format)
102

103
  # Configures cluster spec for distribution strategy.
104
105
  _ = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                           flags_obj.task_index)
106

107
108
109
  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
110
      all_reduce_alg=flags_obj.all_reduce_alg,
111
112
      num_packs=flags_obj.num_packs,
      tpu_address=flags_obj.tpu)
113

rxsang's avatar
rxsang committed
114
115
116
117
118
119
120
  if strategy:
    # flags_obj.enable_get_next_as_optional controls whether enabling
    # get_next_as_optional behavior in DistributedIterator. If true, last
    # partial batch can be supported.
    strategy.extended.experimental_enable_get_next_as_optional = (
        flags_obj.enable_get_next_as_optional
    )
121

122
  strategy_scope = distribution_utils.get_strategy_scope(strategy)
123

124
125
  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
126
    input_fn = common.get_synth_input_fn(
127
128
129
130
        height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_preprocessing.NUM_CHANNELS,
        num_classes=imagenet_preprocessing.NUM_CLASSES,
131
132
        dtype=dtype,
        drop_remainder=True)
133
  else:
134
    input_fn = imagenet_preprocessing.input_fn
135

136
137
138
139
  # When `enable_xla` is True, we always drop the remainder of the batches
  # in the dataset, as XLA-GPU doesn't support dynamic shapes.
  drop_remainder = flags_obj.enable_xla

140
141
142
143
144
145
  # Current resnet_model.resnet50 input format is always channel-last.
  # We use keras_application mobilenet model which input format is depends on
  # the keras beckend image data format.
  # This use_keras_image_data_format flags indicates whether image preprocessor
  # output format should be same as the keras backend image data format or just
  # channel-last format.
Ruomei Yan's avatar
Ruomei Yan committed
146
147
  use_keras_image_data_format = \
    (flags_obj.model == 'mobilenet' or 'mobilenet_pretrained')
148
149
150
151
  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
152
153
      parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
          use_keras_image_data_format=use_keras_image_data_format),
Reed's avatar
Reed committed
154
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
155
      dtype=dtype,
156
157
      drop_remainder=drop_remainder,
      tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
158
      training_dataset_cache=flags_obj.training_dataset_cache,
159
  )
160

161
162
163
164
165
166
  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
167
168
        parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
            use_keras_image_data_format=use_keras_image_data_format),
169
170
        dtype=dtype,
        drop_remainder=drop_remainder)
171

172
173
174
175
176
177
178
  lr_schedule = common.PiecewiseConstantDecayWithWarmup(
      batch_size=flags_obj.batch_size,
      epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
      warmup_epochs=common.LR_SCHEDULE[0][1],
      boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
      multipliers=list(p[0] for p in common.LR_SCHEDULE),
      compute_lr_on_cpu=True)
Ruomei Yan's avatar
Ruomei Yan committed
179
180
  steps_per_epoch = (
      imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
181

Shining Sun's avatar
Shining Sun committed
182
  with strategy_scope:
183
184
    if flags_obj.optimizer == 'resnet50_default':
      optimizer = common.get_optimizer(lr_schedule)
Ruomei Yan's avatar
Ruomei Yan committed
185
    elif flags_obj.optimizer == 'mobilenet_default' or 'mobilenet_fine_tune':
186
      initial_learning_rate = \
Jaehong Kim's avatar
Jaehong Kim committed
187
          flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size
Ruomei Yan's avatar
Ruomei Yan committed
188
189
      if flags_obj.optimizer == 'mobilenet_fine_tune':
        initial_learning_rate = 1e-5
190
191
192
      optimizer = tf.keras.optimizers.SGD(
          learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
              initial_learning_rate,
Jaehong Kim's avatar
Jaehong Kim committed
193
194
              decay_steps=steps_per_epoch * flags_obj.num_epochs_per_decay,
              decay_rate=flags_obj.lr_decay_factor,
195
196
              staircase=True),
          momentum=0.9)
197

198
    if flags_obj.fp16_implementation == 'graph_rewrite':
199
200
201
202
203
204
205
      # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
      # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
      # which will ensure tf.compat.v2.keras.mixed_precision and
      # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
      # up.
      optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
          optimizer)
206

Hongkun Yu's avatar
Hongkun Yu committed
207
    # TODO(hongkuny): Remove trivial model usage and move it to benchmark.
Haoyu Zhang's avatar
Haoyu Zhang committed
208
    if flags_obj.use_trivial_model:
Allen Wang's avatar
Allen Wang committed
209
      model = test_utils.trivial_model(imagenet_preprocessing.NUM_CLASSES)
210
    elif flags_obj.model == 'resnet50_v1.5':
211
      model = resnet_model.resnet50(
212
          num_classes=imagenet_preprocessing.NUM_CLASSES)
Ruomei Yan's avatar
Ruomei Yan committed
213
    elif flags_obj.model == 'mobilenet' or 'mobilenet_pretrained':
214
215
      # TODO(kimjaehong): Remove layers attribute when minimum TF version
      # support 2.0 layers by default.
Ruomei Yan's avatar
Ruomei Yan committed
216
217
218
219
220
221
      if flags_obj.model == 'mobilenet_pretrained':
        classes_labels = 1000
        initial_weights = 'imagenet'
      else:
        classes_labels = imagenet_preprocessing.NUM_CLASSES
        initial_weights = None
222
      model = tf.keras.applications.mobilenet.MobileNet(
Ruomei Yan's avatar
Ruomei Yan committed
223
224
          weights=initial_weights,
          classes=classes_labels,
225
226
          layers=tf.keras.layers)

227
228
229
230
    if flags_obj.pretrained_filepath:
      model.load_weights(flags_obj.pretrained_filepath)

    if flags_obj.pruning_method == 'polynomial_decay':
Ruomei Yan's avatar
Ruomei Yan committed
231
      import tensorflow_model_optimization as tfmot
232
233
234
235
236
237
238
239
240
241
242
243
244
245
      if dtype != tf.float32:
        raise NotImplementedError(
            'Pruning is currently only supported on dtype=tf.float32.')
      pruning_params = {
          'pruning_schedule':
              tfmot.sparsity.keras.PolynomialDecay(
                  initial_sparsity=flags_obj.pruning_initial_sparsity,
                  final_sparsity=flags_obj.pruning_final_sparsity,
                  begin_step=flags_obj.pruning_begin_step,
                  end_step=flags_obj.pruning_end_step,
                  frequency=flags_obj.pruning_frequency),
      }
      model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)
    elif flags_obj.pruning_method:
246
247
248
      raise NotImplementedError('Only polynomial_decay is currently supported.')

    if flags_obj.clustering_method == 'selective_clustering':
Ruomei Yan's avatar
Ruomei Yan committed
249
250
251
      import tensorflow_model_optimization as tfmot
      if dtype != tf.float32 or \
        flags_obj.fp16_implementation == 'graph_rewrite':
252
253
        raise NotImplementedError(
            'Clustering is currently only supported on dtype=tf.float32.')
254
      model = cluster_last_three_conv2d_layers(model)
255
    elif flags_obj.clustering_method:
256
      raise NotImplementedError(
257
          'Only selective_clustering is implemented.')
258

Ruomei Yan's avatar
Ruomei Yan committed
259
260
261
262
263
264
    model.compile(
        loss='sparse_categorical_crossentropy',
        optimizer=optimizer,
        metrics=(['sparse_categorical_accuracy']
                 if flags_obj.report_accuracy_metrics else None),
        run_eagerly=flags_obj.run_eagerly)
Shining Sun's avatar
Shining Sun committed
265

Zongwei Zhou's avatar
Zongwei Zhou committed
266
267
  train_epochs = flags_obj.train_epochs

268
269
270
271
  callbacks = common.get_callbacks(
      pruning_method=flags_obj.pruning_method,
      enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export,
      model_dir=flags_obj.model_dir)
Shining Sun's avatar
Shining Sun committed
272

273
  # If mutliple epochs, ignore the train_steps flag.
274
  if train_epochs <= 1 and flags_obj.train_steps:
Zongwei Zhou's avatar
Zongwei Zhou committed
275
    steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
Shining Sun's avatar
Shining Sun committed
276
277
    train_epochs = 1

Ruomei Yan's avatar
Ruomei Yan committed
278
279
  num_eval_steps = (
      imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size)
Shining Sun's avatar
Shining Sun committed
280
281
282

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
283
284
285
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
286
287
288
289
    if flags_obj.set_learning_phase_to_train:
      # TODO(haoyuzhang): Understand slowdown of setting learning phase when
      # not using distribution strategy.
      tf.keras.backend.set_learning_phase(1)
Shining Sun's avatar
Shining Sun committed
290
291
292
    num_eval_steps = None
    validation_data = None

293
294
  if not strategy and flags_obj.explicit_gpu_placement:
    # TODO(b/135607227): Add device scope automatically in Keras training loop
295
    # when not using distribution strategy.
296
297
298
    no_dist_strat_device = tf.device('/device:GPU:0')
    no_dist_strat_device.__enter__()

299
300
  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
Zongwei Zhou's avatar
Zongwei Zhou committed
301
                      steps_per_epoch=steps_per_epoch,
302
                      callbacks=callbacks,
303
304
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
305
                      validation_freq=flags_obj.epochs_between_evals,
306
                      verbose=2)
Jaehong Kim's avatar
Jaehong Kim committed
307

Hongkun Yu's avatar
Hongkun Yu committed
308
309
310
311
312
313
  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)

314
315
  if flags_obj.pruning_method:
    model = tfmot.sparsity.keras.strip_pruning(model)
316
317

  if flags_obj.clustering_method:
318
    model = tfmot.clustering.keras.strip_clustering(model)
319

320
321
322
323
324
325
326
327
  if flags_obj.enable_checkpoint_and_export:
    if dtype == tf.bfloat16:
      logging.warning('Keras model.save does not support bfloat16 dtype.')
    else:
      # Keras model.save assumes a float32 input designature.
      export_path = os.path.join(flags_obj.model_dir, 'saved_model')
      model.save(export_path, include_optimizer=False)

328
329
330
  if not strategy and flags_obj.explicit_gpu_placement:
    no_dist_strat_device.__exit__()

331
  stats = common.build_stats(history, eval_output, callbacks)
332
  return stats
Shining Sun's avatar
bug fix  
Shining Sun committed
333

Shining Sun's avatar
Shining Sun committed
334

Toby Boyd's avatar
Toby Boyd committed
335
def define_imagenet_keras_flags():
Ruomei Yan's avatar
Ruomei Yan committed
336
337
338
339
  common.define_keras_flags(
      model=True,
      optimizer=True,
      pretrained_filepath=True)
340
  common.define_pruning_flags()
341
  common.define_clustering_flags()
342
  flags_core.set_defaults()
343
  flags.adopt_module_key_flags(common)
Toby Boyd's avatar
Toby Boyd committed
344
345


346
def main(_):
347
  model_helpers.apply_clean(flags.FLAGS)
348
  stats = run(flags.FLAGS)
349
  logging.info('Run stats:\n%s', stats)
350
351
352


if __name__ == '__main__':
353
  logging.set_verbosity(logging.INFO)
Toby Boyd's avatar
Toby Boyd committed
354
  define_imagenet_keras_flags()
Ruomei Yan's avatar
Ruomei Yan committed
355
  app.run(main)