"src/targets/gpu/kernels/resource.rc" did not exist on "216f96623920a0de313cc3cc8b3abe63aa50ded4"
resnet_cifar_main.py 8.77 KB
Newer Older
Shining Sun's avatar
Shining Sun committed
1
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Shining Sun's avatar
Shining Sun committed
15
"""Runs a ResNet model on the Cifar-10 dataset."""
16
17
18
19
20

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

21
from absl import app as absl_app
22
from absl import flags
23
import tensorflow as tf
24
from official.benchmark.models import resnet_cifar_model
25
26
27
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import distribution_utils
Toby Boyd's avatar
Toby Boyd committed
28
from official.utils.misc import keras_utils
29
30
from official.vision.image_classification import cifar_preprocessing
from official.vision.image_classification import common
31
32


33
34
LR_SCHEDULE = [  # (multiplier, epoch to start) tuples
    (0.1, 91), (0.01, 136), (0.001, 182)
35
36
]

37

38
39
40
41
def learning_rate_schedule(current_epoch,
                           current_batch,
                           batches_per_epoch,
                           batch_size):
Shining Sun's avatar
Shining Sun committed
42
  """Handles linear scaling rule and LR decay.
43

44
45
  Scale learning rate at epoch boundaries provided in LR_SCHEDULE by the
  provided scaling factor.
46
47
48
49

  Args:
    current_epoch: integer, current epoch indexed from 0.
    current_batch: integer, current batch in the current epoch, indexed from 0.
50
51
    batches_per_epoch: integer, number of steps in an epoch.
    batch_size: integer, total batch sized.
52
53
54
55

  Returns:
    Adjusted learning rate.
  """
56
  del current_batch, batches_per_epoch  # not used
57
  initial_learning_rate = common.BASE_LEARNING_RATE * batch_size / 128
58
  learning_rate = initial_learning_rate
59
  for mult, start_epoch in LR_SCHEDULE:
60
61
    if current_epoch >= start_epoch:
      learning_rate = initial_learning_rate * mult
62
63
64
65
66
    else:
      break
  return learning_rate


Shining Sun's avatar
Shining Sun committed
67
68
def run(flags_obj):
  """Run ResNet Cifar-10 training and eval loop using native Keras APIs.
69
70
71
72
73
74

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
75
76
77

  Returns:
    Dictionary of training and eval stats.
78
  """
79
80
  keras_utils.set_session_config(
      enable_eager=flags_obj.enable_eager,
81
      enable_xla=flags_obj.enable_xla)
82
83
84

  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
85
86
87
88
89
    keras_utils.set_gpu_thread_mode_and_count(
        per_gpu_thread_count=flags_obj.per_gpu_thread_count,
        gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
        num_gpus=flags_obj.num_gpus,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads)
90
  common.set_cudnn_batchnorm_mode()
91

92
93
94
95
96
  dtype = flags_core.get_tf_dtype(flags_obj)
  if dtype == 'fp16':
    raise ValueError('dtype fp16 is not supported in Keras. Use the default '
                     'value(fp32).')

97
98
99
100
101
  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)
102

103
104
  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
105
106
107
      num_gpus=flags_obj.num_gpus,
      all_reduce_alg=flags_obj.all_reduce_alg,
      num_packs=flags_obj.num_packs)
108

109
110
111
112
113
114
115
116
  if strategy:
    # flags_obj.enable_get_next_as_optional controls whether enabling
    # get_next_as_optional behavior in DistributedIterator. If true, last
    # partial batch can be supported.
    strategy.extended.experimental_enable_get_next_as_optional = (
        flags_obj.enable_get_next_as_optional
    )

117
  strategy_scope = distribution_utils.get_strategy_scope(strategy)
118

119
  if flags_obj.use_synthetic_data:
120
    distribution_utils.set_up_synthetic_data()
121
    input_fn = common.get_synth_input_fn(
122
123
124
125
        height=cifar_preprocessing.HEIGHT,
        width=cifar_preprocessing.WIDTH,
        num_channels=cifar_preprocessing.NUM_CHANNELS,
        num_classes=cifar_preprocessing.NUM_CLASSES,
126
127
        dtype=flags_core.get_tf_dtype(flags_obj),
        drop_remainder=True)
128
  else:
129
    distribution_utils.undo_set_up_synthetic_data()
130
    input_fn = cifar_preprocessing.input_fn
Shining Sun's avatar
Shining Sun committed
131
132
133
134

  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
135
      batch_size=flags_obj.batch_size,
Shining Sun's avatar
Shining Sun committed
136
      num_epochs=flags_obj.train_epochs,
