resnet_imagenet_main.py 14 KB
Newer Older
1
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a ResNet model on the ImageNet dataset."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

Hongkun Yu's avatar
Hongkun Yu committed
21
22
import os

Hongkun Yu's avatar
Hongkun Yu committed
23
# Import libraries
Hongkun Yu's avatar
Hongkun Yu committed
24
from absl import app
25
from absl import flags
26
from absl import logging
27
import tensorflow as tf
28

29

30
import tensorflow_model_optimization as tfmot
31
from official.modeling import performance
32
33
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
Toby Boyd's avatar
Toby Boyd committed
34
from official.utils.misc import keras_utils
35
from official.utils.misc import model_helpers
Allen Wang's avatar
Allen Wang committed
36
from official.vision.image_classification import test_utils
37
38
from official.vision.image_classification.resnet import common
from official.vision.image_classification.resnet import imagenet_preprocessing
39
from official.vision.image_classification.resnet import resnet_model
40
41


42
43
44
def cluster_last_three_conv2d_layers(model):
  last_three_conv2d_layers =  [
      layer for layer in model.layers
45
46
      if isinstance(layer, tf.keras.layers.Conv2D) and
      not isinstance(layer, tf.keras.layers.DepthwiseConv2D)
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    ]
  last_three_conv2d_layers = last_three_conv2d_layers[-3:]

  cluster_weights = tfmot.clustering.keras.cluster_weights
  CentroidInitialization = tfmot.clustering.keras.CentroidInitialization
  clustering_params1 = {
      'number_of_clusters': 256,
      'cluster_centroids_init': CentroidInitialization.LINEAR
  }
  clustering_params2 = {
      'number_of_clusters': 32,
      'cluster_centroids_init': CentroidInitialization.LINEAR
  }

  def cluster_fn(layer):
    if layer not in last_three_conv2d_layers:
        return layer

    if layer == last_three_conv2d_layers[0] or layer == last_three_conv2d_layers[1]:
        clustered = cluster_weights(layer, **clustering_params1)
        print("Clustered {} with {} clusters".format(layer.name, clustering_params1['number_of_clusters']))
    else:
        clustered = cluster_weights(layer, **clustering_params2)
        print("Clustered {} with {} clusters".format(layer.name, clustering_params2['number_of_clusters']))
    return clustered
72

73
  return tf.keras.models.clone_model(model, clone_function=cluster_fn)
74
75


Shining Sun's avatar
Shining Sun committed
76
def run(flags_obj):
77
78
79
80
81
82
83
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
84
    NotImplementedError: If some features are not currently supported.
85
86
87

  Returns:
    Dictionary of training and eval stats.
