"examples/gpt3/train_gpt3_175b_distributed.sh" did not exist on "3aca141586a4b8cdc983c3ecf5f7baf60506c7f8"
keras_imagenet_main.py 8.06 KB
Newer Older
1
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a ResNet model on the ImageNet dataset."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl import app as absl_app
from absl import flags
import tensorflow as tf  # pylint: disable=g-bad-import-order

from official.resnet import imagenet_main
26
from official.resnet.keras import keras_common
Shining Sun's avatar
Shining Sun committed
27
from official.resnet.keras import resnet_model
28
29
30
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import distribution_utils
31
from official.utils.misc import model_helpers
32
33
34
35
36


LR_SCHEDULE = [    # (multiplier, epoch to start) tuples
    (1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
]
Shining Sun's avatar
Shining Sun committed
37

38

Toby Boyd's avatar
Toby Boyd committed
39
40
41
42
def learning_rate_schedule(current_epoch,
                           current_batch,
                           batches_per_epoch,
                           batch_size):
43
44
  """Handles linear scaling rule, gradual warmup, and LR decay.

Toby Boyd's avatar
Toby Boyd committed
45
46
  Scale learning rate at epoch boundaries provided in LR_SCHEDULE by the
  provided scaling factor.
47
48
49
50

  Args:
    current_epoch: integer, current epoch indexed from 0.
    current_batch: integer, current batch in the current epoch, indexed from 0.
Toby Boyd's avatar
Toby Boyd committed
51
52
    batches_per_epoch: integer, number of steps in an epoch.
    batch_size: integer, total batch sized.
53
54
55
56

  Returns:
    Adjusted learning rate.
  """
Toby Boyd's avatar
Toby Boyd committed
57
  initial_lr = keras_common.BASE_LEARNING_RATE * batch_size / 256
58
59
60
61
  epoch = current_epoch + float(current_batch) / batches_per_epoch
  warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0]
  if epoch < warmup_end_epoch:
    # Learning rate increases linearly per step.
Toby Boyd's avatar
Toby Boyd committed
62
    return initial_lr * warmup_lr_multiplier * epoch / warmup_end_epoch
63
64
  for mult, start_epoch in LR_SCHEDULE:
    if epoch >= start_epoch:
Toby Boyd's avatar
Toby Boyd committed
65
      learning_rate = initial_lr * mult
66
67
68
69
70
71
    else:
      break
  return learning_rate


def parse_record_keras(raw_record, is_training, dtype):
Shining Sun's avatar
Shining Sun committed
72
  """Adjust the shape of label."""
Shining Sun's avatar
Shining Sun committed
73
  image, label = imagenet_main.parse_record(raw_record, is_training, dtype)
Shining Sun's avatar
Shining Sun committed
74

Shining Sun's avatar
Shining Sun committed
75
76
77
  # Subtract one so that labels are in [0, 1000), and cast to float32 for
  # Keras model.
  label = tf.cast(tf.cast(tf.reshape(label, shape=[1]), dtype=tf.int32) - 1,
Toby Boyd's avatar
Toby Boyd committed
78
                  dtype=tf.float32)
79
80
81
  return image, label


Shining Sun's avatar
Shining Sun committed
82
def run(flags_obj):
83
84
85
86
87
88
89
  """Run ResNet ImageNet training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.

  Raises:
    ValueError: If fp16 is passed as it is not currently supported.
90
91
92

  Returns:
    Dictionary of training and eval stats.
93
  """
94
95
  # TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends.
  # Eager is default in tf 2.0 and should not be toggled
96
97
98
99
  if keras_common.is_v2_0():
    keras_common.set_config_v2()
  else:
    config = keras_common.get_config_proto_v1()
100
101
102
103
104
    if flags_obj.enable_eager:
      tf.compat.v1.enable_eager_execution(config=config)
    else:
      sess = tf.Session(config=config)
      tf.keras.backend.set_session(sess)
Shining Sun's avatar
Shining Sun committed
105

106
107
108
109
  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
    keras_common.set_gpu_thread_mode_and_count(flags_obj)

110
  dtype = flags_core.get_tf_dtype(flags_obj)
Reed's avatar
Reed committed
111
112
113
  if dtype == 'float16':
    policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars')
    tf.keras.mixed_precision.experimental.set_policy(policy)
114

115
116
117
118
119
  data_format = flags_obj.data_format
  if data_format is None:
    data_format = ('channels_first'
                   if tf.test.is_built_with_cuda() else 'channels_last')
  tf.keras.backend.set_image_data_format(data_format)
120

121
122
  # pylint: disable=protected-access
  if flags_obj.use_synthetic_data:
123
    distribution_utils.set_up_synthetic_data()
