# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Runs a ResNet model on the ImageNet dataset.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from absl import app as absl_app from absl import flags import tensorflow as tf # pylint: disable=g-bad-import-order from official.resnet import imagenet_main from official.resnet import imagenet_preprocessing from official.resnet import resnet_run_loop from official.resnet.keras import keras_common from official.resnet.keras import resnet50 from official.utils.flags import core as flags_core from official.utils.logs import logger from official.utils.misc import distribution_utils # import os # os.environ['TF2_BEHAVIOR'] = 'enabled' LR_SCHEDULE = [ # (multiplier, epoch to start) tuples (1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80) ] def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch, batch_size): """Handles linear scaling rule, gradual warmup, and LR decay. The learning rate starts at 0, then it increases linearly per step. After 5 epochs we reach the base learning rate (scaled to account for batch size). After 30, 60 and 80 epochs the learning rate is divided by 10. After 90 epochs training stops and the LR is set to 0. This ensures that we train for exactly 90 epochs for reproducibility. Args: current_epoch: integer, current epoch indexed from 0. current_batch: integer, current batch in the current epoch, indexed from 0. Returns: Adjusted learning rate. """ initial_learning_rate = keras_common.BASE_LEARNING_RATE * batch_size / 256 epoch = current_epoch + float(current_batch) / batches_per_epoch warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0] if epoch < warmup_end_epoch: # Learning rate increases linearly per step. return initial_learning_rate * warmup_lr_multiplier * epoch / warmup_end_epoch for mult, start_epoch in LR_SCHEDULE: if epoch >= start_epoch: learning_rate = initial_learning_rate * mult else: break return learning_rate def parse_record_keras(raw_record, is_training, dtype): """Adjust the shape of label.""" image, label = imagenet_main.parse_record(raw_record, is_training, dtype) # Subtract one so that labels are in [0, 1000), and cast to float32 for # Keras model. label = tf.cast(tf.cast(tf.reshape(label, shape=[1]), dtype=tf.int32) - 1, dtype=tf.float32) return image, label def get_synth_input_fn(height, width, num_channels, num_classes, dtype=tf.float32): """Returns an input function that returns a dataset with random data. This input_fn returns a data set that iterates over a set of random data and bypasses all preprocessing, e.g. jpeg decode and copy. The host to device copy is still included. This used to find the upper throughput bound when tunning the full input pipeline. Args: height: Integer height that will be used to create a fake image tensor. width: Integer width that will be used to create a fake image tensor. num_channels: Integer depth that will be used to create a fake image tensor. num_classes: Number of classes that should be represented in the fake labels tensor dtype: Data type for features/images. Returns: An input_fn that can be used in place of a real one to return a dataset that can be used for iteration. """ # pylint: disable=unused-argument def input_fn(is_training, data_dir, batch_size, *args, **kwargs): """Returns dataset filled with random data.""" # Synthetic input should be within [0, 255]. inputs = tf.truncated_normal( [batch_size] + [height, width, num_channels], dtype=dtype, mean=127, stddev=60, name='synthetic_inputs') labels = tf.random_uniform( [batch_size] + [1], minval=0, maxval=num_classes - 1, dtype=tf.int32, name='synthetic_labels') data = tf.data.Dataset.from_tensors((inputs, labels)).repeat() data = data.prefetch(buffer_size=tf.contrib.data.AUTOTUNE) return data return input_fn def run_imagenet_with_keras(flags_obj): """Run ResNet ImageNet training and eval loop using native Keras APIs. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. """ if flags_obj.enable_eager: tf.enable_eager_execution() dtype = flags_core.get_tf_dtype(flags_obj) if dtype == 'fp16': raise ValueError('dtype fp16 is not supported in Keras. Use the default ' 'value(fp32).') per_device_batch_size = distribution_utils.per_device_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)) # pylint: disable=protected-access if flags_obj.use_synthetic_data: input_fn = get_synth_input_fn( height=imagenet_main.DEFAULT_IMAGE_SIZE, width=imagenet_main.DEFAULT_IMAGE_SIZE, num_channels=imagenet_main.NUM_CHANNELS, num_classes=imagenet_main.NUM_CLASSES, dtype=flags_core.get_tf_dtype(flags_obj)) else: input_fn = imagenet_main.input_fn train_input_dataset = input_fn( is_training=True, data_dir=flags_obj.data_dir, batch_size=per_device_batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras) eval_input_dataset = input_fn( is_training=False, data_dir=flags_obj.data_dir, batch_size=per_device_batch_size, num_epochs=flags_obj.train_epochs, parse_record_fn=parse_record_keras) optimizer = keras_common.get_optimizer() strategy = distribution_utils.get_distribution_strategy( flags_obj.num_gpus, flags_obj.use_one_device_strategy) model = resnet50.ResNet50(num_classes=imagenet_main.NUM_CLASSES) model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=['sparse_categorical_accuracy'], distribute=strategy) time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks( learning_rate_schedule, imagenet_main.NUM_IMAGES['train']) steps_per_epoch = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] // flags_obj.batch_size) train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size train_epochs = flags_obj.train_epochs if flags_obj.train_steps: train_steps = min(flags_obj.train_steps, train_steps) train_epochs = 1 history = model.fit(train_input_dataset, epochs=train_epochs, steps_per_epoch=train_steps, callbacks=[ time_callback, lr_callback, tensorboard_callback ], validation_steps=num_eval_steps, validation_data=eval_input_dataset, verbose=1) if not flags_obj.skip_eval: eval_output = model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=1) stats = keras_common.analyze_fit_and_eval_result(history, eval_output) return stats def main(_): with logger.benchmark_context(flags.FLAGS): run_imagenet_with_keras(flags.FLAGS) if __name__ == '__main__': tf.logging.set_verbosity(tf.logging.INFO) imagenet_main.define_imagenet_flags() keras_common.define_keras_flags() absl_app.run(main)