88
  """
Toby Boyd's avatar
Toby Boyd committed
89
  keras_utils.set_session_config(
90
      enable_xla=flags_obj.enable_xla)
91
92
  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
93
94
95
96
97
    keras_utils.set_gpu_thread_mode_and_count(
        per_gpu_thread_count=flags_obj.per_gpu_thread_count,
        gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
        num_gpus=flags_obj.num_gpus,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads)
98
  common.set_cudnn_batchnorm_mode()
99

100
  dtype = flags_core.get_tf_dtype(flags_obj)
101
102
103
  performance.set_mixed_precision_policy(
      flags_core.get_tf_dtype(flags_obj),
      flags_core.get_loss_scale(flags_obj, default_for_fp16=128))
104

105
106
  data_format = flags_obj.data_format
  if data_format is None:
107
108
    data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
                   else 'channels_last')
109
  tf.keras.backend.set_image_data_format(data_format)
110

111
  # Configures cluster spec for distribution strategy.
112
113
  _ = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                           flags_obj.task_index)
114

115
116
117
  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_obj.num_gpus,
118
      all_reduce_alg=flags_obj.all_reduce_alg,
119
120
      num_packs=flags_obj.num_packs,
      tpu_address=flags_obj.tpu)
121

rxsang's avatar
rxsang committed
122
123
124
125
126
127
128
  if strategy:
    # flags_obj.enable_get_next_as_optional controls whether enabling
    # get_next_as_optional behavior in DistributedIterator. If true, last
    # partial batch can be supported.
    strategy.extended.experimental_enable_get_next_as_optional = (
        flags_obj.enable_get_next_as_optional
    )
129

130
  strategy_scope = distribution_utils.get_strategy_scope(strategy)
131

132
133
  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
134
    input_fn = common.get_synth_input_fn(
135
136
137
138
        height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_preprocessing.NUM_CHANNELS,
        num_classes=imagenet_preprocessing.NUM_CLASSES,
139
140
        dtype=dtype,
        drop_remainder=True)
141
  else:
142
    input_fn = imagenet_preprocessing.input_fn
143

144
145
146
147
  # When `enable_xla` is True, we always drop the remainder of the batches
  # in the dataset, as XLA-GPU doesn't support dynamic shapes.
  drop_remainder = flags_obj.enable_xla

148
149
150
151
152
153
  # Current resnet_model.resnet50 input format is always channel-last.
  # We use keras_application mobilenet model which input format is depends on
  # the keras beckend image data format.
  # This use_keras_image_data_format flags indicates whether image preprocessor
  # output format should be same as the keras backend image data format or just
  # channel-last format.
154
  use_keras_image_data_format = (flags_obj.model == 'mobilenet' or 'mobilenet_pretrained')
155
156
157
158
  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
159
160
      parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
          use_keras_image_data_format=use_keras_image_data_format),
Reed's avatar
Reed committed
161
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
162
      dtype=dtype,
163
164
      drop_remainder=drop_remainder,
      tf_data_experimental_slack=flags_obj.tf_data_experimental_slack,
165
      training_dataset_cache=flags_obj.training_dataset_cache,
166
  )
167

168
169
170
171
172
173
  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
174
175
        parse_record_fn=imagenet_preprocessing.get_parse_record_fn(
            use_keras_image_data_format=use_keras_image_data_format),
176
177
        dtype=dtype,
        drop_remainder=drop_remainder)
178

179
180
181
182
183
184
185
  lr_schedule = common.PiecewiseConstantDecayWithWarmup(
      batch_size=flags_obj.batch_size,
      epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
      warmup_epochs=common.LR_SCHEDULE[0][1],
      boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
      multipliers=list(p[0] for p in common.LR_SCHEDULE),
      compute_lr_on_cpu=True)
186
187
  steps_per_epoch = (imagenet_preprocessing.NUM_IMAGES['train'] //
                     flags_obj.batch_size)
188

Shining Sun's avatar
Shining Sun committed
189
  with strategy_scope:
190
191
192
193
    if flags_obj.optimizer == 'resnet50_default':
      optimizer = common.get_optimizer(lr_schedule)
    elif flags_obj.optimizer == 'mobilenet_default':
      initial_learning_rate = \
Jaehong Kim's avatar
Jaehong Kim committed
194
          flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size
195
196
197
      optimizer = tf.keras.optimizers.SGD(
          learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
              initial_learning_rate,
Jaehong Kim's avatar
Jaehong Kim committed
198
199
              decay_steps=steps_per_epoch * flags_obj.num_epochs_per_decay,
              decay_rate=flags_obj.lr_decay_factor,
200
201
              staircase=True),
          momentum=0.9)
202
203
204
    elif flags_obj.optimizer == 'mobilenet_fine_tune':
      optimizer = tf.keras.optimizers.SGD(learning_rate=1e-5, momentum=0.9)

205
    if flags_obj.fp16_implementation == 'graph_rewrite':
206
207
208
209
210
211
212
      # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
      # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
      # which will ensure tf.compat.v2.keras.mixed_precision and
      # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
      # up.
      optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
          optimizer)
213

Hongkun Yu's avatar
Hongkun Yu committed
214
    # TODO(hongkuny): Remove trivial model usage and move it to benchmark.
Haoyu Zhang's avatar
Haoyu Zhang committed
215
    if flags_obj.use_trivial_model:
Allen Wang's avatar
Allen Wang committed
216
      model = test_utils.trivial_model(imagenet_preprocessing.NUM_CLASSES)
217
    elif flags_obj.model == 'resnet50_v1.5':
218
      model = resnet_model.resnet50(
219
          num_classes=imagenet_preprocessing.NUM_CLASSES)
220
221
222
223
224
225
226
    elif flags_obj.model == 'mobilenet':
      # TODO(kimjaehong): Remove layers attribute when minimum TF version
      # support 2.0 layers by default.
      model = tf.keras.applications.mobilenet.MobileNet(
          weights=None,
          classes=imagenet_preprocessing.NUM_CLASSES,
          layers=tf.keras.layers)
227
228
229
230
231
232
233
    elif flags_obj.model == 'mobilenet_pretrained':
      model = tf.keras.applications.mobilenet.MobileNet(
          dropout=1e-7,
          weights='imagenet',
          classes=1000,
          layers=tf.keras.layers)

234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
    if flags_obj.pretrained_filepath:
      model.load_weights(flags_obj.pretrained_filepath)

    if flags_obj.pruning_method == 'polynomial_decay':
      if dtype != tf.float32:
        raise NotImplementedError(
            'Pruning is currently only supported on dtype=tf.float32.')
      pruning_params = {
          'pruning_schedule':
              tfmot.sparsity.keras.PolynomialDecay(
                  initial_sparsity=flags_obj.pruning_initial_sparsity,
                  final_sparsity=flags_obj.pruning_final_sparsity,
                  begin_step=flags_obj.pruning_begin_step,
                  end_step=flags_obj.pruning_end_step,
                  frequency=flags_obj.pruning_frequency),
      }
      model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)
    elif flags_obj.pruning_method:
252
253
254
      raise NotImplementedError('Only polynomial_decay is currently supported.')

    if flags_obj.clustering_method == 'selective_clustering':
255
      if dtype != tf.float32 or flags_obj.fp16_implementation == 'graph_rewrite':
256
257
        raise NotImplementedError(
            'Clustering is currently only supported on dtype=tf.float32.')
258
      model = cluster_last_three_conv2d_layers(model)
259
    elif flags_obj.clustering_method:
260
      raise NotImplementedError(
261
          'Only selective_clustering is implemented.')
262

263
264
265
266
267
    model.compile(loss='sparse_categorical_crossentropy',
                  optimizer=optimizer,
                  metrics=(['sparse_categorical_accuracy']
                          if flags_obj.report_accuracy_metrics else None),
                  run_eagerly=flags_obj.run_eagerly)
Shining Sun's avatar
Shining Sun committed
268

Zongwei Zhou's avatar
Zongwei Zhou committed
269
270
  train_epochs = flags_obj.train_epochs

271
272
273
274
  callbacks = common.get_callbacks(
      pruning_method=flags_obj.pruning_method,
      enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export,
      model_dir=flags_obj.model_dir)
Shining Sun's avatar
Shining Sun committed
275

276
  # If mutliple epochs, ignore the train_steps flag.
277
  if train_epochs <= 1 and flags_obj.train_steps:
Zongwei Zhou's avatar
Zongwei Zhou committed
278
    steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
Shining Sun's avatar
Shining Sun committed
279
280
    train_epochs = 1

281
282
  num_eval_steps = (imagenet_preprocessing.NUM_IMAGES['validation'] //
                    flags_obj.batch_size)
Shining Sun's avatar
Shining Sun committed
283
284
285

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
286
287
288
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
289
290
291
292
    if flags_obj.set_learning_phase_to_train:
      # TODO(haoyuzhang): Understand slowdown of setting learning phase when
      # not using distribution strategy.
      tf.keras.backend.set_learning_phase(1)
Shining Sun's avatar
Shining Sun committed
293
294
295
    num_eval_steps = None
    validation_data = None

296
297
  if not strategy and flags_obj.explicit_gpu_placement:
    # TODO(b/135607227): Add device scope automatically in Keras training loop
298
    # when not using distribution strategy.
299
300
301
    no_dist_strat_device = tf.device('/device:GPU:0')
    no_dist_strat_device.__enter__()

302
303
  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
Zongwei Zhou's avatar
Zongwei Zhou committed
304
                      steps_per_epoch=steps_per_epoch,
305
                      callbacks=callbacks,
306
307
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
308
                      validation_freq=flags_obj.epochs_between_evals,
309
                      verbose=2)
Jaehong Kim's avatar
Jaehong Kim committed
310

Hongkun Yu's avatar
Hongkun Yu committed
311
312
313
314
315
316
  eval_output = None
  if not flags_obj.skip_eval:
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)

317
318
  if flags_obj.pruning_method:
    model = tfmot.sparsity.keras.strip_pruning(model)
319
320

  if flags_obj.clustering_method:
321
    model = tfmot.clustering.keras.strip_clustering(model)
322

323
324
325
326
327
328
329
330
  if flags_obj.enable_checkpoint_and_export:
    if dtype == tf.bfloat16:
      logging.warning('Keras model.save does not support bfloat16 dtype.')
    else:
      # Keras model.save assumes a float32 input designature.
      export_path = os.path.join(flags_obj.model_dir, 'saved_model')
      model.save(export_path, include_optimizer=False)

331
332
333
  if not strategy and flags_obj.explicit_gpu_placement:
    no_dist_strat_device.__exit__()

334
  stats = common.build_stats(history, eval_output, callbacks)
335
  return stats
Shining Sun's avatar
bug fix  
Shining Sun committed
336

Shining Sun's avatar
Shining Sun committed
337

Toby Boyd's avatar
Toby Boyd committed
338
def define_imagenet_keras_flags():
339
340
341
  common.define_keras_flags(model=True,
                            optimizer=True,
                            pretrained_filepath=True)
342
  common.define_pruning_flags()
343
  common.define_clustering_flags()
344
  flags_core.set_defaults()
345
  flags.adopt_module_key_flags(common)
Toby Boyd's avatar
Toby Boyd committed
346
347


348
def main(_):
349
  model_helpers.apply_clean(flags.FLAGS)
350
  stats = run(flags.FLAGS)
351
  logging.info('Run stats:\n%s', stats)
352
353
354


if __name__ == '__main__':
355
  logging.set_verbosity(logging.INFO)
Toby Boyd's avatar
Toby Boyd committed
356
  define_imagenet_keras_flags()
357
  app.run(main)