Shining Sun's avatar
Shining Sun committed
124
    input_fn = keras_common.get_synth_input_fn(
125
126
127
128
        height=imagenet_main.DEFAULT_IMAGE_SIZE,
        width=imagenet_main.DEFAULT_IMAGE_SIZE,
        num_channels=imagenet_main.NUM_CHANNELS,
        num_classes=imagenet_main.NUM_CLASSES,
Reed's avatar
Reed committed
129
        dtype=dtype)
130
  else:
131
    distribution_utils.undo_set_up_synthetic_data()
Shining Sun's avatar
Shining Sun committed
132
    input_fn = imagenet_main.input_fn
133

134
135
136
137
138
139
  train_input_dataset = input_fn(
      is_training=True,
      data_dir=flags_obj.data_dir,
      batch_size=flags_obj.batch_size,
      num_epochs=flags_obj.train_epochs,
      parse_record_fn=parse_record_keras,
Reed's avatar
Reed committed
140
141
      datasets_num_private_threads=flags_obj.datasets_num_private_threads,
      dtype=dtype)
142

143
144
145
146
147
148
149
150
151
  eval_input_dataset = None
  if not flags_obj.skip_eval:
    eval_input_dataset = input_fn(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=flags_obj.batch_size,
        num_epochs=flags_obj.train_epochs,
        parse_record_fn=parse_record_keras,
        dtype=dtype)
152

153
  strategy = distribution_utils.get_distribution_strategy(
154
      distribution_strategy=flags_obj.distribution_strategy,
155
156
      num_gpus=flags_obj.num_gpus,
      num_workers=distribution_utils.configure_cluster())
157

158
  strategy_scope = keras_common.get_strategy_scope(strategy)
Shining Sun's avatar
Shining Sun committed
159
160

  with strategy_scope:
Shining Sun's avatar
Shining Sun committed
161
    optimizer = keras_common.get_optimizer()
Reed's avatar
Reed committed
162
163
164
165
166
167
168
    if dtype == 'float16':
      # TODO(reedwm): Remove manually wrapping optimizer once mixed precision
      # can be enabled with a single line of code.
      optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
          optimizer, loss_scale=flags_core.get_loss_scale(flags_obj))
    model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES,
                                  dtype=dtype)
Shining Sun's avatar
Shining Sun committed
169

Shining Sun's avatar
Shining Sun committed
170
171
172
    model.compile(loss='sparse_categorical_crossentropy',
                  optimizer=optimizer,
                  metrics=['sparse_categorical_accuracy'])
Shining Sun's avatar
Shining Sun committed
173

174
175
  time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
      learning_rate_schedule, imagenet_main.NUM_IMAGES['train'])
176

Shining Sun's avatar
Shining Sun committed
177
178
179
180
181
182
183
  train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size
  train_epochs = flags_obj.train_epochs

  if flags_obj.train_steps:
    train_steps = min(flags_obj.train_steps, train_steps)
    train_epochs = 1

Shining Sun's avatar
Shining Sun committed
184
  num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] //
Toby Boyd's avatar
Toby Boyd committed
185
                    flags_obj.batch_size)
Shining Sun's avatar
Shining Sun committed
186
187
188

  validation_data = eval_input_dataset
  if flags_obj.skip_eval:
189
190
191
192
    # Only build the training graph. This reduces memory usage introduced by
    # control flow ops in layers that have different implementations for
    # training and inference (e.g., batch norm).
    tf.keras.backend.set_learning_phase(1)
Shining Sun's avatar
Shining Sun committed
193
194
195
    num_eval_steps = None
    validation_data = None

196
197
198
199
200
201
202
  history = model.fit(train_input_dataset,
                      epochs=train_epochs,
                      steps_per_epoch=train_steps,
                      callbacks=[
                          time_callback,
                          lr_callback,
                          tensorboard_callback
203
                      ],
204
205
                      validation_steps=num_eval_steps,
                      validation_data=validation_data,
206
                      validation_freq=flags_obj.epochs_between_evals,
207
                      verbose=2)
208

209
  eval_output = None
210
  if not flags_obj.skip_eval:
211
212
    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
213
                                 verbose=2)
214
  stats = keras_common.build_stats(history, eval_output, time_callback)
215
  return stats
Shining Sun's avatar
bug fix  
Shining Sun committed
216

Shining Sun's avatar
Shining Sun committed
217

218
def main(_):
219
  model_helpers.apply_clean(flags.FLAGS)
220
  with logger.benchmark_context(flags.FLAGS):
221
    return run(flags.FLAGS)
222
223
224


if __name__ == '__main__':
225
  tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
226
  imagenet_main.define_imagenet_flags()
Shining Sun's avatar
Shining Sun committed
227
  keras_common.define_keras_flags()
228
  absl_app.run(main)