137
      parse_record_fn=cifar_preprocessing.parse_record,
138
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
139
140
141
142
143
      dtype=dtype,
      # Setting drop_remainder to avoid the partial batch logic in normalization
      # layer, which triggers tf.where and leads to extra memory copy of input
      # sizes between host and GPU.
      drop_remainder=(not flags_obj.enable_get_next_as_optional))
144
145
146
147
148
149
150
151

  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
152
        parse_record_fn=cifar_preprocessing.parse_record)
153

Shining Sun's avatar
Shining Sun committed
154
  with strategy_scope:
155
    optimizer = common.get_optimizer()
156
    model = resnet_cifar_model.resnet56(classes=cifar_preprocessing.NUM_CLASSES)
Shining Sun's avatar
Shining Sun committed
157

158
159
160
161
    # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
    # a valid arg for this model. Also remove as a valid flag.
    if flags_obj.force_v2_in_keras_compile is not None:
      model.compile(
Pavithra Vijay's avatar
Pavithra Vijay committed
162
          loss='sparse_categorical_crossentropy',
163
          optimizer=optimizer,
Pavithra Vijay's avatar
Pavithra Vijay committed
164
          metrics=(['sparse_categorical_accuracy']
165
166
167
168
169
                   if flags_obj.report_accuracy_metrics else None),
          run_eagerly=flags_obj.run_eagerly,
          experimental_run_tf_function=flags_obj.force_v2_in_keras_compile)
    else:
      model.compile(
Pavithra Vijay's avatar
Pavithra Vijay committed
170
          loss='sparse_categorical_crossentropy',
171
          optimizer=optimizer,
Pavithra Vijay's avatar
Pavithra Vijay committed
172
          metrics=(['sparse_categorical_accuracy']
173
174
                   if flags_obj.report_accuracy_metrics else None),
          run_eagerly=flags_obj.run_eagerly)
Shining Sun's avatar
Shining Sun committed
175

Zongwei Zhou's avatar
Zongwei Zhou committed
176
177
  steps_per_epoch = (
      cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
Shining Sun's avatar
Shining Sun committed
178
179
  train_epochs = flags_obj.train_epochs

Zongwei Zhou's avatar
Zongwei Zhou committed
180
181
182
183
184
  callbacks = common.get_callbacks(steps_per_epoch, learning_rate_schedule)

  # if mutliple epochs, ignore the train_steps flag.
  if train_epochs <= 1 and flags_obj.train_steps:
    steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
Shining Sun's avatar
Shining Sun committed
185
186
    train_epochs = 1

187
  num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] //
188
189
                    flags_obj.batch_size)

Shining Sun's avatar
Shining Sun committed
190
191
  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
192
193
194
195
    if flags_obj.set_learning_phase_to_train:
      # TODO(haoyuzhang): Understand slowdown of setting learning phase when
      # not using distribution strategy.
      tf.keras.backend.set_learning_phase(1)
Shining Sun's avatar
Shining Sun committed
196
197
198
    num_eval_steps = None
    validation_data = None

199
200
201
202
203
204
  if not strategy and flags_obj.explicit_gpu_placement:
    # TODO(b/135607227): Add device scope automatically in Keras training loop
    # when not using distribition strategy.
    no_dist_strat_device = tf.device('/device:GPU:0')
    no_dist_strat_device.__enter__()

205
  history = model.fit(train_input_dataset,
206
                      epochs=train_epochs,
Zongwei Zhou's avatar
Zongwei Zhou committed
207
                      steps_per_epoch=steps_per_epoch,
208
                      callbacks=callbacks,
209
210
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
211
                      validation_freq=flags_obj.epochs_between_evals,
212
                      verbose=2)
213
  eval_output = None
214
  if not flags_obj.skip_eval:
Shining Sun's avatar
Shining Sun committed
215
216
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
217
                                 verbose=2)
218
219
220
221

  if not strategy and flags_obj.explicit_gpu_placement:
    no_dist_strat_device.__exit__()

222
  stats = common.build_stats(history, eval_output, callbacks)
223
  return stats
224

225

226
def define_cifar_flags():
227
  common.define_keras_flags(dynamic_loss_scale=False)
228
229
230
231
232
233
234

  flags_core.set_defaults(data_dir='/tmp/cifar10_data/cifar-10-batches-bin',
                          model_dir='/tmp/cifar10_model',
                          epochs_between_evals=10,
                          batch_size=128)


235
def main(_):
236
  with logger.benchmark_context(flags.FLAGS):
237
    return run(flags.FLAGS)
238
239
240


if __name__ == '__main__':
241
  tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
242
  define_cifar_flags()
243
  absl_app.run(main)