"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "6b9d5fba69d2dfbd82f2157c732898c1278d912d"
Unverified Commit 842e5a3e authored by Shining Sun's avatar Shining Sun Committed by GitHub
Browse files

Merge pull request #5928 from tensorflow/cifar_keras_refactor

Cifar keras refactor
parents df122b10 03c35ec6
...@@ -29,17 +29,17 @@ from official.utils.logs import logger ...@@ -29,17 +29,17 @@ from official.utils.logs import logger
from official.resnet import resnet_model from official.resnet import resnet_model
from official.resnet import resnet_run_loop from official.resnet import resnet_run_loop
_HEIGHT = 32 HEIGHT = 32
_WIDTH = 32 WIDTH = 32
_NUM_CHANNELS = 3 NUM_CHANNELS = 3
_DEFAULT_IMAGE_BYTES = _HEIGHT * _WIDTH * _NUM_CHANNELS _DEFAULT_IMAGE_BYTES = HEIGHT * WIDTH * NUM_CHANNELS
# The record is the image plus a one-byte label # The record is the image plus a one-byte label
_RECORD_BYTES = _DEFAULT_IMAGE_BYTES + 1 _RECORD_BYTES = _DEFAULT_IMAGE_BYTES + 1
_NUM_CLASSES = 10 NUM_CLASSES = 10
_NUM_DATA_FILES = 5 _NUM_DATA_FILES = 5
# TODO(tobyboyd): Change to best practice 45K(train)/5K(val)/10K(test) splits. # TODO(tobyboyd): Change to best practice 45K(train)/5K(val)/10K(test) splits.
_NUM_IMAGES = { NUM_IMAGES = {
'train': 50000, 'train': 50000,
'validation': 10000, 'validation': 10000,
} }
...@@ -79,7 +79,7 @@ def parse_record(raw_record, is_training, dtype): ...@@ -79,7 +79,7 @@ def parse_record(raw_record, is_training, dtype):
# The remaining bytes after the label represent the image, which we reshape # The remaining bytes after the label represent the image, which we reshape
# from [depth * height * width] to [depth, height, width]. # from [depth * height * width] to [depth, height, width].
depth_major = tf.reshape(record_vector[1:_RECORD_BYTES], depth_major = tf.reshape(record_vector[1:_RECORD_BYTES],
[_NUM_CHANNELS, _HEIGHT, _WIDTH]) [NUM_CHANNELS, HEIGHT, WIDTH])
# Convert from [depth, height, width] to [height, width, depth], and cast as # Convert from [depth, height, width] to [height, width, depth], and cast as
# float32. # float32.
...@@ -96,10 +96,10 @@ def preprocess_image(image, is_training): ...@@ -96,10 +96,10 @@ def preprocess_image(image, is_training):
if is_training: if is_training:
# Resize the image to add four extra pixels on each side. # Resize the image to add four extra pixels on each side.
image = tf.image.resize_image_with_crop_or_pad( image = tf.image.resize_image_with_crop_or_pad(
image, _HEIGHT + 8, _WIDTH + 8) image, HEIGHT + 8, WIDTH + 8)
# Randomly crop a [_HEIGHT, _WIDTH] section of the image. # Randomly crop a [HEIGHT, WIDTH] section of the image.
image = tf.random_crop(image, [_HEIGHT, _WIDTH, _NUM_CHANNELS]) image = tf.random_crop(image, [HEIGHT, WIDTH, NUM_CHANNELS])
# Randomly flip the image horizontally. # Randomly flip the image horizontally.
image = tf.image.random_flip_left_right(image) image = tf.image.random_flip_left_right(image)
...@@ -134,7 +134,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, ...@@ -134,7 +134,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
dataset=dataset, dataset=dataset,
is_training=is_training, is_training=is_training,
batch_size=batch_size, batch_size=batch_size,
shuffle_buffer=_NUM_IMAGES['train'], shuffle_buffer=NUM_IMAGES['train'],
parse_record_fn=parse_record_fn, parse_record_fn=parse_record_fn,
num_epochs=num_epochs, num_epochs=num_epochs,
dtype=dtype, dtype=dtype,
...@@ -145,7 +145,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, ...@@ -145,7 +145,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
def get_synth_input_fn(dtype): def get_synth_input_fn(dtype):
return resnet_run_loop.get_synth_input_fn( return resnet_run_loop.get_synth_input_fn(
_HEIGHT, _WIDTH, _NUM_CHANNELS, _NUM_CLASSES, dtype=dtype) HEIGHT, WIDTH, NUM_CHANNELS, NUM_CLASSES, dtype=dtype)
############################################################################### ###############################################################################
...@@ -154,7 +154,7 @@ def get_synth_input_fn(dtype): ...@@ -154,7 +154,7 @@ def get_synth_input_fn(dtype):
class Cifar10Model(resnet_model.Model): class Cifar10Model(resnet_model.Model):
"""Model class with appropriate defaults for CIFAR-10 data.""" """Model class with appropriate defaults for CIFAR-10 data."""
def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES, def __init__(self, resnet_size, data_format=None, num_classes=NUM_CLASSES,
resnet_version=resnet_model.DEFAULT_VERSION, resnet_version=resnet_model.DEFAULT_VERSION,
dtype=resnet_model.DEFAULT_DTYPE): dtype=resnet_model.DEFAULT_DTYPE):
"""These are the parameters that work for CIFAR-10 data. """These are the parameters that work for CIFAR-10 data.
...@@ -196,11 +196,11 @@ class Cifar10Model(resnet_model.Model): ...@@ -196,11 +196,11 @@ class Cifar10Model(resnet_model.Model):
def cifar10_model_fn(features, labels, mode, params): def cifar10_model_fn(features, labels, mode, params):
"""Model function for CIFAR-10.""" """Model function for CIFAR-10."""
features = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _NUM_CHANNELS]) features = tf.reshape(features, [-1, HEIGHT, WIDTH, NUM_CHANNELS])
# Learning rate schedule follows arXiv:1512.03385 for ResNet-56 and under. # Learning rate schedule follows arXiv:1512.03385 for ResNet-56 and under.
learning_rate_fn = resnet_run_loop.learning_rate_with_decay( learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
batch_size=params['batch_size'], batch_denom=128, batch_size=params['batch_size'], batch_denom=128,
num_images=_NUM_IMAGES['train'], boundary_epochs=[91, 136, 182], num_images=NUM_IMAGES['train'], boundary_epochs=[91, 136, 182],
decay_rates=[1, 0.1, 0.01, 0.001]) decay_rates=[1, 0.1, 0.01, 0.001])
# Weight decay of 2e-4 diverges from 1e-4 decay used in the ResNet paper # Weight decay of 2e-4 diverges from 1e-4 decay used in the ResNet paper
...@@ -261,7 +261,7 @@ def run_cifar(flags_obj): ...@@ -261,7 +261,7 @@ def run_cifar(flags_obj):
input_fn) input_fn)
resnet_run_loop.resnet_main( resnet_run_loop.resnet_main(
flags_obj, cifar10_model_fn, input_function, DATASET_NAME, flags_obj, cifar10_model_fn, input_function, DATASET_NAME,
shape=[_HEIGHT, _WIDTH, _NUM_CHANNELS]) shape=[HEIGHT, WIDTH, NUM_CHANNELS])
def main(_): def main(_):
......
...@@ -30,11 +30,11 @@ from official.resnet import imagenet_preprocessing ...@@ -30,11 +30,11 @@ from official.resnet import imagenet_preprocessing
from official.resnet import resnet_model from official.resnet import resnet_model
from official.resnet import resnet_run_loop from official.resnet import resnet_run_loop
_DEFAULT_IMAGE_SIZE = 224 DEFAULT_IMAGE_SIZE = 224
_NUM_CHANNELS = 3 NUM_CHANNELS = 3
_NUM_CLASSES = 1001 NUM_CLASSES = 1001
_NUM_IMAGES = { NUM_IMAGES = {
'train': 1281167, 'train': 1281167,
'validation': 50000, 'validation': 50000,
} }
...@@ -149,9 +149,9 @@ def parse_record(raw_record, is_training, dtype): ...@@ -149,9 +149,9 @@ def parse_record(raw_record, is_training, dtype):
image = imagenet_preprocessing.preprocess_image( image = imagenet_preprocessing.preprocess_image(
image_buffer=image_buffer, image_buffer=image_buffer,
bbox=bbox, bbox=bbox,
output_height=_DEFAULT_IMAGE_SIZE, output_height=DEFAULT_IMAGE_SIZE,
output_width=_DEFAULT_IMAGE_SIZE, output_width=DEFAULT_IMAGE_SIZE,
num_channels=_NUM_CHANNELS, num_channels=NUM_CHANNELS,
is_training=is_training) is_training=is_training)
image = tf.cast(image, dtype) image = tf.cast(image, dtype)
...@@ -206,7 +206,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, ...@@ -206,7 +206,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
def get_synth_input_fn(dtype): def get_synth_input_fn(dtype):
return resnet_run_loop.get_synth_input_fn( return resnet_run_loop.get_synth_input_fn(
_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS, _NUM_CLASSES, DEFAULT_IMAGE_SIZE, DEFAULT_IMAGE_SIZE, NUM_CHANNELS, NUM_CLASSES,
dtype=dtype) dtype=dtype)
...@@ -216,7 +216,7 @@ def get_synth_input_fn(dtype): ...@@ -216,7 +216,7 @@ def get_synth_input_fn(dtype):
class ImagenetModel(resnet_model.Model): class ImagenetModel(resnet_model.Model):
"""Model class with appropriate defaults for Imagenet data.""" """Model class with appropriate defaults for Imagenet data."""
def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES, def __init__(self, resnet_size, data_format=None, num_classes=NUM_CLASSES,
resnet_version=resnet_model.DEFAULT_VERSION, resnet_version=resnet_model.DEFAULT_VERSION,
dtype=resnet_model.DEFAULT_DTYPE): dtype=resnet_model.DEFAULT_DTYPE):
"""These are the parameters that work for Imagenet data. """These are the parameters that work for Imagenet data.
...@@ -303,7 +303,7 @@ def imagenet_model_fn(features, labels, mode, params): ...@@ -303,7 +303,7 @@ def imagenet_model_fn(features, labels, mode, params):
learning_rate_fn = resnet_run_loop.learning_rate_with_decay( learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
batch_size=params['batch_size'], batch_denom=256, batch_size=params['batch_size'], batch_denom=256,
num_images=_NUM_IMAGES['train'], boundary_epochs=[30, 60, 80, 90], num_images=NUM_IMAGES['train'], boundary_epochs=[30, 60, 80, 90],
decay_rates=[1, 0.1, 0.01, 0.001, 1e-4], warmup=warmup, base_lr=base_lr) decay_rates=[1, 0.1, 0.01, 0.001, 1e-4], warmup=warmup, base_lr=base_lr)
return resnet_run_loop.resnet_model_fn( return resnet_run_loop.resnet_model_fn(
...@@ -343,7 +343,7 @@ def run_imagenet(flags_obj): ...@@ -343,7 +343,7 @@ def run_imagenet(flags_obj):
resnet_run_loop.resnet_main( resnet_run_loop.resnet_main(
flags_obj, imagenet_model_fn, input_function, DATASET_NAME, flags_obj, imagenet_model_fn, input_function, DATASET_NAME,
shape=[_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS]) shape=[DEFAULT_IMAGE_SIZE, DEFAULT_IMAGE_SIZE, NUM_CHANNELS])
def main(_): def main(_):
......
"""Executes Keras benchmarks and accuracy tests."""
from __future__ import print_function
import os
from absl import flags
from absl.testing import flagsaver
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import cifar10_main as cifar_main
import official.resnet.keras.keras_cifar_main as keras_cifar_main
import official.resnet.keras.keras_common as keras_common
DATA_DIR = '/data/cifar10_data/'
class KerasCifar10BenchmarkTests(object):
"""Benchmarks and accuracy tests for KerasCifar10."""
local_flags = None
def __init__(self, output_dir=None):
self.oss_report_object = None
self.output_dir = output_dir
def keras_resnet56_1_gpu(self):
"""Test keras based model with Keras fit and distribution strategies."""
self._setup()
flags.FLAGS.num_gpus = 1
flags.FLAGS.data_dir = DATA_DIR
flags.FLAGS.batch_size = 128
flags.FLAGS.train_epochs = 182
flags.FLAGS.model_dir = self._get_model_dir('keras_resnet56_1_gpu')
flags.FLAGS.resnet_size = 56
flags.FLAGS.dtype = 'fp32'
stats = keras_cifar_main.run(flags.FLAGS)
self._fill_report_object(stats)
def keras_resnet56_4_gpu(self):
"""Test keras based model with Keras fit and distribution strategies."""
self._setup()
flags.FLAGS.num_gpus = 4
flags.FLAGS.data_dir = self._get_model_dir('keras_resnet56_4_gpu')
flags.FLAGS.batch_size = 128
flags.FLAGS.train_epochs = 182
flags.FLAGS.model_dir = ''
flags.FLAGS.resnet_size = 56
flags.FLAGS.dtype = 'fp32'
stats = keras_cifar_main.run(flags.FLAGS)
self._fill_report_object(stats)
def keras_resnet56_no_dist_strat_1_gpu(self):
"""Test keras based model with Keras fit but not distribution strategies."""
self._setup()
flags.FLAGS.turn_off_distribution_strategy = True
flags.FLAGS.num_gpus = 1
flags.FLAGS.data_dir = DATA_DIR
flags.FLAGS.batch_size = 128
flags.FLAGS.train_epochs = 182
flags.FLAGS.model_dir = self._get_model_dir(
'keras_resnet56_no_dist_strat_1_gpu')
flags.FLAGS.resnet_size = 56
flags.FLAGS.dtype = 'fp32'
stats = keras_cifar_main.run(flags.FLAGS)
self._fill_report_object(stats)
def _fill_report_object(self, stats):
if self.oss_report_object:
self.oss_report_object.top_1 = stats['accuracy_top_1']
self.oss_report_object.add_other_quality(stats['training_accuracy_top_1'],
'top_1_train_accuracy')
else:
raise ValueError('oss_report_object has not been set.')
def _get_model_dir(self, folder_name):
return os.path.join(self.output_dir, folder_name)
def _setup(self):
"""Setups up and resets flags before each test."""
tf.logging.set_verbosity(tf.logging.DEBUG)
if KerasCifar10BenchmarkTests.local_flags is None:
keras_common.define_keras_flags()
cifar_main.define_cifar_flags()
# Loads flags to get defaults to then override.
flags.FLAGS(['foo'])
saved_flag_values = flagsaver.save_flag_values()
KerasCifar10BenchmarkTests.local_flags = saved_flag_values
return
flagsaver.restore_flag_values(KerasCifar10BenchmarkTests.local_flags)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a ResNet model on the Cifar-10 dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app as absl_app
from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import cifar10_main as cifar_main
from official.resnet.keras import keras_common
from official.resnet.keras import resnet_cifar_model
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import distribution_utils
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
(0.1, 91), (0.01, 136), (0.001, 182)
]
def learning_rate_schedule(current_epoch,
current_batch,
batches_per_epoch,
batch_size):
"""Handles linear scaling rule and LR decay.
Scale learning rate at epoch boundaries provided in LR_SCHEDULE by the
provided scaling factor.
Args:
current_epoch: integer, current epoch indexed from 0.
current_batch: integer, current batch in the current epoch, indexed from 0.
batches_per_epoch: integer, number of steps in an epoch.
batch_size: integer, total batch sized.
Returns:
Adjusted learning rate.
"""
initial_learning_rate = keras_common.BASE_LEARNING_RATE * batch_size / 128
learning_rate = initial_learning_rate
for mult, start_epoch in LR_SCHEDULE:
if current_epoch >= start_epoch:
learning_rate = initial_learning_rate * mult
else:
break
return learning_rate
def parse_record_keras(raw_record, is_training, dtype):
"""Parses a record containing a training example of an image.
The input record is parsed into a label and image, and the image is passed
through preprocessing steps (cropping, flipping, and so on).
This method converts the label to one hot to fit the loss function.
Args:
raw_record: scalar Tensor tf.string containing a serialized
Example protocol buffer.
is_training: A boolean denoting whether the input is for training.
dtype: Data type to use for input images.
Returns:
Tuple with processed image tensor and one-hot-encoded label tensor.
"""
image, label = cifar_main.parse_record(raw_record, is_training, dtype)
label = tf.sparse_to_dense(label, (cifar_main.NUM_CLASSES,), 1)
return image, label
def run(flags_obj):
"""Run ResNet Cifar-10 training and eval loop using native Keras APIs.
Args:
flags_obj: An object containing parsed flag values.
Raises:
ValueError: If fp16 is passed as it is not currently supported.
Returns:
Dictionary of training and eval stats.
"""
if flags_obj.enable_eager:
tf.enable_eager_execution()
dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'fp16':
raise ValueError('dtype fp16 is not supported in Keras. Use the default '
'value(fp32).')
per_device_batch_size = distribution_utils.per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj))
if flags_obj.use_synthetic_data:
input_fn = keras_common.get_synth_input_fn(
height=cifar_main.HEIGHT,
width=cifar_main.WIDTH,
num_channels=cifar_main.NUM_CHANNELS,
num_classes=cifar_main.NUM_CLASSES,
dtype=flags_core.get_tf_dtype(flags_obj))
else:
input_fn = cifar_main.input_fn
train_input_dataset = input_fn(
is_training=True,
data_dir=flags_obj.data_dir,
batch_size=per_device_batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras)
eval_input_dataset = input_fn(
is_training=False,
data_dir=flags_obj.data_dir,
batch_size=per_device_batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras)
optimizer = keras_common.get_optimizer()
strategy = distribution_utils.get_distribution_strategy(
flags_obj.num_gpus, flags_obj.turn_off_distribution_strategy)
model = resnet_cifar_model.resnet56(classes=cifar_main.NUM_CLASSES)
model.compile(loss='categorical_crossentropy',
optimizer=optimizer,
metrics=['categorical_accuracy'],
distribute=strategy)
time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
learning_rate_schedule, cifar_main.NUM_IMAGES['train'])
train_steps = cifar_main.NUM_IMAGES['train'] // flags_obj.batch_size
train_epochs = flags_obj.train_epochs
if flags_obj.train_steps:
train_steps = min(flags_obj.train_steps, train_steps)
train_epochs = 1
num_eval_steps = (cifar_main.NUM_IMAGES['validation'] //
flags_obj.batch_size)
validation_data = eval_input_dataset
if flags_obj.skip_eval:
num_eval_steps = None
validation_data = None
history = model.fit(train_input_dataset,
epochs=train_epochs,
steps_per_epoch=train_steps,
callbacks=[
time_callback,
lr_callback,
tensorboard_callback
],
validation_steps=num_eval_steps,
validation_data=validation_data,
verbose=1)
eval_output = None
if not flags_obj.skip_eval:
eval_output = model.evaluate(eval_input_dataset,
steps=num_eval_steps,
verbose=1)
stats = keras_common.build_stats(history, eval_output)
return stats
def main(_):
with logger.benchmark_context(flags.FLAGS):
run(flags.FLAGS)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
cifar_main.define_cifar_flags()
keras_common.define_keras_flags()
absl_app.run(main)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common util functions and classes used by both keras cifar and imagenet."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import numpy as np
# pylint: disable=g-bad-import-order
from absl import flags
import tensorflow as tf
from tensorflow.python.keras.optimizer_v2 import (gradient_descent as
gradient_descent_v2)
FLAGS = flags.FLAGS
BASE_LEARNING_RATE = 0.1 # This matches Jing's version.
TRAIN_TOP_1 = 'training_accuracy_top_1'
class TimeHistory(tf.keras.callbacks.Callback):
"""Callback for Keras models."""
def __init__(self, batch_size):
"""Callback for logging performance (# image/second).
Args:
batch_size: Total batch size.
"""
self._batch_size = batch_size
super(TimeHistory, self).__init__()
self.log_steps = 100
def on_train_begin(self, logs=None):
self.record_batch = True
def on_batch_begin(self, batch, logs=None):
if self.record_batch:
self.start_time = time.time()
self.record_batch = False
def on_batch_end(self, batch, logs=None):
if batch % self.log_steps == 0:
elapsed_time = time.time() - self.start_time
examples_per_second = (self._batch_size * self.log_steps) / elapsed_time
self.record_batch = True
# TODO(anjalisridhar): add timestamp as well.
if batch != 0:
tf.logging.info("BenchmarkMetric: {'num_batches':%d, 'time_taken': %f,"
"'images_per_second': %f}" %
(batch, elapsed_time, examples_per_second))
class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
"""Callback to update learning rate on every batch (not epoch boundaries).
N.B. Only support Keras optimizers, not TF optimizers.
Args:
schedule: a function that takes an epoch index and a batch index as input
(both integer, indexed from 0) and returns a new learning rate as
output (float).
"""
def __init__(self, schedule, batch_size, num_images):
super(LearningRateBatchScheduler, self).__init__()
self.schedule = schedule
self.batches_per_epoch = num_images / batch_size
self.batch_size = batch_size
self.epochs = -1
self.prev_lr = -1
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, 'learning_rate'):
raise ValueError('Optimizer must have a "learning_rate" attribute.')
self.epochs += 1
def on_batch_begin(self, batch, logs=None):
"""Executes before step begins."""
lr = self.schedule(self.epochs,
batch,
self.batches_per_epoch,
self.batch_size)
if not isinstance(lr, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function should be float.')
if lr != self.prev_lr:
self.model.optimizer.learning_rate = lr # lr should be a float here
self.prev_lr = lr
tf.logging.debug('Epoch %05d Batch %05d: LearningRateBatchScheduler '
'change learning rate to %s.', self.epochs, batch, lr)
def get_optimizer():
"""Returns optimizer to use."""
# The learning_rate is overwritten at the beginning of each step by callback.
return gradient_descent_v2.SGD(learning_rate=0.1, momentum=0.9)
def get_callbacks(learning_rate_schedule_fn, num_images):
"""Returns common callbacks."""
time_callback = TimeHistory(FLAGS.batch_size)
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=FLAGS.model_dir)
lr_callback = LearningRateBatchScheduler(
learning_rate_schedule_fn,
batch_size=FLAGS.batch_size,
num_images=num_images)
return time_callback, tensorboard_callback, lr_callback
def build_stats(history, eval_output):
"""Normalizes and returns dictionary of stats.
Args:
history: Results of the training step. Supports both categorical_accuracy
and sparse_categorical_accuracy.
eval_output: Output of the eval step. Assumes first value is eval_loss and
second value is accuracy_top_1.
Returns:
Dictionary of normalized results.
"""
stats = {}
if eval_output:
stats['accuracy_top_1'] = eval_output[1].item()
stats['eval_loss'] = eval_output[0].item()
if history and history.history:
train_hist = history.history
# Gets final loss from training.
stats['loss'] = train_hist['loss'][-1].item()
# Gets top_1 training accuracy.
if 'categorical_accuracy' in train_hist:
stats[TRAIN_TOP_1] = train_hist['categorical_accuracy'][-1].item()
elif 'sparse_categorical_accuracy' in train_hist:
stats[TRAIN_TOP_1] = train_hist['sparse_categorical_accuracy'][-1].item()
return stats
def define_keras_flags():
flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?')
flags.DEFINE_boolean(name='skip_eval', default=False, help='Skip evaluation?')
flags.DEFINE_integer(
name='train_steps', default=None,
help='The number of steps to run for training. If it is larger than '
'# batches per epoch, then use # batches per epoch. When this flag is '
'set, only one epoch is going to run for training.')
def get_synth_input_fn(height, width, num_channels, num_classes,
dtype=tf.float32):
"""Returns an input function that returns a dataset with random data.
This input_fn returns a data set that iterates over a set of random data and
bypasses all preprocessing, e.g. jpeg decode and copy. The host to device
copy is still included. This used to find the upper throughput bound when
tuning the full input pipeline.
Args:
height: Integer height that will be used to create a fake image tensor.
width: Integer width that will be used to create a fake image tensor.
num_channels: Integer depth that will be used to create a fake image tensor.
num_classes: Number of classes that should be represented in the fake labels
tensor
dtype: Data type for features/images.
Returns:
An input_fn that can be used in place of a real one to return a dataset
that can be used for iteration.
"""
# pylint: disable=unused-argument
def input_fn(is_training, data_dir, batch_size, *args, **kwargs):
"""Returns dataset filled with random data."""
# Synthetic input should be within [0, 255].
inputs = tf.truncated_normal(
[batch_size] + [height, width, num_channels],
dtype=dtype,
mean=127,
stddev=60,
name='synthetic_inputs')
labels = tf.random_uniform(
[batch_size] + [1],
minval=0,
maxval=num_classes - 1,
dtype=tf.int32,
name='synthetic_labels')
data = tf.data.Dataset.from_tensors((inputs, labels)).repeat()
data = data.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
return data
return input_fn
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the keras_common module."""
from __future__ import absolute_import
from __future__ import print_function
from mock import Mock
import numpy as np
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet.keras import keras_common
tf.logging.set_verbosity(tf.logging.ERROR)
class KerasCommonTests(tf.test.TestCase):
"""Tests for keras_common."""
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(KerasCommonTests, cls).setUpClass()
def test_build_stats(self):
history = self._build_history(1.145, cat_accuracy=.99988)
eval_output = self._build_eval_output(.56432111, 5.990)
stats = keras_common.build_stats(history, eval_output)
self.assertEqual(1.145, stats['loss'])
self.assertEqual(.99988, stats['training_accuracy_top_1'])
self.assertEqual(.56432111, stats['accuracy_top_1'])
self.assertEqual(5.990, stats['eval_loss'])
def test_build_stats_sparse(self):
history = self._build_history(1.145, cat_accuracy_sparse=.99988)
eval_output = self._build_eval_output(.928, 1.9844)
stats = keras_common.build_stats(history, eval_output)
self.assertEqual(1.145, stats['loss'])
self.assertEqual(.99988, stats['training_accuracy_top_1'])
self.assertEqual(.928, stats['accuracy_top_1'])
self.assertEqual(1.9844, stats['eval_loss'])
def _build_history(self, loss, cat_accuracy=None,
cat_accuracy_sparse=None):
history_p = Mock()
history = {}
history_p.history = history
history['loss'] = [np.float64(loss)]
if cat_accuracy:
history['categorical_accuracy'] = [np.float64(cat_accuracy)]
if cat_accuracy_sparse:
history['sparse_categorical_accuracy'] = [np.float64(cat_accuracy_sparse)]
return history_p
def _build_eval_output(self, top_1, eval_loss):
eval_output = [np.float64(eval_loss), np.float64(top_1)]
return eval_output
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a ResNet model on the ImageNet dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app as absl_app
from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import imagenet_main
from official.resnet.keras import keras_common
from official.resnet.keras import resnet_model
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import distribution_utils
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
(1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
]
def learning_rate_schedule(current_epoch,
current_batch,
batches_per_epoch,
batch_size):
"""Handles linear scaling rule, gradual warmup, and LR decay.
Scale learning rate at epoch boundaries provided in LR_SCHEDULE by the
provided scaling factor.
Args:
current_epoch: integer, current epoch indexed from 0.
current_batch: integer, current batch in the current epoch, indexed from 0.
batches_per_epoch: integer, number of steps in an epoch.
batch_size: integer, total batch sized.
Returns:
Adjusted learning rate.
"""
initial_lr = keras_common.BASE_LEARNING_RATE * batch_size / 256
epoch = current_epoch + float(current_batch) / batches_per_epoch
warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0]
if epoch < warmup_end_epoch:
# Learning rate increases linearly per step.
return initial_lr * warmup_lr_multiplier * epoch / warmup_end_epoch
for mult, start_epoch in LR_SCHEDULE:
if epoch >= start_epoch:
learning_rate = initial_lr * mult
else:
break
return learning_rate
def parse_record_keras(raw_record, is_training, dtype):
"""Adjust the shape of label."""
image, label = imagenet_main.parse_record(raw_record, is_training, dtype)
# Subtract one so that labels are in [0, 1000), and cast to float32 for
# Keras model.
label = tf.cast(tf.cast(tf.reshape(label, shape=[1]), dtype=tf.int32) - 1,
dtype=tf.float32)
return image, label
def run(flags_obj):
"""Run ResNet ImageNet training and eval loop using native Keras APIs.
Args:
flags_obj: An object containing parsed flag values.
Raises:
ValueError: If fp16 is passed as it is not currently supported.
"""
if flags_obj.enable_eager:
tf.enable_eager_execution()
dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'fp16':
raise ValueError('dtype fp16 is not supported in Keras. Use the default '
'value(fp32).')
per_device_batch_size = distribution_utils.per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj))
# pylint: disable=protected-access
if flags_obj.use_synthetic_data:
input_fn = keras_common.get_synth_input_fn(
height=imagenet_main.DEFAULT_IMAGE_SIZE,
width=imagenet_main.DEFAULT_IMAGE_SIZE,
num_channels=imagenet_main.NUM_CHANNELS,
num_classes=imagenet_main.NUM_CLASSES,
dtype=flags_core.get_tf_dtype(flags_obj))
else:
input_fn = imagenet_main.input_fn
train_input_dataset = input_fn(is_training=True,
data_dir=flags_obj.data_dir,
batch_size=per_device_batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras)
eval_input_dataset = input_fn(is_training=False,
data_dir=flags_obj.data_dir,
batch_size=per_device_batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras)
optimizer = keras_common.get_optimizer()
strategy = distribution_utils.get_distribution_strategy(
flags_obj.num_gpus, flags_obj.turn_off_distribution_strategy)
model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES)
model.compile(loss='sparse_categorical_crossentropy',
optimizer=optimizer,
metrics=['sparse_categorical_accuracy'],
distribute=strategy)
time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
learning_rate_schedule, imagenet_main.NUM_IMAGES['train'])
train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size
train_epochs = flags_obj.train_epochs
if flags_obj.train_steps:
train_steps = min(flags_obj.train_steps, train_steps)
train_epochs = 1
num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] //
flags_obj.batch_size)
validation_data = eval_input_dataset
if flags_obj.skip_eval:
num_eval_steps = None
validation_data = None
model.fit(train_input_dataset,
epochs=train_epochs,
steps_per_epoch=train_steps,
callbacks=[
time_callback,
lr_callback,
tensorboard_callback
],
validation_steps=num_eval_steps,
validation_data=validation_data,
verbose=1)
if not flags_obj.skip_eval:
model.evaluate(eval_input_dataset,
steps=num_eval_steps,
verbose=1)
def main(_):
with logger.benchmark_context(flags.FLAGS):
run(flags.FLAGS)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
imagenet_main.define_imagenet_flags()
keras_common.define_keras_flags()
absl_app.run(main)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ResNet56 model for Keras adapted from tf.keras.applications.ResNet50.
# Reference:
- [Deep Residual Learning for Image Recognition](
https://arxiv.org/abs/1512.03385)
Adapted from code contributed by BigMoyan.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.keras import backend
from tensorflow.python.keras import layers
BATCH_NORM_DECAY = 0.997
BATCH_NORM_EPSILON = 1e-5
L2_WEIGHT_DECAY = 2e-4
def identity_building_block(input_tensor,
kernel_size,
filters,
stage,
block,
training=None):
"""The identity block is the block that has no conv layer at shortcut.
Arguments:
input_tensor: input tensor
kernel_size: default 3, the kernel size of
middle conv layer at main path
filters: list of integers, the filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
training: Only used if training keras model with Estimator. In other
scenarios it is handled automatically.
Returns:
Output tensor for the block.
"""
filters1, filters2 = filters
if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = 3
else:
bn_axis = 1
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
x = tf.keras.layers.Conv2D(filters1, kernel_size,
padding='same',
kernel_initializer='he_normal',
kernel_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')(input_tensor)
x = tf.keras.layers.BatchNormalization(axis=bn_axis,
name=bn_name_base + '2a',
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON)(
x, training=training)
x = tf.keras.layers.Activation('relu')(x)
x = tf.keras.layers.Conv2D(filters2, kernel_size,
padding='same',
kernel_initializer='he_normal',
kernel_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')(x)
x = tf.keras.layers.BatchNormalization(axis=bn_axis,
name=bn_name_base + '2b',
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON)(
x, training=training)
x = tf.keras.layers.add([x, input_tensor])
x = tf.keras.layers.Activation('relu')(x)
return x
def conv_building_block(input_tensor,
kernel_size,
filters,
stage,
block,
strides=(2, 2),
training=None):
"""A block that has a conv layer at shortcut.
Arguments:
input_tensor: input tensor
kernel_size: default 3, the kernel size of
middle conv layer at main path
filters: list of integers, the filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
strides: Strides for the first conv layer in the block.
training: Only used if training keras model with Estimator. In other
scenarios it is handled automatically.
Returns:
Output tensor for the block.
Note that from stage 3,
the first conv layer at main path is with strides=(2, 2)
And the shortcut should have strides=(2, 2) as well
"""
filters1, filters2 = filters
if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = 3
else:
bn_axis = 1
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
x = tf.keras.layers.Conv2D(filters1, kernel_size, strides=strides,
padding='same',
kernel_initializer='he_normal',
kernel_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')(input_tensor)
x = tf.keras.layers.BatchNormalization(axis=bn_axis,
name=bn_name_base + '2a',
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON)(
x, training=training)
x = tf.keras.layers.Activation('relu')(x)
x = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same',
kernel_initializer='he_normal',
kernel_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')(x)
x = tf.keras.layers.BatchNormalization(axis=bn_axis,
name=bn_name_base + '2b',
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON)(
x, training=training)
shortcut = tf.keras.layers.Conv2D(filters2, (1, 1), strides=strides,
kernel_initializer='he_normal',
kernel_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '1')(input_tensor)
shortcut = tf.keras.layers.BatchNormalization(
axis=bn_axis, name=bn_name_base + '1',
momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON)(
shortcut, training=training)
x = tf.keras.layers.add([x, shortcut])
x = tf.keras.layers.Activation('relu')(x)
return x
def resnet56(classes=100, training=None):
"""Instantiates the ResNet56 architecture.
Arguments:
classes: optional number of classes to classify images into
training: Only used if training keras model with Estimator. In other
scenarios it is handled automatically.
Returns:
A Keras model instance.
"""
# Determine proper input shape
if backend.image_data_format() == 'channels_first':
input_shape = (3, 32, 32)
bn_axis = 1
else: # channel_last
input_shape = (32, 32, 3)
bn_axis = 3
img_input = layers.Input(shape=input_shape)
x = tf.keras.layers.ZeroPadding2D(padding=(1, 1), name='conv1_pad')(img_input)
x = tf.keras.layers.Conv2D(16, (3, 3),
strides=(1, 1),
padding='valid',
kernel_initializer='he_normal',
kernel_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
name='conv1')(x)
x = tf.keras.layers.BatchNormalization(axis=bn_axis, name='bn_conv1',
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON)(
x, training=training)
x = tf.keras.layers.Activation('relu')(x)
x = conv_building_block(x, 3, [16, 16], stage=2, block='a', strides=(1, 1),
training=training)
x = identity_building_block(x, 3, [16, 16], stage=2, block='b',
training=training)
x = identity_building_block(x, 3, [16, 16], stage=2, block='c',
training=training)
x = identity_building_block(x, 3, [16, 16], stage=2, block='d',
training=training)
x = identity_building_block(x, 3, [16, 16], stage=2, block='e',
training=training)
x = identity_building_block(x, 3, [16, 16], stage=2, block='f',
training=training)
x = identity_building_block(x, 3, [16, 16], stage=2, block='g',
training=training)
x = identity_building_block(x, 3, [16, 16], stage=2, block='h',
training=training)
x = identity_building_block(x, 3, [16, 16], stage=2, block='i',
training=training)
x = conv_building_block(x, 3, [32, 32], stage=3, block='a',
training=training)
x = identity_building_block(x, 3, [32, 32], stage=3, block='b',
training=training)
x = identity_building_block(x, 3, [32, 32], stage=3, block='c',
training=training)
x = identity_building_block(x, 3, [32, 32], stage=3, block='d',
training=training)
x = identity_building_block(x, 3, [32, 32], stage=3, block='e',
training=training)
x = identity_building_block(x, 3, [32, 32], stage=3, block='f',
training=training)
x = identity_building_block(x, 3, [32, 32], stage=3, block='g',
training=training)
x = identity_building_block(x, 3, [32, 32], stage=3, block='h',
training=training)
x = identity_building_block(x, 3, [32, 32], stage=3, block='i',
training=training)
x = conv_building_block(x, 3, [64, 64], stage=4, block='a',
training=training)
x = identity_building_block(x, 3, [64, 64], stage=4, block='b',
training=training)
x = identity_building_block(x, 3, [64, 64], stage=4, block='c',
training=training)
x = identity_building_block(x, 3, [64, 64], stage=4, block='d',
training=training)
x = identity_building_block(x, 3, [64, 64], stage=4, block='e',
training=training)
x = identity_building_block(x, 3, [64, 64], stage=4, block='f',
training=training)
x = identity_building_block(x, 3, [64, 64], stage=4, block='g',
training=training)
x = identity_building_block(x, 3, [64, 64], stage=4, block='h',
training=training)
x = identity_building_block(x, 3, [64, 64], stage=4, block='i',
training=training)
x = tf.keras.layers.GlobalAveragePooling2D(name='avg_pool')(x)
x = tf.keras.layers.Dense(classes, activation='softmax',
kernel_initializer='he_normal',
kernel_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
name='fc10')(x)
inputs = img_input
# Create model.
model = tf.keras.models.Model(inputs, x, name='resnet56')
return model
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ResNet50 model for Keras.
Adapted from tf.keras.applications.resnet50.ResNet50().
This is ResNet model version 1.5.
Related papers/blogs:
- https://arxiv.org/abs/1512.03385
- https://arxiv.org/pdf/1603.05027v2.pdf
- http://torch.ch/blog/2016/02/04/resnets.html
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import warnings
from tensorflow.python.keras import backend
from tensorflow.python.keras import layers
from tensorflow.python.keras import models
from tensorflow.python.keras import regularizers
from tensorflow.python.keras import utils
L2_WEIGHT_DECAY = 1e-4
BATCH_NORM_DECAY = 0.9
BATCH_NORM_EPSILON = 1e-5
def identity_block(input_tensor, kernel_size, filters, stage, block):
"""The identity block is the block that has no conv layer at shortcut.
# Arguments
input_tensor: input tensor
kernel_size: default 3, the kernel size of
middle conv layer at main path
filters: list of integers, the filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
# Returns
Output tensor for the block.
"""
filters1, filters2, filters3 = filters
if backend.image_data_format() == 'channels_last':
bn_axis = 3
else:
bn_axis = 1
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
x = layers.Conv2D(filters1, (1, 1),
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')(input_tensor)
x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2a')(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters2, kernel_size,
padding='same',
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')(x)
x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2b')(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters3, (1, 1),
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2c')(x)
x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2c')(x)
x = layers.add([x, input_tensor])
x = layers.Activation('relu')(x)
return x
def conv_block(input_tensor,
kernel_size,
filters,
stage,
block,
strides=(2, 2)):
"""A block that has a conv layer at shortcut.
# Arguments
input_tensor: input tensor
kernel_size: default 3, the kernel size of
middle conv layer at main path
filters: list of integers, the filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
strides: Strides for the second conv layer in the block.
# Returns
Output tensor for the block.
Note that from stage 3,
the second conv layer at main path is with strides=(2, 2)
And the shortcut should have strides=(2, 2) as well
"""
filters1, filters2, filters3 = filters
if backend.image_data_format() == 'channels_last':
bn_axis = 3
else:
bn_axis = 1
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
x = layers.Conv2D(filters1, (1, 1), kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')(input_tensor)
x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2a')(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters2, kernel_size, strides=strides, padding='same',
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')(x)
x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2b')(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters3, (1, 1),
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2c')(x)
x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2c')(x)
shortcut = layers.Conv2D(filters3, (1, 1), strides=strides,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '1')(input_tensor)
shortcut = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '1')(shortcut)
x = layers.add([x, shortcut])
x = layers.Activation('relu')(x)
return x
def resnet50(num_classes):
# TODO(tfboyd): add training argument, just lik resnet56.
"""Instantiates the ResNet50 architecture.
Args:
num_classes: `int` number of classes for image classification.
Returns:
A Keras model instance.
"""
# Determine proper input shape
if backend.image_data_format() == 'channels_first':
input_shape = (3, 224, 224)
bn_axis = 1
else:
input_shape = (224, 224, 3)
bn_axis = 3
img_input = layers.Input(shape=input_shape)
x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(img_input)
x = layers.Conv2D(64, (7, 7),
strides=(2, 2),
padding='valid',
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name='conv1')(x)
x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name='bn_conv1')(x)
x = layers.Activation('relu')(x)
x = layers.ZeroPadding2D(padding=(1, 1), name='pool1_pad')(x)
x = layers.MaxPooling2D((3, 3), strides=(2, 2))(x)
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
x = layers.Dense(
num_classes, activation='softmax',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name='fc1000')(x)
# Create model.
return models.Model(img_input, x, name='resnet50')
...@@ -628,7 +628,10 @@ def define_resnet_flags(resnet_size_choices=None): ...@@ -628,7 +628,10 @@ def define_resnet_flags(resnet_size_choices=None):
'the expense of image resize/cropping being done as part of model ' 'the expense of image resize/cropping being done as part of model '
'inference. Note, this flag only applies to ImageNet and cannot ' 'inference. Note, this flag only applies to ImageNet and cannot '
'be used for CIFAR.')) 'be used for CIFAR.'))
flags.DEFINE_boolean(
name='turn_off_distribution_strategy', default=False,
help=flags_core.help_wrap('Set to True to not use distribution '
'strategies.'))
choice_kwargs = dict( choice_kwargs = dict(
name='resnet_size', short_name='rs', default='50', name='resnet_size', short_name='rs', default='50',
help=flags_core.help_wrap('The size of the ResNet model to use.')) help=flags_core.help_wrap('The size of the ResNet model to use.'))
......
...@@ -21,7 +21,9 @@ from __future__ import print_function ...@@ -21,7 +21,9 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
def get_distribution_strategy(num_gpus, all_reduce_alg=None): def get_distribution_strategy(num_gpus,
all_reduce_alg=None,
turn_off_distribution_strategy=False):
"""Return a DistributionStrategy for running the model. """Return a DistributionStrategy for running the model.
Args: Args:
...@@ -30,15 +32,31 @@ def get_distribution_strategy(num_gpus, all_reduce_alg=None): ...@@ -30,15 +32,31 @@ def get_distribution_strategy(num_gpus, all_reduce_alg=None):
See tf.contrib.distribute.AllReduceCrossDeviceOps for available See tf.contrib.distribute.AllReduceCrossDeviceOps for available
algorithms. If None, DistributionStrategy will choose based on device algorithms. If None, DistributionStrategy will choose based on device
topology. topology.
turn_off_distribution_strategy: when set to True, do not use any
distribution strategy. Note that when it is True, and num_gpus is
larger than 1, it will raise a ValueError.
Returns: Returns:
tf.contrib.distribute.DistibutionStrategy object. tf.contrib.distribute.DistibutionStrategy object.
Raises:
ValueError: if turn_off_distribution_strategy is True and num_gpus is
larger than 1
""" """
if num_gpus == 0: if num_gpus == 0:
if turn_off_distribution_strategy:
return None
else:
return tf.contrib.distribute.OneDeviceStrategy("device:CPU:0") return tf.contrib.distribute.OneDeviceStrategy("device:CPU:0")
elif num_gpus == 1: elif num_gpus == 1:
return tf.contrib.distribute.OneDeviceStrategy("device:GPU:0") if turn_off_distribution_strategy:
return None
else: else:
return tf.contrib.distribute.OneDeviceStrategy("device:GPU:0")
elif turn_off_distribution_strategy:
raise ValueError("When {} GPUs are specified, "
"turn_off_distribution_strategy flag cannot be set to"
"True.".format(num_gpus))
else: # num_gpus > 1 and not turn_off_distribution_strategy
if all_reduce_alg: if all_reduce_alg:
return tf.contrib.distribute.MirroredStrategy( return tf.contrib.distribute.MirroredStrategy(
num_gpus=num_gpus, num_gpus=num_gpus,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment