Commit 9dafea91 authored by sunxx1's avatar sunxx1
Browse files

Merge branch 'qianyj_tf' into 'main'

update tf code

See merge request dcutoolkit/deeplearing/dlexamples_new!35
parents 92a2ca36 a4146470
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a ResNet model on the ImageNet dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import app as absl_app
from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.resnet import imagenet_preprocessing
from official.resnet import resnet_model
from official.resnet import resnet_run_loop
DEFAULT_IMAGE_SIZE = 224
NUM_CHANNELS = 3
NUM_CLASSES = 1001
NUM_IMAGES = {
'train': 1281167,
'validation': 50000,
}
_NUM_TRAIN_FILES = 1024
_SHUFFLE_BUFFER = 10000
DATASET_NAME = 'ImageNet'
###############################################################################
# Data processing
###############################################################################
def get_filenames(is_training, data_dir):
"""Return filenames for dataset."""
if is_training:
return [
os.path.join(data_dir, 'train-%05d-of-01024' % i)
for i in range(_NUM_TRAIN_FILES)]
else:
return [
os.path.join(data_dir, 'validation-%05d-of-00128' % i)
for i in range(128)]
def _parse_example_proto(example_serialized):
"""Parses an Example proto containing a training example of an image.
The output of the build_image_data.py image preprocessing script is a dataset
containing serialized Example protocol buffers. Each Example proto contains
the following fields (values are included as examples):
image/height: 462
image/width: 581
image/colorspace: 'RGB'
image/channels: 3
image/class/label: 615
image/class/synset: 'n03623198'
image/class/text: 'knee pad'
image/object/bbox/xmin: 0.1
image/object/bbox/xmax: 0.9
image/object/bbox/ymin: 0.2
image/object/bbox/ymax: 0.6
image/object/bbox/label: 615
image/format: 'JPEG'
image/filename: 'ILSVRC2012_val_00041207.JPEG'
image/encoded: <JPEG encoded string>
Args:
example_serialized: scalar Tensor tf.string containing a serialized
Example protocol buffer.
Returns:
image_buffer: Tensor tf.string containing the contents of a JPEG file.
label: Tensor tf.int32 containing the label.
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
where each coordinate is [0, 1) and the coordinates are arranged as
[ymin, xmin, ymax, xmax].
"""
# Dense features in Example proto.
feature_map = {
'image/encoded': tf.FixedLenFeature([], dtype=tf.string,
default_value=''),
'image/class/label': tf.FixedLenFeature([], dtype=tf.int64,
default_value=-1),
'image/class/text': tf.FixedLenFeature([], dtype=tf.string,
default_value=''),
}
sparse_float32 = tf.VarLenFeature(dtype=tf.float32)
# Sparse features in Example proto.
feature_map.update(
{k: sparse_float32 for k in ['image/object/bbox/xmin',
'image/object/bbox/ymin',
'image/object/bbox/xmax',
'image/object/bbox/ymax']})
features = tf.parse_single_example(example_serialized, feature_map)
label = tf.cast(features['image/class/label'], dtype=tf.int32)
xmin = tf.expand_dims(features['image/object/bbox/xmin'].values, 0)
ymin = tf.expand_dims(features['image/object/bbox/ymin'].values, 0)
xmax = tf.expand_dims(features['image/object/bbox/xmax'].values, 0)
ymax = tf.expand_dims(features['image/object/bbox/ymax'].values, 0)
# Note that we impose an ordering of (y, x) just to make life difficult.
bbox = tf.concat([ymin, xmin, ymax, xmax], 0)
# Force the variable number of bounding boxes into the shape
# [1, num_boxes, coords].
bbox = tf.expand_dims(bbox, 0)
bbox = tf.transpose(bbox, [0, 2, 1])
return features['image/encoded'], label, bbox
def parse_record(raw_record, is_training, dtype):
"""Parses a record containing a training example of an image.
The input record is parsed into a label and image, and the image is passed
through preprocessing steps (cropping, flipping, and so on).
Args:
raw_record: scalar Tensor tf.string containing a serialized
Example protocol buffer.
is_training: A boolean denoting whether the input is for training.
dtype: data type to use for images/features.
Returns:
Tuple with processed image tensor and one-hot-encoded label tensor.
"""
image_buffer, label, bbox = _parse_example_proto(raw_record)
image = imagenet_preprocessing.preprocess_image(
image_buffer=image_buffer,
bbox=bbox,
output_height=DEFAULT_IMAGE_SIZE,
output_width=DEFAULT_IMAGE_SIZE,
num_channels=NUM_CHANNELS,
is_training=is_training)
image = tf.cast(image, dtype)
return image, label
def input_fn(is_training, data_dir, batch_size, num_epochs=1,
dtype=tf.float32, datasets_num_private_threads=None,
num_parallel_batches=1, parse_record_fn=parse_record):
"""Input function which provides batches for train or eval.
Args:
is_training: A boolean denoting whether the input is for training.
data_dir: The directory containing the input data.
batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset.
dtype: Data type to use for images/features
datasets_num_private_threads: Number of private threads for tf.data.
num_parallel_batches: Number of parallel batches for tf.data.
parse_record_fn: Function to use for parsing the records.
Returns:
A dataset that can be used for iteration.
"""
filenames = get_filenames(is_training, data_dir)
dataset = tf.data.Dataset.from_tensor_slices(filenames)
if is_training:
# Shuffle the input files
dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES)
# Convert to individual records.
# cycle_length = 10 means 10 files will be read and deserialized in parallel.
# This number is low enough to not cause too much contention on small systems
# but high enough to provide the benefits of parallelization. You may want
# to increase this number if you have a large number of CPU cores.
dataset = dataset.apply(tf.contrib.data.parallel_interleave(
tf.data.TFRecordDataset, cycle_length=10))
return resnet_run_loop.process_record_dataset(
dataset=dataset,
is_training=is_training,
batch_size=batch_size,
shuffle_buffer=_SHUFFLE_BUFFER,
parse_record_fn=parse_record_fn,
num_epochs=num_epochs,
dtype=dtype,
datasets_num_private_threads=datasets_num_private_threads,
num_parallel_batches=num_parallel_batches
)
def get_synth_input_fn(dtype):
return resnet_run_loop.get_synth_input_fn(
DEFAULT_IMAGE_SIZE, DEFAULT_IMAGE_SIZE, NUM_CHANNELS, NUM_CLASSES,
dtype=dtype)
###############################################################################
# Running the model
###############################################################################
class ImagenetModel(resnet_model.Model):
"""Model class with appropriate defaults for Imagenet data."""
def __init__(self, resnet_size, data_format=None, num_classes=NUM_CLASSES,
resnet_version=resnet_model.DEFAULT_VERSION,
dtype=resnet_model.DEFAULT_DTYPE):
"""These are the parameters that work for Imagenet data.
Args:
resnet_size: The number of convolutional layers needed in the model.
data_format: Either 'channels_first' or 'channels_last', specifying which
data format to use when setting up the model.
num_classes: The number of output classes needed from the model. This
enables users to extend the same model to their own datasets.
resnet_version: Integer representing which version of the ResNet network
to use. See README for details. Valid values: [1, 2]
dtype: The TensorFlow dtype to use for calculations.
"""
# For bigger models, we want to use "bottleneck" layers
if resnet_size < 50:
bottleneck = False
else:
bottleneck = True
super(ImagenetModel, self).__init__(
resnet_size=resnet_size,
bottleneck=bottleneck,
num_classes=num_classes,
num_filters=64,
kernel_size=7,
conv_stride=2,
first_pool_size=3,
first_pool_stride=2,
block_sizes=_get_block_sizes(resnet_size),
block_strides=[1, 2, 2, 2],
resnet_version=resnet_version,
data_format=data_format,
dtype=dtype
)
def _get_block_sizes(resnet_size):
"""Retrieve the size of each block_layer in the ResNet model.
The number of block layers used for the Resnet model varies according
to the size of the model. This helper grabs the layer set we want, throwing
an error if a non-standard size has been selected.
Args:
resnet_size: The number of convolutional layers needed in the model.
Returns:
A list of block sizes to use in building the model.
Raises:
KeyError: if invalid resnet_size is received.
"""
choices = {
18: [2, 2, 2, 2],
34: [3, 4, 6, 3],
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
152: [3, 8, 36, 3],
200: [3, 24, 36, 3]
}
try:
return choices[resnet_size]
except KeyError:
err = ('Could not find layers for selected Resnet size.\n'
'Size received: {}; sizes allowed: {}.'.format(
resnet_size, choices.keys()))
raise ValueError(err)
def imagenet_model_fn(features, labels, mode, params):
"""Our model_fn for ResNet to be used with our Estimator."""
# Warmup and higher lr may not be valid for fine tuning with small batches
# and smaller numbers of training images.
if params['fine_tune']:
warmup = False
base_lr = .1
else:
warmup = True
base_lr = .128
learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
batch_size=params['batch_size'], batch_denom=256,
num_images=NUM_IMAGES['train'], boundary_epochs=[30, 60, 80, 90],
decay_rates=[1, 0.1, 0.01, 0.001, 1e-4], warmup=warmup, base_lr=base_lr)
return resnet_run_loop.resnet_model_fn(
features=features,
labels=labels,
mode=mode,
model_class=ImagenetModel,
resnet_size=params['resnet_size'],
weight_decay=1e-4,
learning_rate_fn=learning_rate_fn,
momentum=0.9,
data_format=params['data_format'],
resnet_version=params['resnet_version'],
loss_scale=params['loss_scale'],
loss_filter_fn=None,
dtype=params['dtype'],
fine_tune=params['fine_tune']
)
def define_imagenet_flags():
resnet_run_loop.define_resnet_flags(
resnet_size_choices=['18', '34', '50', '101', '152', '200'])
flags.adopt_module_key_flags(resnet_run_loop)
flags_core.set_defaults(train_epochs=90)
def run_imagenet(flags_obj):
"""Run ResNet ImageNet training and eval loop.
Args:
flags_obj: An object containing parsed flag values.
"""
input_function = (flags_obj.use_synthetic_data and
get_synth_input_fn(flags_core.get_tf_dtype(flags_obj)) or
input_fn)
resnet_run_loop.resnet_main(
flags_obj, imagenet_model_fn, input_function, DATASET_NAME,
shape=[DEFAULT_IMAGE_SIZE, DEFAULT_IMAGE_SIZE, NUM_CHANNELS])
def main(_):
with logger.benchmark_context(flags.FLAGS):
run_imagenet(flags.FLAGS)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
define_imagenet_flags()
absl_app.run(main)
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides utilities to preprocess images.
Training images are sampled using the provided bounding boxes, and subsequently
cropped to the sampled bounding box. Images are additionally flipped randomly,
then resized to the target output size (without aspect-ratio preservation).
Images used during evaluation are resized (with aspect-ratio preservation) and
centrally cropped.
All images undergo mean color subtraction.
Note that these steps are colloquially referred to as "ResNet preprocessing,"
and they differ from "VGG preprocessing," which does not use bounding boxes
and instead does an aspect-preserving resize followed by random crop during
training. (These both differ from "Inception preprocessing," which introduces
color distortion steps.)
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
_R_MEAN = 123.68
_G_MEAN = 116.78
_B_MEAN = 103.94
_CHANNEL_MEANS = [_R_MEAN, _G_MEAN, _B_MEAN]
# The lower bound for the smallest side of the image for aspect-preserving
# resizing. For example, if an image is 500 x 1000, it will be resized to
# _RESIZE_MIN x (_RESIZE_MIN * 2).
_RESIZE_MIN = 256
def _decode_crop_and_flip(image_buffer, bbox, num_channels):
"""Crops the given image to a random part of the image, and randomly flips.
We use the fused decode_and_crop op, which performs better than the two ops
used separately in series, but note that this requires that the image be
passed in as an un-decoded string Tensor.
Args:
image_buffer: scalar string Tensor representing the raw JPEG image buffer.
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
where each coordinate is [0, 1) and the coordinates are arranged as
[ymin, xmin, ymax, xmax].
num_channels: Integer depth of the image buffer for decoding.
Returns:
3-D tensor with cropped image.
"""
# A large fraction of image datasets contain a human-annotated bounding box
# delineating the region of the image containing the object of interest. We
# choose to create a new bounding box for the object which is a randomly
# distorted version of the human-annotated bounding box that obeys an
# allowed range of aspect ratios, sizes and overlap with the human-annotated
# bounding box. If no box is supplied, then we assume the bounding box is
# the entire image.
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
tf.image.extract_jpeg_shape(image_buffer),
bounding_boxes=bbox,
min_object_covered=0.1,
aspect_ratio_range=[0.75, 1.33],
area_range=[0.05, 1.0],
max_attempts=100,
use_image_if_no_bounding_boxes=True)
bbox_begin, bbox_size, _ = sample_distorted_bounding_box
# Reassemble the bounding box in the format the crop op requires.
offset_y, offset_x, _ = tf.unstack(bbox_begin)
target_height, target_width, _ = tf.unstack(bbox_size)
crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
# Use the fused decode and crop op here, which is faster than each in series.
cropped = tf.image.decode_and_crop_jpeg(
image_buffer, crop_window, channels=num_channels)
# Flip to add a little more random distortion in.
cropped = tf.image.random_flip_left_right(cropped)
return cropped
def _central_crop(image, crop_height, crop_width):
"""Performs central crops of the given image list.
Args:
image: a 3-D image tensor
crop_height: the height of the image following the crop.
crop_width: the width of the image following the crop.
Returns:
3-D tensor with cropped image.
"""
shape = tf.shape(image)
height, width = shape[0], shape[1]
amount_to_be_cropped_h = (height - crop_height)
crop_top = amount_to_be_cropped_h // 2
amount_to_be_cropped_w = (width - crop_width)
crop_left = amount_to_be_cropped_w // 2
return tf.slice(
image, [crop_top, crop_left, 0], [crop_height, crop_width, -1])
def _mean_image_subtraction(image, means, num_channels):
"""Subtracts the given means from each image channel.
For example:
means = [123.68, 116.779, 103.939]
image = _mean_image_subtraction(image, means)
Note that the rank of `image` must be known.
Args:
image: a tensor of size [height, width, C].
means: a C-vector of values to subtract from each channel.
num_channels: number of color channels in the image that will be distorted.
Returns:
the centered image.
Raises:
ValueError: If the rank of `image` is unknown, if `image` has a rank other
than three or if the number of channels in `image` doesn't match the
number of values in `means`.
"""
if image.get_shape().ndims != 3:
raise ValueError('Input must be of size [height, width, C>0]')
if len(means) != num_channels:
raise ValueError('len(means) must match the number of channels')
# We have a 1-D tensor of means; convert to 3-D.
means = tf.expand_dims(tf.expand_dims(means, 0), 0)
return image - means
def _smallest_size_at_least(height, width, resize_min):
"""Computes new shape with the smallest side equal to `smallest_side`.
Computes new shape with the smallest side equal to `smallest_side` while
preserving the original aspect ratio.
Args:
height: an int32 scalar tensor indicating the current height.
width: an int32 scalar tensor indicating the current width.
resize_min: A python integer or scalar `Tensor` indicating the size of
the smallest side after resize.
Returns:
new_height: an int32 scalar tensor indicating the new height.
new_width: an int32 scalar tensor indicating the new width.
"""
resize_min = tf.cast(resize_min, tf.float32)
# Convert to floats to make subsequent calculations go smoothly.
height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32)
smaller_dim = tf.minimum(height, width)
scale_ratio = resize_min / smaller_dim
# Convert back to ints to make heights and widths that TF ops will accept.
new_height = tf.cast(height * scale_ratio, tf.int32)
new_width = tf.cast(width * scale_ratio, tf.int32)
return new_height, new_width
def _aspect_preserving_resize(image, resize_min):
"""Resize images preserving the original aspect ratio.
Args:
image: A 3-D image `Tensor`.
resize_min: A python integer or scalar `Tensor` indicating the size of
the smallest side after resize.
Returns:
resized_image: A 3-D tensor containing the resized image.
"""
shape = tf.shape(image)
height, width = shape[0], shape[1]
new_height, new_width = _smallest_size_at_least(height, width, resize_min)
return _resize_image(image, new_height, new_width)
def _resize_image(image, height, width):
"""Simple wrapper around tf.resize_images.
This is primarily to make sure we use the same `ResizeMethod` and other
details each time.
Args:
image: A 3-D image `Tensor`.
height: The target height for the resized image.
width: The target width for the resized image.
Returns:
resized_image: A 3-D tensor containing the resized image. The first two
dimensions have the shape [height, width].
"""
return tf.image.resize_images(
image, [height, width], method=tf.image.ResizeMethod.BILINEAR,
align_corners=False)
def preprocess_image(image_buffer, bbox, output_height, output_width,
num_channels, is_training=False):
"""Preprocesses the given image.
Preprocessing includes decoding, cropping, and resizing for both training
and eval images. Training preprocessing, however, introduces some random
distortion of the image to improve accuracy.
Args:
image_buffer: scalar string Tensor representing the raw JPEG image buffer.
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
where each coordinate is [0, 1) and the coordinates are arranged as
[ymin, xmin, ymax, xmax].
output_height: The height of the image after preprocessing.
output_width: The width of the image after preprocessing.
num_channels: Integer depth of the image buffer for decoding.
is_training: `True` if we're preprocessing the image for training and
`False` otherwise.
Returns:
A preprocessed image.
"""
if is_training:
# For training, we want to randomize some of the distortions.
image = _decode_crop_and_flip(image_buffer, bbox, num_channels)
image = _resize_image(image, output_height, output_width)
else:
# For validation, we want to decode, resize, then just crop the middle.
image = tf.image.decode_jpeg(image_buffer, channels=num_channels)
image = _aspect_preserving_resize(image, _RESIZE_MIN)
image = _central_crop(image, output_height, output_width)
image.set_shape([output_height, output_width, num_channels])
return _mean_image_subtraction(image, _CHANNEL_MEANS, num_channels)
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import imagenet_main
from official.utils.testing import integration
tf.logging.set_verbosity(tf.logging.ERROR)
_BATCH_SIZE = 32
_LABEL_CLASSES = 1001
class BaseTest(tf.test.TestCase):
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(BaseTest, cls).setUpClass()
imagenet_main.define_imagenet_flags()
def tearDown(self):
super(BaseTest, self).tearDown()
tf.gfile.DeleteRecursively(self.get_temp_dir())
def _tensor_shapes_helper(self, resnet_size, resnet_version, dtype, with_gpu):
"""Checks the tensor shapes after each phase of the ResNet model."""
def reshape(shape):
"""Returns the expected dimensions depending on if a GPU is being used."""
# If a GPU is used for the test, the shape is returned (already in NCHW
# form). When GPU is not used, the shape is converted to NHWC.
if with_gpu:
return shape
return shape[0], shape[2], shape[3], shape[1]
graph = tf.Graph()
with graph.as_default(), self.test_session(
graph=graph, use_gpu=with_gpu, force_gpu=with_gpu):
model = imagenet_main.ImagenetModel(
resnet_size=resnet_size,
data_format='channels_first' if with_gpu else 'channels_last',
resnet_version=resnet_version,
dtype=dtype
)
inputs = tf.random_uniform([1, 224, 224, 3])
output = model(inputs, training=True)
initial_conv = graph.get_tensor_by_name('resnet_model/initial_conv:0')
max_pool = graph.get_tensor_by_name('resnet_model/initial_max_pool:0')
block_layer1 = graph.get_tensor_by_name('resnet_model/block_layer1:0')
block_layer2 = graph.get_tensor_by_name('resnet_model/block_layer2:0')
block_layer3 = graph.get_tensor_by_name('resnet_model/block_layer3:0')
block_layer4 = graph.get_tensor_by_name('resnet_model/block_layer4:0')
reduce_mean = graph.get_tensor_by_name('resnet_model/final_reduce_mean:0')
dense = graph.get_tensor_by_name('resnet_model/final_dense:0')
self.assertAllEqual(initial_conv.shape, reshape((1, 64, 112, 112)))
self.assertAllEqual(max_pool.shape, reshape((1, 64, 56, 56)))
# The number of channels after each block depends on whether we're
# using the building_block or the bottleneck_block.
if resnet_size < 50:
self.assertAllEqual(block_layer1.shape, reshape((1, 64, 56, 56)))
self.assertAllEqual(block_layer2.shape, reshape((1, 128, 28, 28)))
self.assertAllEqual(block_layer3.shape, reshape((1, 256, 14, 14)))
self.assertAllEqual(block_layer4.shape, reshape((1, 512, 7, 7)))
self.assertAllEqual(reduce_mean.shape, reshape((1, 512, 1, 1)))
else:
self.assertAllEqual(block_layer1.shape, reshape((1, 256, 56, 56)))
self.assertAllEqual(block_layer2.shape, reshape((1, 512, 28, 28)))
self.assertAllEqual(block_layer3.shape, reshape((1, 1024, 14, 14)))
self.assertAllEqual(block_layer4.shape, reshape((1, 2048, 7, 7)))
self.assertAllEqual(reduce_mean.shape, reshape((1, 2048, 1, 1)))
self.assertAllEqual(dense.shape, (1, _LABEL_CLASSES))
self.assertAllEqual(output.shape, (1, _LABEL_CLASSES))
def tensor_shapes_helper(self, resnet_size, resnet_version, with_gpu=False):
self._tensor_shapes_helper(resnet_size=resnet_size,
resnet_version=resnet_version,
dtype=tf.float32, with_gpu=with_gpu)
self._tensor_shapes_helper(resnet_size=resnet_size,
resnet_version=resnet_version,
dtype=tf.float16, with_gpu=with_gpu)
def test_tensor_shapes_resnet_18_v1(self):
self.tensor_shapes_helper(18, resnet_version=1)
def test_tensor_shapes_resnet_18_v2(self):
self.tensor_shapes_helper(18, resnet_version=2)
def test_tensor_shapes_resnet_34_v1(self):
self.tensor_shapes_helper(34, resnet_version=1)
def test_tensor_shapes_resnet_34_v2(self):
self.tensor_shapes_helper(34, resnet_version=2)
def test_tensor_shapes_resnet_50_v1(self):
self.tensor_shapes_helper(50, resnet_version=1)
def test_tensor_shapes_resnet_50_v2(self):
self.tensor_shapes_helper(50, resnet_version=2)
def test_tensor_shapes_resnet_101_v1(self):
self.tensor_shapes_helper(101, resnet_version=1)
def test_tensor_shapes_resnet_101_v2(self):
self.tensor_shapes_helper(101, resnet_version=2)
def test_tensor_shapes_resnet_152_v1(self):
self.tensor_shapes_helper(152, resnet_version=1)
def test_tensor_shapes_resnet_152_v2(self):
self.tensor_shapes_helper(152, resnet_version=2)
def test_tensor_shapes_resnet_200_v1(self):
self.tensor_shapes_helper(200, resnet_version=1)
def test_tensor_shapes_resnet_200_v2(self):
self.tensor_shapes_helper(200, resnet_version=2)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_18_with_gpu_v1(self):
self.tensor_shapes_helper(18, resnet_version=1, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_18_with_gpu_v2(self):
self.tensor_shapes_helper(18, resnet_version=2, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_34_with_gpu_v1(self):
self.tensor_shapes_helper(34, resnet_version=1, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_34_with_gpu_v2(self):
self.tensor_shapes_helper(34, resnet_version=2, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_50_with_gpu_v1(self):
self.tensor_shapes_helper(50, resnet_version=1, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_50_with_gpu_v2(self):
self.tensor_shapes_helper(50, resnet_version=2, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_101_with_gpu_v1(self):
self.tensor_shapes_helper(101, resnet_version=1, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_101_with_gpu_v2(self):
self.tensor_shapes_helper(101, resnet_version=2, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_152_with_gpu_v1(self):
self.tensor_shapes_helper(152, resnet_version=1, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_152_with_gpu_v2(self):
self.tensor_shapes_helper(152, resnet_version=2, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_200_with_gpu_v1(self):
self.tensor_shapes_helper(200, resnet_version=1, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_200_with_gpu_v2(self):
self.tensor_shapes_helper(200, resnet_version=2, with_gpu=True)
def resnet_model_fn_helper(self, mode, resnet_version, dtype):
"""Tests that the EstimatorSpec is given the appropriate arguments."""
tf.train.create_global_step()
input_fn = imagenet_main.get_synth_input_fn(dtype)
dataset = input_fn(True, '', _BATCH_SIZE)
iterator = dataset.make_initializable_iterator()
features, labels = iterator.get_next()
spec = imagenet_main.imagenet_model_fn(
features, labels, mode, {
'dtype': dtype,
'resnet_size': 50,
'data_format': 'channels_last',
'batch_size': _BATCH_SIZE,
'resnet_version': resnet_version,
'loss_scale': 128 if dtype == tf.float16 else 1,
'fine_tune': False,
})
predictions = spec.predictions
self.assertAllEqual(predictions['probabilities'].shape,
(_BATCH_SIZE, _LABEL_CLASSES))
self.assertEqual(predictions['probabilities'].dtype, tf.float32)
self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
self.assertEqual(predictions['classes'].dtype, tf.int64)
if mode != tf.estimator.ModeKeys.PREDICT:
loss = spec.loss
self.assertAllEqual(loss.shape, ())
self.assertEqual(loss.dtype, tf.float32)
if mode == tf.estimator.ModeKeys.EVAL:
eval_metric_ops = spec.eval_metric_ops
self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ())
self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
def test_resnet_model_fn_train_mode_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, resnet_version=1,
dtype=tf.float32)
def test_resnet_model_fn_train_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, resnet_version=2,
dtype=tf.float32)
def test_resnet_model_fn_eval_mode_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, resnet_version=1,
dtype=tf.float32)
def test_resnet_model_fn_eval_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, resnet_version=2,
dtype=tf.float32)
def test_resnet_model_fn_predict_mode_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, resnet_version=1,
dtype=tf.float32)
def test_resnet_model_fn_predict_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, resnet_version=2,
dtype=tf.float32)
def _test_imagenetmodel_shape(self, resnet_version):
batch_size = 135
num_classes = 246
model = imagenet_main.ImagenetModel(
50, data_format='channels_last', num_classes=num_classes,
resnet_version=resnet_version)
fake_input = tf.random_uniform([batch_size, 224, 224, 3])
output = model(fake_input, training=True)
self.assertAllEqual(output.shape, (batch_size, num_classes))
def test_imagenetmodel_shape_v1(self):
self._test_imagenetmodel_shape(resnet_version=1)
def test_imagenetmodel_shape_v2(self):
self._test_imagenetmodel_shape(resnet_version=2)
def test_imagenet_end_to_end_synthetic_v1(self):
integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '1']
)
def test_imagenet_end_to_end_synthetic_v2(self):
integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '2']
)
def test_imagenet_end_to_end_synthetic_v1_tiny(self):
integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '1', '-resnet_size', '18']
)
def test_imagenet_end_to_end_synthetic_v2_tiny(self):
integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '2', '-resnet_size', '18']
)
def test_imagenet_end_to_end_synthetic_v1_huge(self):
integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '1', '-resnet_size', '200']
)
def test_imagenet_end_to_end_synthetic_v2_huge(self):
integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '2', '-resnet_size', '200']
)
if __name__ == '__main__':
tf.test.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Executes Keras benchmarks and accuracy tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
import json
from absl import flags
from absl.testing import flagsaver
import tensorflow as tf # pylint: disable=g-bad-import-order
FLAGS = flags.FLAGS
class KerasBenchmark(tf.test.Benchmark):
"""Base benchmark class with methods to simplify testing."""
local_flags = None
def __init__(self, output_dir=None, default_flags=None, flag_methods=None):
self.output_dir = output_dir
self.default_flags = default_flags or {}
self.flag_methods = flag_methods or {}
def _get_model_dir(self, folder_name):
return os.path.join(self.output_dir, folder_name)
def _setup(self):
"""Sets up and resets flags before each test."""
tf.logging.set_verbosity(tf.logging.DEBUG)
if KerasBenchmark.local_flags is None:
for flag_method in self.flag_methods:
flag_method()
# Loads flags to get defaults to then override. List cannot be empty.
flags.FLAGS(['foo'])
# Overrides flag values with defaults for the class of tests.
for k, v in self.default_flags.items():
setattr(FLAGS, k, v)
saved_flag_values = flagsaver.save_flag_values()
KerasBenchmark.local_flags = saved_flag_values
else:
flagsaver.restore_flag_values(KerasBenchmark.local_flags)
def _report_benchmark(self,
stats,
wall_time_sec,
top_1_max=None,
top_1_min=None,
log_steps=None,
total_batch_size=None,
warmup=1):
"""Report benchmark results by writing to local protobuf file
Args:
stats: dict returned from keras models with known entries.
wall_time_sec: the during of the benchmark execution in seconds
top_1_max: highest passing level for top_1 accuracy.
top_1_min: lowest passing level for top_1 accuracy.
log_steps: How often the log was created for stats['step_timestamp_log'].
total_batch_size: Global batch-size.
warmup: number of entries in stats['step_timestamp_log'] to ignore.
"""
extras = {}
if 'accuracy_top_1' in stats:
extras['accuracy'] = self._json_description(
stats['accuracy_top_1'],
priority=0,
min_value=top_1_min,
max_value=top_1_max)
extras['top_1_train_accuracy'] = self._json_description(
stats['training_accuracy_top_1'], priority=1)
if (warmup and 'step_timestamp_log' in stats and
len(stats['step_timestamp_log']) > warmup):
# first entry in the time_log is start of step 1. The rest of the
# entries are the end of each step recorded
time_log = stats['step_timestamp_log']
elapsed = time_log[-1].timestamp - time_log[warmup].timestamp
num_examples = (
total_batch_size * log_steps * (len(time_log) - warmup - 1))
examples_per_sec = num_examples / elapsed
extras['exp_per_second'] = self._json_description(
examples_per_sec, priority=2)
if 'avg_exp_per_second' in stats:
extras['avg_exp_per_second'] = self._json_description(
stats['avg_exp_per_second'], priority=3)
self.report_benchmark(iters=-1, wall_time=wall_time_sec, extras=extras)
def _json_description(self,
value,
priority=None,
min_value=None,
max_value=None):
"""Get a json-formatted string describing the attributes for a metric"""
attributes = {}
attributes['value'] = value
if priority:
attributes['priority'] = priority
if min_value:
attributes['min_value'] = min_value
if max_value:
attributes['max_value'] = max_value
if min_value or max_value:
succeeded = True
if min_value and value < min_value:
succeeded = False
if max_value and value > max_value:
succeeded = False
attributes['succeeded'] = succeeded
return json.dumps(attributes)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Executes Keras benchmarks and accuracy tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
from absl import flags
from official.resnet import cifar10_main as cifar_main
from official.resnet.keras import keras_benchmark
from official.resnet.keras import keras_cifar_main
from official.resnet.keras import keras_common
DATA_DIR = '/data/cifar10_data/cifar-10-batches-bin'
MIN_TOP_1_ACCURACY = 0.925
MAX_TOP_1_ACCURACY = 0.938
FLAGS = flags.FLAGS
class Resnet56KerasAccuracy(keras_benchmark.KerasBenchmark):
"""Accuracy tests for ResNet56 Keras CIFAR-10."""
def __init__(self, output_dir=None):
flag_methods = [
keras_common.define_keras_flags, cifar_main.define_cifar_flags
]
super(Resnet56KerasAccuracy, self).__init__(
output_dir=output_dir, flag_methods=flag_methods)
def benchmark_graph_1_gpu(self):
"""Test keras based model with Keras fit and distribution strategies."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.data_dir = DATA_DIR
FLAGS.batch_size = 128
FLAGS.train_epochs = 182
FLAGS.model_dir = self._get_model_dir('keras_resnet56_1_gpu')
FLAGS.dtype = 'fp32'
self._run_and_report_benchmark()
def benchmark_1_gpu(self):
"""Test keras based model with eager and distribution strategies."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.data_dir = DATA_DIR
FLAGS.batch_size = 128
FLAGS.train_epochs = 182
FLAGS.model_dir = self._get_model_dir('keras_resnet56_eager_1_gpu')
FLAGS.dtype = 'fp32'
FLAGS.enable_eager = True
self._run_and_report_benchmark()
def benchmark_2_gpu(self):
"""Test keras based model with eager and distribution strategies."""
self._setup()
FLAGS.num_gpus = 2
FLAGS.data_dir = DATA_DIR
FLAGS.batch_size = 128
FLAGS.train_epochs = 182
FLAGS.model_dir = self._get_model_dir('keras_resnet56_eager_2_gpu')
FLAGS.dtype = 'fp32'
FLAGS.enable_eager = True
self._run_and_report_benchmark()
def benchmark_graph_2_gpu(self):
"""Test keras based model with Keras fit and distribution strategies."""
self._setup()
FLAGS.num_gpus = 2
FLAGS.data_dir = DATA_DIR
FLAGS.batch_size = 128
FLAGS.train_epochs = 182
FLAGS.model_dir = self._get_model_dir('keras_resnet56_2_gpu')
FLAGS.dtype = 'fp32'
self._run_and_report_benchmark()
def benchmark_graph_1_gpu_no_dist_strat(self):
"""Test keras based model with Keras fit but not distribution strategies."""
self._setup()
FLAGS.turn_off_distribution_strategy = True
FLAGS.num_gpus = 1
FLAGS.data_dir = DATA_DIR
FLAGS.batch_size = 128
FLAGS.train_epochs = 182
FLAGS.model_dir = self._get_model_dir('keras_resnet56_no_dist_strat_1_gpu')
FLAGS.dtype = 'fp32'
self._run_and_report_benchmark()
def _run_and_report_benchmark(self):
start_time_sec = time.time()
stats = keras_cifar_main.run(FLAGS)
wall_time_sec = time.time() - start_time_sec
super(Resnet56KerasAccuracy, self)._report_benchmark(
stats,
wall_time_sec,
top_1_min=MIN_TOP_1_ACCURACY,
top_1_max=MAX_TOP_1_ACCURACY,
total_batch_size=FLAGS.batch_size,
log_steps=100)
class Resnet56KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
"""Short performance tests for ResNet56 via Keras and CIFAR-10."""
def __init__(self, output_dir=None, default_flags=None):
flag_methods = [
keras_common.define_keras_flags, cifar_main.define_cifar_flags
]
super(Resnet56KerasBenchmarkBase, self).__init__(
output_dir=output_dir,
flag_methods=flag_methods,
default_flags=default_flags)
def _run_and_report_benchmark(self):
start_time_sec = time.time()
stats = keras_cifar_main.run(FLAGS)
wall_time_sec = time.time() - start_time_sec
super(Resnet56KerasBenchmarkBase, self)._report_benchmark(
stats,
wall_time_sec,
total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
def benchmark_1_gpu_no_dist_strat(self):
self._setup()
FLAGS.num_gpus = 1
FLAGS.enable_eager = True
FLAGS.turn_off_distribution_strategy = True
FLAGS.batch_size = 128
self._run_and_report_benchmark()
def benchmark_graph_1_gpu_no_dist_strat(self):
self._setup()
FLAGS.num_gpus = 1
FLAGS.enable_eager = False
FLAGS.turn_off_distribution_strategy = True
FLAGS.batch_size = 128
self._run_and_report_benchmark()
def benchmark_1_gpu(self):
self._setup()
FLAGS.num_gpus = 1
FLAGS.enable_eager = True
FLAGS.turn_off_distribution_strategy = False
FLAGS.batch_size = 128
self._run_and_report_benchmark()
def benchmark_graph_1_gpu(self):
self._setup()
FLAGS.num_gpus = 1
FLAGS.enable_eager = False
FLAGS.turn_off_distribution_strategy = False
FLAGS.batch_size = 128
self._run_and_report_benchmark()
def benchmark_2_gpu(self):
self._setup()
FLAGS.num_gpus = 2
FLAGS.enable_eager = True
FLAGS.turn_off_distribution_strategy = False
FLAGS.batch_size = 128 * 2 # 2 GPUs
self._run_and_report_benchmark()
def benchmark_graph_2_gpu(self):
self._setup()
FLAGS.num_gpus = 2
FLAGS.enable_eager = False
FLAGS.turn_off_distribution_strategy = False
FLAGS.batch_size = 128 * 2 # 2 GPUs
self._run_and_report_benchmark()
class Resnet56KerasBenchmarkSynth(Resnet56KerasBenchmarkBase):
"""Synthetic benchmarks for ResNet56 and Keras."""
def __init__(self, output_dir=None):
def_flags = {}
def_flags['skip_eval'] = True
def_flags['use_synthetic_data'] = True
def_flags['train_steps'] = 110
def_flags['log_steps'] = 10
super(Resnet56KerasBenchmarkSynth, self).__init__(
output_dir=output_dir, default_flags=def_flags)
class Resnet56KerasBenchmarkReal(Resnet56KerasBenchmarkBase):
"""Real data benchmarks for ResNet56 and Keras."""
def __init__(self, output_dir=None):
def_flags = {}
def_flags['skip_eval'] = True
def_flags['data_dir'] = DATA_DIR
def_flags['train_steps'] = 110
def_flags['log_steps'] = 10
super(Resnet56KerasBenchmarkReal, self).__init__(
output_dir=output_dir, default_flags=def_flags)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a ResNet model on the Cifar-10 dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app as absl_app
from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import cifar10_main as cifar_main
from official.resnet.keras import keras_common
from official.resnet.keras import resnet_cifar_model
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import distribution_utils
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
(0.1, 91), (0.01, 136), (0.001, 182)
]
def learning_rate_schedule(current_epoch,
current_batch,
batches_per_epoch,
batch_size):
"""Handles linear scaling rule and LR decay.
Scale learning rate at epoch boundaries provided in LR_SCHEDULE by the
provided scaling factor.
Args:
current_epoch: integer, current epoch indexed from 0.
current_batch: integer, current batch in the current epoch, indexed from 0.
batches_per_epoch: integer, number of steps in an epoch.
batch_size: integer, total batch sized.
Returns:
Adjusted learning rate.
"""
initial_learning_rate = keras_common.BASE_LEARNING_RATE * batch_size / 128
learning_rate = initial_learning_rate
for mult, start_epoch in LR_SCHEDULE:
if current_epoch >= start_epoch:
learning_rate = initial_learning_rate * mult
else:
break
return learning_rate
def parse_record_keras(raw_record, is_training, dtype):
"""Parses a record containing a training example of an image.
The input record is parsed into a label and image, and the image is passed
through preprocessing steps (cropping, flipping, and so on).
This method converts the label to one hot to fit the loss function.
Args:
raw_record: scalar Tensor tf.string containing a serialized
Example protocol buffer.
is_training: A boolean denoting whether the input is for training.
dtype: Data type to use for input images.
Returns:
Tuple with processed image tensor and one-hot-encoded label tensor.
"""
image, label = cifar_main.parse_record(raw_record, is_training, dtype)
label = tf.sparse_to_dense(label, (cifar_main.NUM_CLASSES,), 1)
return image, label
def run(flags_obj):
"""Run ResNet Cifar-10 training and eval loop using native Keras APIs.
Args:
flags_obj: An object containing parsed flag values.
Raises:
ValueError: If fp16 is passed as it is not currently supported.
Returns:
Dictionary of training and eval stats.
"""
if flags_obj.enable_eager:
tf.enable_eager_execution()
dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'fp16':
raise ValueError('dtype fp16 is not supported in Keras. Use the default '
'value(fp32).')
data_format = flags_obj.data_format
if data_format is None:
data_format = ('channels_first'
if tf.test.is_built_with_cuda() else 'channels_last')
tf.keras.backend.set_image_data_format(data_format)
if flags_obj.use_synthetic_data:
input_fn = keras_common.get_synth_input_fn(
height=cifar_main.HEIGHT,
width=cifar_main.WIDTH,
num_channels=cifar_main.NUM_CHANNELS,
num_classes=cifar_main.NUM_CLASSES,
dtype=flags_core.get_tf_dtype(flags_obj))
else:
input_fn = cifar_main.input_fn
train_input_dataset = input_fn(
is_training=True,
data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras)
eval_input_dataset = input_fn(
is_training=False,
data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras)
strategy = distribution_utils.get_distribution_strategy(
num_gpus=flags_obj.num_gpus,
turn_off_distribution_strategy=flags_obj.turn_off_distribution_strategy)
strategy_scope = keras_common.get_strategy_scope(strategy)
with strategy_scope:
optimizer = keras_common.get_optimizer()
model = resnet_cifar_model.resnet56(classes=cifar_main.NUM_CLASSES)
model.compile(loss='categorical_crossentropy',
optimizer=optimizer,
metrics=['categorical_accuracy'])
time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
learning_rate_schedule, cifar_main.NUM_IMAGES['train'])
train_steps = cifar_main.NUM_IMAGES['train'] // flags_obj.batch_size
train_epochs = flags_obj.train_epochs
if flags_obj.train_steps:
train_steps = min(flags_obj.train_steps, train_steps)
train_epochs = 1
num_eval_steps = (cifar_main.NUM_IMAGES['validation'] //
flags_obj.batch_size)
validation_data = eval_input_dataset
if flags_obj.skip_eval:
tf.keras.backend.set_learning_phase(1)
num_eval_steps = None
validation_data = None
history = model.fit(train_input_dataset,
epochs=train_epochs,
steps_per_epoch=train_steps,
callbacks=[
time_callback,
lr_callback,
tensorboard_callback
],
validation_steps=num_eval_steps,
validation_data=validation_data,
verbose=2)
eval_output = None
if not flags_obj.skip_eval:
eval_output = model.evaluate(eval_input_dataset,
steps=num_eval_steps,
verbose=1)
stats = keras_common.build_stats(history, eval_output, time_callback)
return stats
def main(_):
with logger.benchmark_context(flags.FLAGS):
return run(flags.FLAGS)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
cifar_main.define_cifar_flags()
keras_common.define_keras_flags()
absl_app.run(main)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common util functions and classes used by both keras cifar and imagenet."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import numpy as np
# pylint: disable=g-bad-import-order
from absl import flags
import tensorflow as tf
from tensorflow.python.keras.optimizer_v2 import (gradient_descent as
gradient_descent_v2)
FLAGS = flags.FLAGS
BASE_LEARNING_RATE = 0.1 # This matches Jing's version.
TRAIN_TOP_1 = 'training_accuracy_top_1'
class BatchTimestamp(object):
"""A structure to store batch time stamp."""
def __init__(self, batch_index, timestamp):
self.batch_index = batch_index
self.timestamp = timestamp
class TimeHistory(tf.keras.callbacks.Callback):
"""Callback for Keras models."""
def __init__(self, batch_size, log_steps):
"""Callback for logging performance (# image/second).
Args:
batch_size: Total batch size.
"""
self.batch_size = batch_size
super(TimeHistory, self).__init__()
self.log_steps = log_steps
# Logs start of step 0 then end of each step based on log_steps interval.
self.timestamp_log = []
def on_train_begin(self, logs=None):
self.record_batch = True
def on_train_end(self, logs=None):
self.train_finish_time = time.time()
def on_batch_begin(self, batch, logs=None):
if self.record_batch:
timestamp = time.time()
self.start_time = timestamp
self.record_batch = False
if batch == 0:
self.timestamp_log.append(BatchTimestamp(batch, timestamp))
def on_batch_end(self, batch, logs=None):
if batch % self.log_steps == 0:
timestamp = time.time()
elapsed_time = timestamp - self.start_time
examples_per_second = (self.batch_size * self.log_steps) / elapsed_time
if batch != 0:
self.record_batch = True
self.timestamp_log.append(BatchTimestamp(batch, timestamp))
tf.logging.info("BenchmarkMetric: {'num_batches':%d, 'time_taken': %f,"
"'images_per_second': %f}" %
(batch, elapsed_time, examples_per_second))
class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
"""Callback to update learning rate on every batch (not epoch boundaries).
N.B. Only support Keras optimizers, not TF optimizers.
Args:
schedule: a function that takes an epoch index and a batch index as input
(both integer, indexed from 0) and returns a new learning rate as
output (float).
"""
def __init__(self, schedule, batch_size, num_images):
super(LearningRateBatchScheduler, self).__init__()
self.schedule = schedule
self.batches_per_epoch = num_images / batch_size
self.batch_size = batch_size
self.epochs = -1
self.prev_lr = -1
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, 'learning_rate'):
raise ValueError('Optimizer must have a "learning_rate" attribute.')
self.epochs += 1
def on_batch_begin(self, batch, logs=None):
"""Executes before step begins."""
lr = self.schedule(self.epochs,
batch,
self.batches_per_epoch,
self.batch_size)
if not isinstance(lr, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function should be float.')
if lr != self.prev_lr:
self.model.optimizer.learning_rate = lr # lr should be a float here
self.prev_lr = lr
tf.logging.debug('Epoch %05d Batch %05d: LearningRateBatchScheduler '
'change learning rate to %s.', self.epochs, batch, lr)
def get_optimizer():
"""Returns optimizer to use."""
# The learning_rate is overwritten at the beginning of each step by callback.
return gradient_descent_v2.SGD(learning_rate=0.1, momentum=0.9)
def get_callbacks(learning_rate_schedule_fn, num_images):
"""Returns common callbacks."""
time_callback = TimeHistory(FLAGS.batch_size, FLAGS.log_steps)
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=FLAGS.model_dir)
lr_callback = LearningRateBatchScheduler(
learning_rate_schedule_fn,
batch_size=FLAGS.batch_size,
num_images=num_images)
return time_callback, tensorboard_callback, lr_callback
def build_stats(history, eval_output, time_callback):
"""Normalizes and returns dictionary of stats.
Args:
history: Results of the training step. Supports both categorical_accuracy
and sparse_categorical_accuracy.
eval_output: Output of the eval step. Assumes first value is eval_loss and
second value is accuracy_top_1.
time_callback: Time tracking callback likely used during keras.fit.
Returns:
Dictionary of normalized results.
"""
stats = {}
if eval_output:
stats['accuracy_top_1'] = eval_output[1].item()
stats['eval_loss'] = eval_output[0].item()
if history and history.history:
train_hist = history.history
# Gets final loss from training.
stats['loss'] = train_hist['loss'][-1].item()
# Gets top_1 training accuracy.
if 'categorical_accuracy' in train_hist:
stats[TRAIN_TOP_1] = train_hist['categorical_accuracy'][-1].item()
elif 'sparse_categorical_accuracy' in train_hist:
stats[TRAIN_TOP_1] = train_hist['sparse_categorical_accuracy'][-1].item()
if time_callback:
timestamp_log = time_callback.timestamp_log
stats['step_timestamp_log'] = timestamp_log
stats['train_finish_time'] = time_callback.train_finish_time
if len(timestamp_log) > 1:
stats['avg_exp_per_second'] = (
time_callback.batch_size * time_callback.log_steps *
(len(time_callback.timestamp_log)-1) /
(timestamp_log[-1].timestamp - timestamp_log[0].timestamp))
return stats
def define_keras_flags():
flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?')
flags.DEFINE_boolean(name='skip_eval', default=False, help='Skip evaluation?')
flags.DEFINE_integer(
name='train_steps', default=None,
help='The number of steps to run for training. If it is larger than '
'# batches per epoch, then use # batches per epoch. When this flag is '
'set, only one epoch is going to run for training.')
flags.DEFINE_integer(
name='log_steps', default=100,
help='For every log_steps, we log the timing information such as '
'examples per second. Besides, for every log_steps, we store the '
'timestamp of a batch end.')
def get_synth_input_fn(height, width, num_channels, num_classes,
dtype=tf.float32):
"""Returns an input function that returns a dataset with random data.
This input_fn returns a data set that iterates over a set of random data and
bypasses all preprocessing, e.g. jpeg decode and copy. The host to device
copy is still included. This used to find the upper throughput bound when
tuning the full input pipeline.
Args:
height: Integer height that will be used to create a fake image tensor.
width: Integer width that will be used to create a fake image tensor.
num_channels: Integer depth that will be used to create a fake image tensor.
num_classes: Number of classes that should be represented in the fake labels
tensor
dtype: Data type for features/images.
Returns:
An input_fn that can be used in place of a real one to return a dataset
that can be used for iteration.
"""
# pylint: disable=unused-argument
def input_fn(is_training, data_dir, batch_size, *args, **kwargs):
"""Returns dataset filled with random data."""
# Synthetic input should be within [0, 255].
inputs = tf.truncated_normal(
[height, width, num_channels],
dtype=dtype,
mean=127,
stddev=60,
name='synthetic_inputs')
labels = tf.random_uniform(
[1],
minval=0,
maxval=num_classes - 1,
dtype=tf.int32,
name='synthetic_labels')
data = tf.data.Dataset.from_tensors((inputs, labels)).repeat()
data = data.batch(batch_size)
data = data.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
return data
return input_fn
def get_strategy_scope(strategy):
if strategy:
strategy_scope = strategy.scope()
else:
strategy_scope = DummyContextManager()
return strategy_scope
class DummyContextManager(object):
def __enter__(self):
pass
def __exit__(self, *args):
pass
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the keras_common module."""
from __future__ import absolute_import
from __future__ import print_function
from mock import Mock
import numpy as np
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet.keras import keras_common
tf.logging.set_verbosity(tf.logging.ERROR)
class KerasCommonTests(tf.test.TestCase):
"""Tests for keras_common."""
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(KerasCommonTests, cls).setUpClass()
def test_build_stats(self):
history = self._build_history(1.145, cat_accuracy=.99988)
eval_output = self._build_eval_output(.56432111, 5.990)
th = keras_common.TimeHistory(128, 100)
th.batch_start_timestamps = [1, 2, 3]
th.batch_end_timestamps = [4, 5, 6]
th.train_finish_time = 12345
stats = keras_common.build_stats(history, eval_output, th)
self.assertEqual(1.145, stats['loss'])
self.assertEqual(.99988, stats['training_accuracy_top_1'])
self.assertEqual(.56432111, stats['accuracy_top_1'])
self.assertEqual(5.990, stats['eval_loss'])
self.assertItemsEqual([1, 2, 3], stats['batch_start_timestamps'])
self.assertItemsEqual([4, 5, 6], stats['batch_end_timestamps'])
self.assertEqual(12345, stats['train_finish_time'])
def test_build_stats_sparse(self):
history = self._build_history(1.145, cat_accuracy_sparse=.99988)
eval_output = self._build_eval_output(.928, 1.9844)
stats = keras_common.build_stats(history, eval_output, None)
self.assertEqual(1.145, stats['loss'])
self.assertEqual(.99988, stats['training_accuracy_top_1'])
self.assertEqual(.928, stats['accuracy_top_1'])
self.assertEqual(1.9844, stats['eval_loss'])
def test_time_history(self):
th = keras_common.TimeHistory(batch_size=128, log_steps=3)
th.on_train_begin()
th.on_batch_begin(0)
th.on_batch_end(0)
th.on_batch_begin(1)
th.on_batch_end(1)
th.on_batch_begin(2)
th.on_batch_end(2)
th.on_batch_begin(3)
th.on_batch_end(3)
th.on_batch_begin(4)
th.on_batch_end(4)
th.on_batch_begin(5)
th.on_batch_end(5)
th.on_batch_begin(6)
th.on_batch_end(6)
th.on_train_end()
self.assertEqual(3, len(th.batch_start_timestamps))
self.assertEqual(2, len(th.batch_end_timestamps))
self.assertEqual(0, th.batch_start_timestamps[0].batch_index)
self.assertEqual(1, th.batch_start_timestamps[1].batch_index)
self.assertEqual(4, th.batch_start_timestamps[2].batch_index)
self.assertEqual(3, th.batch_end_timestamps[0].batch_index)
self.assertEqual(6, th.batch_end_timestamps[1].batch_index)
def _build_history(self, loss, cat_accuracy=None,
cat_accuracy_sparse=None):
history_p = Mock()
history = {}
history_p.history = history
history['loss'] = [np.float64(loss)]
if cat_accuracy:
history['categorical_accuracy'] = [np.float64(cat_accuracy)]
if cat_accuracy_sparse:
history['sparse_categorical_accuracy'] = [np.float64(cat_accuracy_sparse)]
return history_p
def _build_eval_output(self, top_1, eval_loss):
eval_output = [np.float64(eval_loss), np.float64(top_1)]
return eval_output
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Executes Keras benchmarks and accuracy tests."""
from __future__ import print_function
import os
import time
from absl import flags
from official.resnet import imagenet_main
from official.resnet.keras import keras_benchmark
from official.resnet.keras import keras_common
from official.resnet.keras import keras_imagenet_main
MIN_TOP_1_ACCURACY = 0.76
MAX_TOP_1_ACCURACY = 0.77
DATA_DIR = '/data/imagenet/'
FLAGS = flags.FLAGS
class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
"""Benchmark accuracy tests for ResNet50 in Keras."""
def __init__(self, output_dir=None):
flag_methods = [
keras_common.define_keras_flags, imagenet_main.define_imagenet_flags
]
super(Resnet50KerasAccuracy, self).__init__(
output_dir=output_dir, flag_methods=flag_methods)
def benchmark_graph_8_gpu(self):
"""Test Keras model with Keras fit/dist_strat and 8 GPUs."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.data_dir = DATA_DIR
FLAGS.batch_size = 128 * 8
FLAGS.train_epochs = 90
FLAGS.model_dir = self._get_model_dir('keras_resnet50_8_gpu')
FLAGS.dtype = 'fp32'
self._run_and_report_benchmark()
def benchmark_8_gpu(self):
"""Test Keras model with eager, dist_strat and 8 GPUs."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.data_dir = DATA_DIR
FLAGS.batch_size = 128 * 8
FLAGS.train_epochs = 90
FLAGS.model_dir = self._get_model_dir('keras_resnet50_eager_8_gpu')
FLAGS.dtype = 'fp32'
FLAGS.enable_eager = True
self._run_and_report_benchmark()
def _run_and_report_benchmark(self):
start_time_sec = time.time()
stats = keras_imagenet_main.run(flags.FLAGS)
wall_time_sec = time.time() - start_time_sec
super(Resnet50KerasAccuracy, self)._report_benchmark(
stats,
wall_time_sec,
top_1_min=MIN_TOP_1_ACCURACY,
top_1_max=MAX_TOP_1_ACCURACY,
total_batch_size=FLAGS.batch_size,
log_steps=100)
def _get_model_dir(self, folder_name):
return os.path.join(self.output_dir, folder_name)
class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
"""Resnet50 benchmarks."""
def __init__(self, output_dir=None, default_flags=None):
flag_methods = [
keras_common.define_keras_flags, imagenet_main.define_imagenet_flags
]
super(Resnet50KerasBenchmarkBase, self).__init__(
output_dir=output_dir,
flag_methods=flag_methods,
default_flags=default_flags)
def _run_and_report_benchmark(self):
start_time_sec = time.time()
stats = keras_imagenet_main.run(FLAGS)
wall_time_sec = time.time() - start_time_sec
super(Resnet50KerasBenchmarkBase, self)._report_benchmark(
stats,
wall_time_sec,
total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
def benchmark_1_gpu_no_dist_strat(self):
self._setup()
FLAGS.num_gpus = 1
FLAGS.enable_eager = True
FLAGS.turn_off_distribution_strategy = True
FLAGS.batch_size = 128
self._run_and_report_benchmark()
def benchmark_graph_1_gpu_no_dist_strat(self):
self._setup()
FLAGS.num_gpus = 1
FLAGS.enable_eager = False
FLAGS.turn_off_distribution_strategy = True
FLAGS.batch_size = 128
self._run_and_report_benchmark()
def benchmark_1_gpu(self):
self._setup()
FLAGS.num_gpus = 1
FLAGS.enable_eager = True
FLAGS.turn_off_distribution_strategy = False
FLAGS.batch_size = 128
self._run_and_report_benchmark()
def benchmark_graph_1_gpu(self):
self._setup()
FLAGS.num_gpus = 1
FLAGS.enable_eager = False
FLAGS.turn_off_distribution_strategy = False
FLAGS.batch_size = 128
self._run_and_report_benchmark()
def benchmark_8_gpu(self):
self._setup()
FLAGS.num_gpus = 8
FLAGS.enable_eager = True
FLAGS.turn_off_distribution_strategy = False
FLAGS.batch_size = 128 * 8 # 8 GPUs
self._run_and_report_benchmark()
def benchmark_graph_8_gpu(self):
self._setup()
FLAGS.num_gpus = 8
FLAGS.enable_eager = False
FLAGS.turn_off_distribution_strategy = False
FLAGS.batch_size = 128 * 8 # 8 GPUs
self._run_and_report_benchmark()
def fill_report_object(self, stats):
super(Resnet50KerasBenchmarkBase, self).fill_report_object(
stats,
total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
class Resnet50KerasBenchmarkSynth(Resnet50KerasBenchmarkBase):
"""Resnet50 synthetic benchmark tests."""
def __init__(self, output_dir=None):
def_flags = {}
def_flags['skip_eval'] = True
def_flags['use_synthetic_data'] = True
def_flags['train_steps'] = 110
def_flags['log_steps'] = 10
super(Resnet50KerasBenchmarkSynth, self).__init__(
output_dir=output_dir, default_flags=def_flags)
class Resnet50KerasBenchmarkReal(Resnet50KerasBenchmarkBase):
"""Resnet50 real data benchmark tests."""
def __init__(self, output_dir=None):
def_flags = {}
def_flags['skip_eval'] = True
def_flags['data_dir'] = DATA_DIR
def_flags['train_steps'] = 110
def_flags['log_steps'] = 10
super(Resnet50KerasBenchmarkReal, self).__init__(
output_dir=output_dir, default_flags=def_flags)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a ResNet model on the ImageNet dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app as absl_app
from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import imagenet_main
from official.resnet.keras import keras_common
from official.resnet.keras import resnet_model
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import distribution_utils
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
(1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
]
def learning_rate_schedule(current_epoch,
current_batch,
batches_per_epoch,
batch_size):
"""Handles linear scaling rule, gradual warmup, and LR decay.
Scale learning rate at epoch boundaries provided in LR_SCHEDULE by the
provided scaling factor.
Args:
current_epoch: integer, current epoch indexed from 0.
current_batch: integer, current batch in the current epoch, indexed from 0.
batches_per_epoch: integer, number of steps in an epoch.
batch_size: integer, total batch sized.
Returns:
Adjusted learning rate.
"""
initial_lr = keras_common.BASE_LEARNING_RATE * batch_size / 256
epoch = current_epoch + float(current_batch) / batches_per_epoch
warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0]
if epoch < warmup_end_epoch:
# Learning rate increases linearly per step.
return initial_lr * warmup_lr_multiplier * epoch / warmup_end_epoch
for mult, start_epoch in LR_SCHEDULE:
if epoch >= start_epoch:
learning_rate = initial_lr * mult
else:
break
return learning_rate
def parse_record_keras(raw_record, is_training, dtype):
"""Adjust the shape of label."""
image, label = imagenet_main.parse_record(raw_record, is_training, dtype)
# Subtract one so that labels are in [0, 1000), and cast to float32 for
# Keras model.
label = tf.cast(tf.cast(tf.reshape(label, shape=[1]), dtype=tf.int32) - 1,
dtype=tf.float32)
return image, label
def run(flags_obj):
"""Run ResNet ImageNet training and eval loop using native Keras APIs.
Args:
flags_obj: An object containing parsed flag values.
Raises:
ValueError: If fp16 is passed as it is not currently supported.
"""
if flags_obj.enable_eager:
tf.enable_eager_execution()
dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'fp16':
raise ValueError('dtype fp16 is not supported in Keras. Use the default '
'value(fp32).')
data_format = flags_obj.data_format
if data_format is None:
data_format = ('channels_first'
if tf.test.is_built_with_cuda() else 'channels_last')
tf.keras.backend.set_image_data_format(data_format)
# pylint: disable=protected-access
if flags_obj.use_synthetic_data:
input_fn = keras_common.get_synth_input_fn(
height=imagenet_main.DEFAULT_IMAGE_SIZE,
width=imagenet_main.DEFAULT_IMAGE_SIZE,
num_channels=imagenet_main.NUM_CHANNELS,
num_classes=imagenet_main.NUM_CLASSES,
dtype=flags_core.get_tf_dtype(flags_obj))
else:
input_fn = imagenet_main.input_fn
train_input_dataset = input_fn(is_training=True,
data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras)
eval_input_dataset = input_fn(is_training=False,
data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras)
strategy = distribution_utils.get_distribution_strategy(
num_gpus=flags_obj.num_gpus,
turn_off_distribution_strategy=flags_obj.turn_off_distribution_strategy)
strategy_scope = keras_common.get_strategy_scope(strategy)
with strategy_scope:
optimizer = keras_common.get_optimizer()
model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES)
model.compile(loss='sparse_categorical_crossentropy',
optimizer=optimizer,
metrics=['sparse_categorical_accuracy'])
time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
learning_rate_schedule, imagenet_main.NUM_IMAGES['train'])
train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size
train_epochs = flags_obj.train_epochs
if flags_obj.train_steps:
train_steps = min(flags_obj.train_steps, train_steps)
train_epochs = 1
num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] //
flags_obj.batch_size)
validation_data = eval_input_dataset
if flags_obj.skip_eval:
# Only build the training graph. This reduces memory usage introduced by
# control flow ops in layers that have different implementations for
# training and inference (e.g., batch norm).
tf.keras.backend.set_learning_phase(1)
num_eval_steps = None
validation_data = None
history = model.fit(train_input_dataset,
epochs=train_epochs,
steps_per_epoch=train_steps,
callbacks=[
time_callback,
lr_callback,
tensorboard_callback
],
validation_steps=num_eval_steps,
validation_data=validation_data,
verbose=2)
eval_output = None
if not flags_obj.skip_eval:
eval_output = model.evaluate(eval_input_dataset,
steps=num_eval_steps,
verbose=1)
stats = keras_common.build_stats(history, eval_output, time_callback)
return stats
def main(_):
with logger.benchmark_context(flags.FLAGS):
return run(flags.FLAGS)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
imagenet_main.define_imagenet_flags()
keras_common.define_keras_flags()
absl_app.run(main)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ResNet56 model for Keras adapted from tf.keras.applications.ResNet50.
# Reference:
- [Deep Residual Learning for Image Recognition](
https://arxiv.org/abs/1512.03385)
Adapted from code contributed by BigMoyan.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.keras import backend
from tensorflow.python.keras import layers
BATCH_NORM_DECAY = 0.997
BATCH_NORM_EPSILON = 1e-5
L2_WEIGHT_DECAY = 2e-4
def identity_building_block(input_tensor,
kernel_size,
filters,
stage,
block,
training=None):
"""The identity block is the block that has no conv layer at shortcut.
Arguments:
input_tensor: input tensor
kernel_size: default 3, the kernel size of
middle conv layer at main path
filters: list of integers, the filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
training: Only used if training keras model with Estimator. In other
scenarios it is handled automatically.
Returns:
Output tensor for the block.
"""
filters1, filters2 = filters
if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = 3
else:
bn_axis = 1
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
x = tf.keras.layers.Conv2D(filters1, kernel_size,
padding='same',
kernel_initializer='he_normal',
kernel_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')(input_tensor)
x = tf.keras.layers.BatchNormalization(axis=bn_axis,
name=bn_name_base + '2a',
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON)(
x, training=training)
x = tf.keras.layers.Activation('relu')(x)
x = tf.keras.layers.Conv2D(filters2, kernel_size,
padding='same',
kernel_initializer='he_normal',
kernel_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')(x)
x = tf.keras.layers.BatchNormalization(axis=bn_axis,
name=bn_name_base + '2b',
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON)(
x, training=training)
x = tf.keras.layers.add([x, input_tensor])
x = tf.keras.layers.Activation('relu')(x)
return x
def conv_building_block(input_tensor,
kernel_size,
filters,
stage,
block,
strides=(2, 2),
training=None):
"""A block that has a conv layer at shortcut.
Arguments:
input_tensor: input tensor
kernel_size: default 3, the kernel size of
middle conv layer at main path
filters: list of integers, the filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
strides: Strides for the first conv layer in the block.
training: Only used if training keras model with Estimator. In other
scenarios it is handled automatically.
Returns:
Output tensor for the block.
Note that from stage 3,
the first conv layer at main path is with strides=(2, 2)
And the shortcut should have strides=(2, 2) as well
"""
filters1, filters2 = filters
if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = 3
else:
bn_axis = 1
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
x = tf.keras.layers.Conv2D(filters1, kernel_size, strides=strides,
padding='same',
kernel_initializer='he_normal',
kernel_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')(input_tensor)
x = tf.keras.layers.BatchNormalization(axis=bn_axis,
name=bn_name_base + '2a',
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON)(
x, training=training)
x = tf.keras.layers.Activation('relu')(x)
x = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same',
kernel_initializer='he_normal',
kernel_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')(x)
x = tf.keras.layers.BatchNormalization(axis=bn_axis,
name=bn_name_base + '2b',
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON)(
x, training=training)
shortcut = tf.keras.layers.Conv2D(filters2, (1, 1), strides=strides,
kernel_initializer='he_normal',
kernel_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '1')(input_tensor)
shortcut = tf.keras.layers.BatchNormalization(
axis=bn_axis, name=bn_name_base + '1',
momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON)(
shortcut, training=training)
x = tf.keras.layers.add([x, shortcut])
x = tf.keras.layers.Activation('relu')(x)
return x
def resnet56(classes=100, training=None):
"""Instantiates the ResNet56 architecture.
Arguments:
classes: optional number of classes to classify images into
training: Only used if training keras model with Estimator. In other
scenarios it is handled automatically.
Returns:
A Keras model instance.
"""
input_shape = (32, 32, 3)
img_input = layers.Input(shape=input_shape)
if backend.image_data_format() == 'channels_first':
x = layers.Lambda(lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),
name='transpose')(img_input)
bn_axis = 1
else: # channel_last
x = img_input
bn_axis = 3
x = tf.keras.layers.ZeroPadding2D(padding=(1, 1), name='conv1_pad')(x)
x = tf.keras.layers.Conv2D(16, (3, 3),
strides=(1, 1),
padding='valid',
kernel_initializer='he_normal',
kernel_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
name='conv1')(x)
x = tf.keras.layers.BatchNormalization(axis=bn_axis, name='bn_conv1',
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON)(
x, training=training)
x = tf.keras.layers.Activation('relu')(x)
x = conv_building_block(x, 3, [16, 16], stage=2, block='a', strides=(1, 1),
training=training)
x = identity_building_block(x, 3, [16, 16], stage=2, block='b',
training=training)
x = identity_building_block(x, 3, [16, 16], stage=2, block='c',
training=training)
x = identity_building_block(x, 3, [16, 16], stage=2, block='d',
training=training)
x = identity_building_block(x, 3, [16, 16], stage=2, block='e',
training=training)
x = identity_building_block(x, 3, [16, 16], stage=2, block='f',
training=training)
x = identity_building_block(x, 3, [16, 16], stage=2, block='g',
training=training)
x = identity_building_block(x, 3, [16, 16], stage=2, block='h',
training=training)
x = identity_building_block(x, 3, [16, 16], stage=2, block='i',
training=training)
x = conv_building_block(x, 3, [32, 32], stage=3, block='a',
training=training)
x = identity_building_block(x, 3, [32, 32], stage=3, block='b',
training=training)
x = identity_building_block(x, 3, [32, 32], stage=3, block='c',
training=training)
x = identity_building_block(x, 3, [32, 32], stage=3, block='d',
training=training)
x = identity_building_block(x, 3, [32, 32], stage=3, block='e',
training=training)
x = identity_building_block(x, 3, [32, 32], stage=3, block='f',
training=training)
x = identity_building_block(x, 3, [32, 32], stage=3, block='g',
training=training)
x = identity_building_block(x, 3, [32, 32], stage=3, block='h',
training=training)
x = identity_building_block(x, 3, [32, 32], stage=3, block='i',
training=training)
x = conv_building_block(x, 3, [64, 64], stage=4, block='a',
training=training)
x = identity_building_block(x, 3, [64, 64], stage=4, block='b',
training=training)
x = identity_building_block(x, 3, [64, 64], stage=4, block='c',
training=training)
x = identity_building_block(x, 3, [64, 64], stage=4, block='d',
training=training)
x = identity_building_block(x, 3, [64, 64], stage=4, block='e',
training=training)
x = identity_building_block(x, 3, [64, 64], stage=4, block='f',
training=training)
x = identity_building_block(x, 3, [64, 64], stage=4, block='g',
training=training)
x = identity_building_block(x, 3, [64, 64], stage=4, block='h',
training=training)
x = identity_building_block(x, 3, [64, 64], stage=4, block='i',
training=training)
x = tf.keras.layers.GlobalAveragePooling2D(name='avg_pool')(x)
x = tf.keras.layers.Dense(classes, activation='softmax',
kernel_initializer='he_normal',
kernel_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=
tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
name='fc10')(x)
inputs = img_input
# Create model.
model = tf.keras.models.Model(inputs, x, name='resnet56')
return model
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ResNet50 model for Keras.
Adapted from tf.keras.applications.resnet50.ResNet50().
This is ResNet model version 1.5.
Related papers/blogs:
- https://arxiv.org/abs/1512.03385
- https://arxiv.org/pdf/1603.05027v2.pdf
- http://torch.ch/blog/2016/02/04/resnets.html
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import warnings
from tensorflow.python.keras import backend
from tensorflow.python.keras import layers
from tensorflow.python.keras import models
from tensorflow.python.keras import regularizers
from tensorflow.python.keras import utils
L2_WEIGHT_DECAY = 1e-4
BATCH_NORM_DECAY = 0.9
BATCH_NORM_EPSILON = 1e-5
def identity_block(input_tensor, kernel_size, filters, stage, block):
"""The identity block is the block that has no conv layer at shortcut.
# Arguments
input_tensor: input tensor
kernel_size: default 3, the kernel size of
middle conv layer at main path
filters: list of integers, the filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
# Returns
Output tensor for the block.
"""
filters1, filters2, filters3 = filters
if backend.image_data_format() == 'channels_last':
bn_axis = 3
else:
bn_axis = 1
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
x = layers.Conv2D(filters1, (1, 1), use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')(input_tensor)
x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2a')(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters2, kernel_size,
padding='same', use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')(x)
x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2b')(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters3, (1, 1), use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2c')(x)
x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2c')(x)
x = layers.add([x, input_tensor])
x = layers.Activation('relu')(x)
return x
def conv_block(input_tensor,
kernel_size,
filters,
stage,
block,
strides=(2, 2)):
"""A block that has a conv layer at shortcut.
# Arguments
input_tensor: input tensor
kernel_size: default 3, the kernel size of
middle conv layer at main path
filters: list of integers, the filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
strides: Strides for the second conv layer in the block.
# Returns
Output tensor for the block.
Note that from stage 3,
the second conv layer at main path is with strides=(2, 2)
And the shortcut should have strides=(2, 2) as well
"""
filters1, filters2, filters3 = filters
if backend.image_data_format() == 'channels_last':
bn_axis = 3
else:
bn_axis = 1
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
x = layers.Conv2D(filters1, (1, 1), use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')(input_tensor)
x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2a')(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters2, kernel_size, strides=strides, padding='same',
use_bias=False, kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')(x)
x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2b')(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters3, (1, 1), use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2c')(x)
x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2c')(x)
shortcut = layers.Conv2D(filters3, (1, 1), strides=strides, use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '1')(input_tensor)
shortcut = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '1')(shortcut)
x = layers.add([x, shortcut])
x = layers.Activation('relu')(x)
return x
def resnet50(num_classes):
# TODO(tfboyd): add training argument, just lik resnet56.
"""Instantiates the ResNet50 architecture.
Args:
num_classes: `int` number of classes for image classification.
Returns:
A Keras model instance.
"""
input_shape = (224, 224, 3)
img_input = layers.Input(shape=input_shape)
if backend.image_data_format() == 'channels_first':
x = layers.Lambda(lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),
name='transpose')(img_input)
bn_axis = 1
else: # channels_last
x = img_input
bn_axis = 3
x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(x)
x = layers.Conv2D(64, (7, 7),
strides=(2, 2),
padding='valid', use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name='conv1')(x)
x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name='bn_conv1')(x)
x = layers.Activation('relu')(x)
x = layers.ZeroPadding2D(padding=(1, 1), name='pool1_pad')(x)
x = layers.MaxPooling2D((3, 3), strides=(2, 2))(x)
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
x = layers.Dense(
num_classes, activation='softmax',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name='fc1000')(x)
# Create model.
return models.Model(img_input, x, name='resnet50')
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test that the definitions of ResNet layers haven't changed.
These tests will fail if either:
a) The graph of a resnet layer changes and the change is significant enough
that it can no longer load existing checkpoints.
b) The numerical results produced by the layer change.
A warning will be issued if the graph changes, but the checkpoint still loads.
In the event that a layer change is intended, or the TensorFlow implementation
of a layer changes (and thus changes the graph), regenerate using the command:
$ python3 layer_test.py -regen
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import resnet_model
from official.utils.testing import reference_data
DATA_FORMAT = "channels_last" # CPU instructions often preclude channels_first
BATCH_SIZE = 32
BLOCK_TESTS = [
dict(bottleneck=True, projection=True, resnet_version=1, width=8,
channels=4),
dict(bottleneck=True, projection=True, resnet_version=2, width=8,
channels=4),
dict(bottleneck=True, projection=False, resnet_version=1, width=8,
channels=4),
dict(bottleneck=True, projection=False, resnet_version=2, width=8,
channels=4),
dict(bottleneck=False, projection=True, resnet_version=1, width=8,
channels=4),
dict(bottleneck=False, projection=True, resnet_version=2, width=8,
channels=4),
dict(bottleneck=False, projection=False, resnet_version=1, width=8,
channels=4),
dict(bottleneck=False, projection=False, resnet_version=2, width=8,
channels=4),
]
class BaseTest(reference_data.BaseTest):
"""Tests for core ResNet layers."""
@property
def test_name(self):
return "resnet"
def _batch_norm_ops(self, test=False):
name = "batch_norm"
g = tf.Graph()
with g.as_default():
tf.set_random_seed(self.name_to_seed(name))
input_tensor = tf.get_variable(
"input_tensor", dtype=tf.float32,
initializer=tf.random_uniform((32, 16, 16, 3), maxval=1)
)
layer = resnet_model.batch_norm(
inputs=input_tensor, data_format=DATA_FORMAT, training=True)
self._save_or_test_ops(
name=name, graph=g, ops_to_eval=[input_tensor, layer], test=test,
correctness_function=self.default_correctness_function
)
def make_projection(self, filters_out, strides, data_format):
"""1D convolution with stride projector.
Args:
filters_out: Number of filters in the projection.
strides: Stride length for convolution.
data_format: channels_first or channels_last
Returns:
A CNN projector function with kernel_size 1.
"""
def projection_shortcut(inputs):
return resnet_model.conv2d_fixed_padding(
inputs=inputs, filters=filters_out, kernel_size=1, strides=strides,
data_format=data_format)
return projection_shortcut
def _resnet_block_ops(self, test, batch_size, bottleneck, projection,
resnet_version, width, channels):
"""Test whether resnet block construction has changed.
Args:
test: Whether or not to run as a test case.
batch_size: Number of points in the fake image. This is needed due to
batch normalization.
bottleneck: Whether or not to use bottleneck layers.
projection: Whether or not to project the input.
resnet_version: Which version of ResNet to test.
width: The width of the fake image.
channels: The number of channels in the fake image.
"""
name = "batch-size-{}_{}{}_version-{}_width-{}_channels-{}".format(
batch_size,
"bottleneck" if bottleneck else "building",
"_projection" if projection else "",
resnet_version,
width,
channels
)
if resnet_version == 1:
block_fn = resnet_model._building_block_v1
if bottleneck:
block_fn = resnet_model._bottleneck_block_v1
else:
block_fn = resnet_model._building_block_v2
if bottleneck:
block_fn = resnet_model._bottleneck_block_v2
g = tf.Graph()
with g.as_default():
tf.set_random_seed(self.name_to_seed(name))
strides = 1
channels_out = channels
projection_shortcut = None
if projection:
strides = 2
channels_out *= strides
projection_shortcut = self.make_projection(
filters_out=channels_out, strides=strides, data_format=DATA_FORMAT)
filters = channels_out
if bottleneck:
filters = channels_out // 4
input_tensor = tf.get_variable(
"input_tensor", dtype=tf.float32,
initializer=tf.random_uniform((batch_size, width, width, channels),
maxval=1)
)
layer = block_fn(inputs=input_tensor, filters=filters, training=True,
projection_shortcut=projection_shortcut, strides=strides,
data_format=DATA_FORMAT)
self._save_or_test_ops(
name=name, graph=g, ops_to_eval=[input_tensor, layer], test=test,
correctness_function=self.default_correctness_function
)
def test_batch_norm(self):
self._batch_norm_ops(test=True)
def test_block_0(self):
self._resnet_block_ops(test=True, batch_size=BATCH_SIZE, **BLOCK_TESTS[0])
def test_block_1(self):
self._resnet_block_ops(test=True, batch_size=BATCH_SIZE, **BLOCK_TESTS[1])
def test_block_2(self):
self._resnet_block_ops(test=True, batch_size=BATCH_SIZE, **BLOCK_TESTS[2])
def test_block_3(self):
self._resnet_block_ops(test=True, batch_size=BATCH_SIZE, **BLOCK_TESTS[3])
def test_block_4(self):
self._resnet_block_ops(test=True, batch_size=BATCH_SIZE, **BLOCK_TESTS[4])
def test_block_5(self):
self._resnet_block_ops(test=True, batch_size=BATCH_SIZE, **BLOCK_TESTS[5])
def test_block_6(self):
self._resnet_block_ops(test=True, batch_size=BATCH_SIZE, **BLOCK_TESTS[6])
def test_block_7(self):
self._resnet_block_ops(test=True, batch_size=BATCH_SIZE, **BLOCK_TESTS[7])
def regenerate(self):
"""Create reference data files for ResNet layer tests."""
self._batch_norm_ops(test=False)
for block_params in BLOCK_TESTS:
self._resnet_block_ops(test=False, batch_size=BATCH_SIZE, **block_params)
if __name__ == "__main__":
reference_data.main(argv=sys.argv, test_class=BaseTest)
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains definitions for Residual Networks.
Residual networks ('v1' ResNets) were originally proposed in:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
Deep Residual Learning for Image Recognition. arXiv:1512.03385
The full preactivation 'v2' ResNet variant was introduced by:
[2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
Identity Mappings in Deep Residual Networks. arXiv: 1603.05027
The key difference of the full preactivation 'v2' variant compared to the
'v1' variant in [1] is the use of batch normalization before every weight layer
rather than after.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
_BATCH_NORM_DECAY = 0.997
_BATCH_NORM_EPSILON = 1e-5
DEFAULT_VERSION = 2
DEFAULT_DTYPE = tf.float32
CASTABLE_TYPES = (tf.float16,)
ALLOWED_TYPES = (DEFAULT_DTYPE,) + CASTABLE_TYPES
################################################################################
# Convenience functions for building the ResNet model.
################################################################################
def batch_norm(inputs, training, data_format):
"""Performs a batch normalization using a standard set of parameters."""
# We set fused=True for a significant performance boost. See
# https://www.tensorflow.org/performance/performance_guide#common_fused_ops
return tf.layers.batch_normalization(
inputs=inputs, axis=1 if data_format == 'channels_first' else 3,
momentum=_BATCH_NORM_DECAY, epsilon=_BATCH_NORM_EPSILON, center=True,
scale=True, training=training, fused=True)
def fixed_padding(inputs, kernel_size, data_format):
"""Pads the input along the spatial dimensions independently of input size.
Args:
inputs: A tensor of size [batch, channels, height_in, width_in] or
[batch, height_in, width_in, channels] depending on data_format.
kernel_size: The kernel to be used in the conv2d or max_pool2d operation.
Should be a positive integer.
data_format: The input format ('channels_last' or 'channels_first').
Returns:
A tensor with the same format as the input with the data either intact
(if kernel_size == 1) or padded (if kernel_size > 1).
"""
pad_total = kernel_size - 1
pad_beg = pad_total // 2
pad_end = pad_total - pad_beg
if data_format == 'channels_first':
padded_inputs = tf.pad(inputs, [[0, 0], [0, 0],
[pad_beg, pad_end], [pad_beg, pad_end]])
else:
padded_inputs = tf.pad(inputs, [[0, 0], [pad_beg, pad_end],
[pad_beg, pad_end], [0, 0]])
return padded_inputs
def conv2d_fixed_padding(inputs, filters, kernel_size, strides, data_format):
"""Strided 2-D convolution with explicit padding."""
# The padding is consistent and is based only on `kernel_size`, not on the
# dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone).
if strides > 1:
inputs = fixed_padding(inputs, kernel_size, data_format)
return tf.layers.conv2d(
inputs=inputs, filters=filters, kernel_size=kernel_size, strides=strides,
padding=('SAME' if strides == 1 else 'VALID'), use_bias=False,
kernel_initializer=tf.variance_scaling_initializer(),
data_format=data_format)
################################################################################
# ResNet block definitions.
################################################################################
def _building_block_v1(inputs, filters, training, projection_shortcut, strides,
data_format):
"""A single block for ResNet v1, without a bottleneck.
Convolution then batch normalization then ReLU as described by:
Deep Residual Learning for Image Recognition
https://arxiv.org/pdf/1512.03385.pdf
by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015.
Args:
inputs: A tensor of size [batch, channels, height_in, width_in] or
[batch, height_in, width_in, channels] depending on data_format.
filters: The number of filters for the convolutions.
training: A Boolean for whether the model is in training or inference
mode. Needed for batch normalization.
projection_shortcut: The function to use for projection shortcuts
(typically a 1x1 convolution when downsampling the input).
strides: The block's stride. If greater than 1, this block will ultimately
downsample the input.
data_format: The input format ('channels_last' or 'channels_first').
Returns:
The output tensor of the block; shape should match inputs.
"""
shortcut = inputs
if projection_shortcut is not None:
shortcut = projection_shortcut(inputs)
shortcut = batch_norm(inputs=shortcut, training=training,
data_format=data_format)
inputs = conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=3, strides=strides,
data_format=data_format)
inputs = batch_norm(inputs, training, data_format)
inputs = tf.nn.relu(inputs)
inputs = conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=3, strides=1,
data_format=data_format)
inputs = batch_norm(inputs, training, data_format)
inputs += shortcut
inputs = tf.nn.relu(inputs)
return inputs
def _building_block_v2(inputs, filters, training, projection_shortcut, strides,
data_format):
"""A single block for ResNet v2, without a bottleneck.
Batch normalization then ReLu then convolution as described by:
Identity Mappings in Deep Residual Networks
https://arxiv.org/pdf/1603.05027.pdf
by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Jul 2016.
Args:
inputs: A tensor of size [batch, channels, height_in, width_in] or
[batch, height_in, width_in, channels] depending on data_format.
filters: The number of filters for the convolutions.
training: A Boolean for whether the model is in training or inference
mode. Needed for batch normalization.
projection_shortcut: The function to use for projection shortcuts
(typically a 1x1 convolution when downsampling the input).
strides: The block's stride. If greater than 1, this block will ultimately
downsample the input.
data_format: The input format ('channels_last' or 'channels_first').
Returns:
The output tensor of the block; shape should match inputs.
"""
shortcut = inputs
inputs = batch_norm(inputs, training, data_format)
inputs = tf.nn.relu(inputs)
# The projection shortcut should come after the first batch norm and ReLU
# since it performs a 1x1 convolution.
if projection_shortcut is not None:
shortcut = projection_shortcut(inputs)
inputs = conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=3, strides=strides,
data_format=data_format)
inputs = batch_norm(inputs, training, data_format)
inputs = tf.nn.relu(inputs)
inputs = conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=3, strides=1,
data_format=data_format)
return inputs + shortcut
def _bottleneck_block_v1(inputs, filters, training, projection_shortcut,
strides, data_format):
"""A single block for ResNet v1, with a bottleneck.
Similar to _building_block_v1(), except using the "bottleneck" blocks
described in:
Convolution then batch normalization then ReLU as described by:
Deep Residual Learning for Image Recognition
https://arxiv.org/pdf/1512.03385.pdf
by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015.
Args:
inputs: A tensor of size [batch, channels, height_in, width_in] or
[batch, height_in, width_in, channels] depending on data_format.
filters: The number of filters for the convolutions.
training: A Boolean for whether the model is in training or inference
mode. Needed for batch normalization.
projection_shortcut: The function to use for projection shortcuts
(typically a 1x1 convolution when downsampling the input).
strides: The block's stride. If greater than 1, this block will ultimately
downsample the input.
data_format: The input format ('channels_last' or 'channels_first').
Returns:
The output tensor of the block; shape should match inputs.
"""
shortcut = inputs
if projection_shortcut is not None:
shortcut = projection_shortcut(inputs)
shortcut = batch_norm(inputs=shortcut, training=training,
data_format=data_format)
inputs = conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=1, strides=1,
data_format=data_format)
inputs = batch_norm(inputs, training, data_format)
inputs = tf.nn.relu(inputs)
inputs = conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=3, strides=strides,
data_format=data_format)
inputs = batch_norm(inputs, training, data_format)
inputs = tf.nn.relu(inputs)
inputs = conv2d_fixed_padding(
inputs=inputs, filters=4 * filters, kernel_size=1, strides=1,
data_format=data_format)
inputs = batch_norm(inputs, training, data_format)
inputs += shortcut
inputs = tf.nn.relu(inputs)
return inputs
def _bottleneck_block_v2(inputs, filters, training, projection_shortcut,
strides, data_format):
"""A single block for ResNet v2, with a bottleneck.
Similar to _building_block_v2(), except using the "bottleneck" blocks
described in:
Convolution then batch normalization then ReLU as described by:
Deep Residual Learning for Image Recognition
https://arxiv.org/pdf/1512.03385.pdf
by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015.
Adapted to the ordering conventions of:
Batch normalization then ReLu then convolution as described by:
Identity Mappings in Deep Residual Networks
https://arxiv.org/pdf/1603.05027.pdf
by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Jul 2016.
Args:
inputs: A tensor of size [batch, channels, height_in, width_in] or
[batch, height_in, width_in, channels] depending on data_format.
filters: The number of filters for the convolutions.
training: A Boolean for whether the model is in training or inference
mode. Needed for batch normalization.
projection_shortcut: The function to use for projection shortcuts
(typically a 1x1 convolution when downsampling the input).
strides: The block's stride. If greater than 1, this block will ultimately
downsample the input.
data_format: The input format ('channels_last' or 'channels_first').
Returns:
The output tensor of the block; shape should match inputs.
"""
shortcut = inputs
inputs = batch_norm(inputs, training, data_format)
inputs = tf.nn.relu(inputs)
# The projection shortcut should come after the first batch norm and ReLU
# since it performs a 1x1 convolution.
if projection_shortcut is not None:
shortcut = projection_shortcut(inputs)
inputs = conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=1, strides=1,
data_format=data_format)
inputs = batch_norm(inputs, training, data_format)
inputs = tf.nn.relu(inputs)
inputs = conv2d_fixed_padding(
inputs=inputs, filters=filters, kernel_size=3, strides=strides,
data_format=data_format)
inputs = batch_norm(inputs, training, data_format)
inputs = tf.nn.relu(inputs)
inputs = conv2d_fixed_padding(
inputs=inputs, filters=4 * filters, kernel_size=1, strides=1,
data_format=data_format)
return inputs + shortcut
def block_layer(inputs, filters, bottleneck, block_fn, blocks, strides,
training, name, data_format):
"""Creates one layer of blocks for the ResNet model.
Args:
inputs: A tensor of size [batch, channels, height_in, width_in] or
[batch, height_in, width_in, channels] depending on data_format.
filters: The number of filters for the first convolution of the layer.
bottleneck: Is the block created a bottleneck block.
block_fn: The block to use within the model, either `building_block` or
`bottleneck_block`.
blocks: The number of blocks contained in the layer.
strides: The stride to use for the first convolution of the layer. If
greater than 1, this layer will ultimately downsample the input.
training: Either True or False, whether we are currently training the
model. Needed for batch norm.
name: A string name for the tensor output of the block layer.
data_format: The input format ('channels_last' or 'channels_first').
Returns:
The output tensor of the block layer.
"""
# Bottleneck blocks end with 4x the number of filters as they start with
filters_out = filters * 4 if bottleneck else filters
def projection_shortcut(inputs):
return conv2d_fixed_padding(
inputs=inputs, filters=filters_out, kernel_size=1, strides=strides,
data_format=data_format)
# Only the first block per block_layer uses projection_shortcut and strides
inputs = block_fn(inputs, filters, training, projection_shortcut, strides,
data_format)
for _ in range(1, blocks):
inputs = block_fn(inputs, filters, training, None, 1, data_format)
return tf.identity(inputs, name)
class Model(object):
"""Base class for building the Resnet Model."""
def __init__(self, resnet_size, bottleneck, num_classes, num_filters,
kernel_size,
conv_stride, first_pool_size, first_pool_stride,
block_sizes, block_strides,
resnet_version=DEFAULT_VERSION, data_format=None,
dtype=DEFAULT_DTYPE):
"""Creates a model for classifying an image.
Args:
resnet_size: A single integer for the size of the ResNet model.
bottleneck: Use regular blocks or bottleneck blocks.
num_classes: The number of classes used as labels.
num_filters: The number of filters to use for the first block layer
of the model. This number is then doubled for each subsequent block
layer.
kernel_size: The kernel size to use for convolution.
conv_stride: stride size for the initial convolutional layer
first_pool_size: Pool size to be used for the first pooling layer.
If none, the first pooling layer is skipped.
first_pool_stride: stride size for the first pooling layer. Not used
if first_pool_size is None.
block_sizes: A list containing n values, where n is the number of sets of
block layers desired. Each value should be the number of blocks in the
i-th set.
block_strides: List of integers representing the desired stride size for
each of the sets of block layers. Should be same length as block_sizes.
resnet_version: Integer representing which version of the ResNet network
to use. See README for details. Valid values: [1, 2]
data_format: Input format ('channels_last', 'channels_first', or None).
If set to None, the format is dependent on whether a GPU is available.
dtype: The TensorFlow dtype to use for calculations. If not specified
tf.float32 is used.
Raises:
ValueError: if invalid version is selected.
"""
self.resnet_size = resnet_size
if not data_format:
data_format = (
'channels_first' if tf.test.is_built_with_cuda() else 'channels_last')
self.resnet_version = resnet_version
if resnet_version not in (1, 2):
raise ValueError(
'Resnet version should be 1 or 2. See README for citations.')
self.bottleneck = bottleneck
if bottleneck:
if resnet_version == 1:
self.block_fn = _bottleneck_block_v1
else:
self.block_fn = _bottleneck_block_v2
else:
if resnet_version == 1:
self.block_fn = _building_block_v1
else:
self.block_fn = _building_block_v2
if dtype not in ALLOWED_TYPES:
raise ValueError('dtype must be one of: {}'.format(ALLOWED_TYPES))
self.data_format = data_format
self.num_classes = num_classes
self.num_filters = num_filters
self.kernel_size = kernel_size
self.conv_stride = conv_stride
self.first_pool_size = first_pool_size
self.first_pool_stride = first_pool_stride
self.block_sizes = block_sizes
self.block_strides = block_strides
self.dtype = dtype
self.pre_activation = resnet_version == 2
def _custom_dtype_getter(self, getter, name, shape=None, dtype=DEFAULT_DTYPE,
*args, **kwargs):
"""Creates variables in fp32, then casts to fp16 if necessary.
This function is a custom getter. A custom getter is a function with the
same signature as tf.get_variable, except it has an additional getter
parameter. Custom getters can be passed as the `custom_getter` parameter of
tf.variable_scope. Then, tf.get_variable will call the custom getter,
instead of directly getting a variable itself. This can be used to change
the types of variables that are retrieved with tf.get_variable.
The `getter` parameter is the underlying variable getter, that would have
been called if no custom getter was used. Custom getters typically get a
variable with `getter`, then modify it in some way.
This custom getter will create an fp32 variable. If a low precision
(e.g. float16) variable was requested it will then cast the variable to the
requested dtype. The reason we do not directly create variables in low
precision dtypes is that applying small gradients to such variables may
cause the variable not to change.
Args:
getter: The underlying variable getter, that has the same signature as
tf.get_variable and returns a variable.
name: The name of the variable to get.
shape: The shape of the variable to get.
dtype: The dtype of the variable to get. Note that if this is a low
precision dtype, the variable will be created as a tf.float32 variable,
then cast to the appropriate dtype
*args: Additional arguments to pass unmodified to getter.
**kwargs: Additional keyword arguments to pass unmodified to getter.
Returns:
A variable which is cast to fp16 if necessary.
"""
if dtype in CASTABLE_TYPES:
var = getter(name, shape, tf.float32, *args, **kwargs)
return tf.cast(var, dtype=dtype, name=name + '_cast')
else:
return getter(name, shape, dtype, *args, **kwargs)
def _model_variable_scope(self):
"""Returns a variable scope that the model should be created under.
If self.dtype is a castable type, model variable will be created in fp32
then cast to self.dtype before being used.
Returns:
A variable scope for the model.
"""
return tf.variable_scope('resnet_model',
custom_getter=self._custom_dtype_getter)
def __call__(self, inputs, training):
"""Add operations to classify a batch of input images.
Args:
inputs: A Tensor representing a batch of input images.
training: A boolean. Set to True to add operations required only when
training the classifier.
Returns:
A logits Tensor with shape [<batch_size>, self.num_classes].
"""
with self._model_variable_scope():
if self.data_format == 'channels_first':
# Convert the inputs from channels_last (NHWC) to channels_first (NCHW).
# This provides a large performance boost on GPU. See
# https://www.tensorflow.org/performance/performance_guide#data_formats
inputs = tf.transpose(inputs, [0, 3, 1, 2])
inputs = conv2d_fixed_padding(
inputs=inputs, filters=self.num_filters, kernel_size=self.kernel_size,
strides=self.conv_stride, data_format=self.data_format)
inputs = tf.identity(inputs, 'initial_conv')
# We do not include batch normalization or activation functions in V2
# for the initial conv1 because the first ResNet unit will perform these
# for both the shortcut and non-shortcut paths as part of the first
# block's projection. Cf. Appendix of [2].
if self.resnet_version == 1:
inputs = batch_norm(inputs, training, self.data_format)
inputs = tf.nn.relu(inputs)
if self.first_pool_size:
inputs = tf.layers.max_pooling2d(
inputs=inputs, pool_size=self.first_pool_size,
strides=self.first_pool_stride, padding='SAME',
data_format=self.data_format)
inputs = tf.identity(inputs, 'initial_max_pool')
for i, num_blocks in enumerate(self.block_sizes):
num_filters = self.num_filters * (2**i)
inputs = block_layer(
inputs=inputs, filters=num_filters, bottleneck=self.bottleneck,
block_fn=self.block_fn, blocks=num_blocks,
strides=self.block_strides[i], training=training,
name='block_layer{}'.format(i + 1), data_format=self.data_format)
# Only apply the BN and ReLU for model that does pre_activation in each
# building/bottleneck block, eg resnet V2.
if self.pre_activation:
inputs = batch_norm(inputs, training, self.data_format)
inputs = tf.nn.relu(inputs)
# The current top layer has shape
# `batch_size x pool_size x pool_size x final_size`.
# ResNet does an Average Pooling layer over pool_size,
# but that is the same as doing a reduce_mean. We do a reduce_mean
# here because it performs better than AveragePooling2D.
axes = [2, 3] if self.data_format == 'channels_first' else [1, 2]
inputs = tf.reduce_mean(inputs, axes, keepdims=True)
inputs = tf.identity(inputs, 'final_reduce_mean')
inputs = tf.squeeze(inputs, axes)
inputs = tf.layers.dense(inputs=inputs, units=self.num_classes)
inputs = tf.identity(inputs, 'final_dense')
return inputs
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains utility and supporting functions for ResNet.
This module contains ResNet code which does not directly build layers. This
includes dataset management, hyperparameter and optimizer code, and argument
parsing. Code for defining the ResNet layers can be found in resnet_model.py.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import math
import multiprocessing
import os
# pylint: disable=g-bad-import-order
from absl import flags
import tensorflow as tf
from tensorflow.contrib.data.python.ops import threadpool
from official.resnet import resnet_model
from official.utils.flags import core as flags_core
from official.utils.export import export
from official.utils.logs import hooks_helper
from official.utils.logs import logger
from official.resnet import imagenet_preprocessing
from official.utils.misc import distribution_utils
from official.utils.misc import model_helpers
################################################################################
# Functions for input processing.
################################################################################
def process_record_dataset(dataset,
is_training,
batch_size,
shuffle_buffer,
parse_record_fn,
num_epochs=1,
dtype=tf.float32,
datasets_num_private_threads=None,
num_parallel_batches=1):
"""Given a Dataset with raw records, return an iterator over the records.
Args:
dataset: A Dataset representing raw records
is_training: A boolean denoting whether the input is for training.
batch_size: The number of samples per batch.
shuffle_buffer: The buffer size to use when shuffling records. A larger
value results in better randomness, but smaller values reduce startup
time and use less memory.
parse_record_fn: A function that takes a raw record and returns the
corresponding (image, label) pair.
num_epochs: The number of epochs to repeat the dataset.
dtype: Data type to use for images/features.
datasets_num_private_threads: Number of threads for a private
threadpool created for all datasets computation.
num_parallel_batches: Number of parallel batches for tf.data.
Returns:
Dataset of (image, label) pairs ready for iteration.
"""
# Prefetches a batch at a time to smooth out the time taken to load input
# files for shuffling and processing.
dataset = dataset.prefetch(buffer_size=batch_size)
if is_training:
# Shuffles records before repeating to respect epoch boundaries.
dataset = dataset.shuffle(buffer_size=shuffle_buffer)
# Repeats the dataset for the number of epochs to train.
dataset = dataset.repeat(num_epochs)
# Parses the raw records into images and labels.
dataset = dataset.apply(
tf.contrib.data.map_and_batch(
lambda value: parse_record_fn(value, is_training, dtype),
batch_size=batch_size,
num_parallel_batches=num_parallel_batches,
drop_remainder=False))
# Operations between the final prefetch and the get_next call to the iterator
# will happen synchronously during run time. We prefetch here again to
# background all of the above processing work and keep it out of the
# critical training path. Setting buffer_size to tf.contrib.data.AUTOTUNE
# allows DistributionStrategies to adjust how many batches to fetch based
# on how many devices are present.
dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
# Defines a specific size thread pool for tf.data operations.
if datasets_num_private_threads:
tf.logging.info('datasets_num_private_threads: %s',
datasets_num_private_threads)
dataset = threadpool.override_threadpool(
dataset,
threadpool.PrivateThreadPool(
datasets_num_private_threads,
display_name='input_pipeline_thread_pool'))
return dataset
def get_synth_input_fn(height, width, num_channels, num_classes,
dtype=tf.float32):
"""Returns an input function that returns a dataset with random data.
This input_fn returns a data set that iterates over a set of random data and
bypasses all preprocessing, e.g. jpeg decode and copy. The host to device
copy is still included. This used to find the upper throughput bound when
tunning the full input pipeline.
Args:
height: Integer height that will be used to create a fake image tensor.
width: Integer width that will be used to create a fake image tensor.
num_channels: Integer depth that will be used to create a fake image tensor.
num_classes: Number of classes that should be represented in the fake labels
tensor
dtype: Data type for features/images.
Returns:
An input_fn that can be used in place of a real one to return a dataset
that can be used for iteration.
"""
# pylint: disable=unused-argument
def input_fn(is_training, data_dir, batch_size, *args, **kwargs):
"""Returns dataset filled with random data."""
# Synthetic input should be within [0, 255].
inputs = tf.truncated_normal(
[batch_size] + [height, width, num_channels],
dtype=dtype,
mean=127,
stddev=60,
name='synthetic_inputs')
labels = tf.random_uniform(
[batch_size],
minval=0,
maxval=num_classes - 1,
dtype=tf.int32,
name='synthetic_labels')
data = tf.data.Dataset.from_tensors((inputs, labels)).repeat()
data = data.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
return data
return input_fn
def image_bytes_serving_input_fn(image_shape, dtype=tf.float32):
"""Serving input fn for raw jpeg images."""
def _preprocess_image(image_bytes):
"""Preprocess a single raw image."""
# Bounding box around the whole image.
bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=dtype, shape=[1, 1, 4])
height, width, num_channels = image_shape
image = imagenet_preprocessing.preprocess_image(
image_bytes, bbox, height, width, num_channels, is_training=False)
return image
image_bytes_list = tf.placeholder(
shape=[None], dtype=tf.string, name='input_tensor')
images = tf.map_fn(
_preprocess_image, image_bytes_list, back_prop=False, dtype=dtype)
return tf.estimator.export.TensorServingInputReceiver(
images, {'image_bytes': image_bytes_list})
def override_flags_and_set_envars_for_gpu_thread_pool(flags_obj):
"""Override flags and set env_vars for performance.
These settings exist to test the difference between using stock settings
and manual tuning. It also shows some of the ENV_VARS that can be tweaked to
squeeze a few extra examples per second. These settings are defaulted to the
current platform of interest, which changes over time.
On systems with small numbers of cpu cores, e.g. under 8 logical cores,
setting up a gpu thread pool with `tf_gpu_thread_mode=gpu_private` may perform
poorly.
Args:
flags_obj: Current flags, which will be adjusted possibly overriding
what has been set by the user on the command-line.
"""
cpu_count = multiprocessing.cpu_count()
tf.logging.info('Logical CPU cores: %s', cpu_count)
# Sets up thread pool for each GPU for op scheduling.
per_gpu_thread_count = 1
total_gpu_thread_count = per_gpu_thread_count * flags_obj.num_gpus
os.environ['TF_GPU_THREAD_MODE'] = flags_obj.tf_gpu_thread_mode
os.environ['TF_GPU_THREAD_COUNT'] = str(per_gpu_thread_count)
tf.logging.info('TF_GPU_THREAD_COUNT: %s', os.environ['TF_GPU_THREAD_COUNT'])
tf.logging.info('TF_GPU_THREAD_MODE: %s', os.environ['TF_GPU_THREAD_MODE'])
# Reduces general thread pool by number of threads used for GPU pool.
main_thread_count = cpu_count - total_gpu_thread_count
flags_obj.inter_op_parallelism_threads = main_thread_count
# Sets thread count for tf.data. Logical cores minus threads assign to the
# private GPU pool along with 2 thread per GPU for event monitoring and
# sending / receiving tensors.
num_monitoring_threads = 2 * flags_obj.num_gpus
flags_obj.datasets_num_private_threads = (cpu_count - total_gpu_thread_count
- num_monitoring_threads)
################################################################################
# Functions for running training/eval/validation loops for the model.
################################################################################
def learning_rate_with_decay(
batch_size, batch_denom, num_images, boundary_epochs, decay_rates,
base_lr=0.1, warmup=False):
"""Get a learning rate that decays step-wise as training progresses.
Args:
batch_size: the number of examples processed in each training batch.
batch_denom: this value will be used to scale the base learning rate.
`0.1 * batch size` is divided by this number, such that when
batch_denom == batch_size, the initial learning rate will be 0.1.
num_images: total number of images that will be used for training.
boundary_epochs: list of ints representing the epochs at which we
decay the learning rate.
decay_rates: list of floats representing the decay rates to be used
for scaling the learning rate. It should have one more element
than `boundary_epochs`, and all elements should have the same type.
base_lr: Initial learning rate scaled based on batch_denom.
warmup: Run a 5 epoch warmup to the initial lr.
Returns:
Returns a function that takes a single argument - the number of batches
trained so far (global_step)- and returns the learning rate to be used
for training the next batch.
"""
initial_learning_rate = base_lr * batch_size / batch_denom
batches_per_epoch = num_images / batch_size
# Reduce the learning rate at certain epochs.
# CIFAR-10: divide by 10 at epoch 100, 150, and 200
# ImageNet: divide by 10 at epoch 30, 60, 80, and 90
boundaries = [int(batches_per_epoch * epoch) for epoch in boundary_epochs]
vals = [initial_learning_rate * decay for decay in decay_rates]
def learning_rate_fn(global_step):
"""Builds scaled learning rate function with 5 epoch warm up."""
lr = tf.train.piecewise_constant(global_step, boundaries, vals)
if warmup:
warmup_steps = int(batches_per_epoch * 5)
warmup_lr = (
initial_learning_rate * tf.cast(global_step, tf.float32) / tf.cast(
warmup_steps, tf.float32))
return tf.cond(global_step < warmup_steps, lambda: warmup_lr, lambda: lr)
return lr
return learning_rate_fn
def resnet_model_fn(features, labels, mode, model_class,
resnet_size, weight_decay, learning_rate_fn, momentum,
data_format, resnet_version, loss_scale,
loss_filter_fn=None, dtype=resnet_model.DEFAULT_DTYPE,
fine_tune=False):
"""Shared functionality for different resnet model_fns.
Initializes the ResnetModel representing the model layers
and uses that model to build the necessary EstimatorSpecs for
the `mode` in question. For training, this means building losses,
the optimizer, and the train op that get passed into the EstimatorSpec.
For evaluation and prediction, the EstimatorSpec is returned without
a train op, but with the necessary parameters for the given mode.
Args:
features: tensor representing input images
labels: tensor representing class labels for all input images
mode: current estimator mode; should be one of
`tf.estimator.ModeKeys.TRAIN`, `EVALUATE`, `PREDICT`
model_class: a class representing a TensorFlow model that has a __call__
function. We assume here that this is a subclass of ResnetModel.
resnet_size: A single integer for the size of the ResNet model.
weight_decay: weight decay loss rate used to regularize learned variables.
learning_rate_fn: function that returns the current learning rate given
the current global_step
momentum: momentum term used for optimization
data_format: Input format ('channels_last', 'channels_first', or None).
If set to None, the format is dependent on whether a GPU is available.
resnet_version: Integer representing which version of the ResNet network to
use. See README for details. Valid values: [1, 2]
loss_scale: The factor to scale the loss for numerical stability. A detailed
summary is present in the arg parser help text.
loss_filter_fn: function that takes a string variable name and returns
True if the var should be included in loss calculation, and False
otherwise. If None, batch_normalization variables will be excluded
from the loss.
dtype: the TensorFlow dtype to use for calculations.
fine_tune: If True only train the dense layers(final layers).
Returns:
EstimatorSpec parameterized according to the input params and the
current mode.
"""
# Generate a summary node for the images
tf.summary.image('images', features, max_outputs=6)
# Checks that features/images have same data type being used for calculations.
assert features.dtype == dtype
model = model_class(resnet_size, data_format, resnet_version=resnet_version,
dtype=dtype)
logits = model(features, mode == tf.estimator.ModeKeys.TRAIN)
# This acts as a no-op if the logits are already in fp32 (provided logits are
# not a SparseTensor). If dtype is is low precision, logits must be cast to
# fp32 for numerical stability.
logits = tf.cast(logits, tf.float32)
predictions = {
'classes': tf.argmax(logits, axis=1),
'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
}
if mode == tf.estimator.ModeKeys.PREDICT:
# Return the predictions and the specification for serving a SavedModel
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
export_outputs={
'predict': tf.estimator.export.PredictOutput(predictions)
})
# Calculate loss, which includes softmax cross entropy and L2 regularization.
cross_entropy = tf.losses.sparse_softmax_cross_entropy(
logits=logits, labels=labels)
# Create a tensor named cross_entropy for logging purposes.
tf.identity(cross_entropy, name='cross_entropy')
tf.summary.scalar('cross_entropy', cross_entropy)
# If no loss_filter_fn is passed, assume we want the default behavior,
# which is that batch_normalization variables are excluded from loss.
def exclude_batch_norm(name):
return 'batch_normalization' not in name
loss_filter_fn = loss_filter_fn or exclude_batch_norm
# Add weight decay to the loss.
l2_loss = weight_decay * tf.add_n(
# loss is computed using fp32 for numerical stability.
[tf.nn.l2_loss(tf.cast(v, tf.float32)) for v in tf.trainable_variables()
if loss_filter_fn(v.name)])
tf.summary.scalar('l2_loss', l2_loss)
loss = cross_entropy + l2_loss
if mode == tf.estimator.ModeKeys.TRAIN:
global_step = tf.train.get_or_create_global_step()
learning_rate = learning_rate_fn(global_step)
# Create a tensor named learning_rate for logging purposes
tf.identity(learning_rate, name='learning_rate')
tf.summary.scalar('learning_rate', learning_rate)
optimizer = tf.train.MomentumOptimizer(
learning_rate=learning_rate,
momentum=momentum
)
def _dense_grad_filter(gvs):
"""Only apply gradient updates to the final layer.
This function is used for fine tuning.
Args:
gvs: list of tuples with gradients and variable info
Returns:
filtered gradients so that only the dense layer remains
"""
return [(g, v) for g, v in gvs if 'dense' in v.name]
if loss_scale != 1:
# When computing fp16 gradients, often intermediate tensor values are
# so small, they underflow to 0. To avoid this, we multiply the loss by
# loss_scale to make these tensor values loss_scale times bigger.
scaled_grad_vars = optimizer.compute_gradients(loss * loss_scale)
if fine_tune:
scaled_grad_vars = _dense_grad_filter(scaled_grad_vars)
# Once the gradient computation is complete we can scale the gradients
# back to the correct scale before passing them to the optimizer.
unscaled_grad_vars = [(grad / loss_scale, var)
for grad, var in scaled_grad_vars]
minimize_op = optimizer.apply_gradients(unscaled_grad_vars, global_step)
else:
grad_vars = optimizer.compute_gradients(loss)
if fine_tune:
grad_vars = _dense_grad_filter(grad_vars)
minimize_op = optimizer.apply_gradients(grad_vars, global_step)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
train_op = tf.group(minimize_op, update_ops)
else:
train_op = None
accuracy = tf.metrics.accuracy(labels, predictions['classes'])
accuracy_top_5 = tf.metrics.mean(tf.nn.in_top_k(predictions=logits,
targets=labels,
k=5,
name='top_5_op'))
metrics = {'accuracy': accuracy,
'accuracy_top_5': accuracy_top_5}
# Create a tensor named train_accuracy for logging purposes
tf.identity(accuracy[1], name='train_accuracy')
tf.identity(accuracy_top_5[1], name='train_accuracy_top_5')
tf.summary.scalar('train_accuracy', accuracy[1])
tf.summary.scalar('train_accuracy_top_5', accuracy_top_5[1])
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op,
eval_metric_ops=metrics)
def resnet_main(
flags_obj, model_function, input_function, dataset_name, shape=None):
"""Shared main loop for ResNet Models.
Args:
flags_obj: An object containing parsed flags. See define_resnet_flags()
for details.
model_function: the function that instantiates the Model and builds the
ops for train/eval. This will be passed directly into the estimator.
input_function: the function that processes the dataset and returns a
dataset that the estimator can train on. This will be wrapped with
all the relevant flags for running and passed to estimator.
dataset_name: the name of the dataset for training and evaluation. This is
used for logging purpose.
shape: list of ints representing the shape of the images used for training.
This is only used if flags_obj.export_dir is passed.
Returns:
Dict of results of the run.
"""
model_helpers.apply_clean(flags.FLAGS)
# Ensures flag override logic is only executed if explicitly triggered.
if flags_obj.tf_gpu_thread_mode:
override_flags_and_set_envars_for_gpu_thread_pool(flags_obj)
# Creates session config. allow_soft_placement = True, is required for
# multi-GPU and is not harmful for other modes.
session_config = tf.ConfigProto(
inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
allow_soft_placement=True)
distribution_strategy = distribution_utils.get_distribution_strategy(
flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)
# Creates a `RunConfig` that checkpoints every 24 hours which essentially
# results in checkpoints determined only by `epochs_between_evals`.
run_config = tf.estimator.RunConfig(
train_distribute=distribution_strategy,
session_config=session_config,
save_checkpoints_secs=60*60*24)
# Initializes model with all but the dense layer from pretrained ResNet.
if flags_obj.pretrained_model_checkpoint_path is not None:
warm_start_settings = tf.estimator.WarmStartSettings(
flags_obj.pretrained_model_checkpoint_path,
vars_to_warm_start='^(?!.*dense)')
else:
warm_start_settings = None
classifier = tf.estimator.Estimator(
model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config,
warm_start_from=warm_start_settings, params={
'resnet_size': int(flags_obj.resnet_size),
'data_format': flags_obj.data_format,
'batch_size': flags_obj.batch_size,
'resnet_version': int(flags_obj.resnet_version),
'loss_scale': flags_core.get_loss_scale(flags_obj),
'dtype': flags_core.get_tf_dtype(flags_obj),
'fine_tune': flags_obj.fine_tune
})
run_params = {
'batch_size': flags_obj.batch_size,
'dtype': flags_core.get_tf_dtype(flags_obj),
'resnet_size': flags_obj.resnet_size,
'resnet_version': flags_obj.resnet_version,
'synthetic_data': flags_obj.use_synthetic_data,
'train_epochs': flags_obj.train_epochs,
}
if flags_obj.use_synthetic_data:
dataset_name = dataset_name + '-synthetic'
benchmark_logger = logger.get_benchmark_logger()
benchmark_logger.log_run_info('resnet', dataset_name, run_params,
test_id=flags_obj.benchmark_test_id)
train_hooks = hooks_helper.get_train_hooks(
flags_obj.hooks,
model_dir=flags_obj.model_dir,
batch_size=flags_obj.batch_size)
def input_fn_train(num_epochs):
return input_function(
is_training=True,
data_dir=flags_obj.data_dir,
batch_size=distribution_utils.per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
num_epochs=num_epochs,
dtype=flags_core.get_tf_dtype(flags_obj),
datasets_num_private_threads=flags_obj.datasets_num_private_threads,
num_parallel_batches=flags_obj.datasets_num_parallel_batches)
def input_fn_eval():
return input_function(
is_training=False,
data_dir=flags_obj.data_dir,
batch_size=distribution_utils.per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
num_epochs=1,
dtype=flags_core.get_tf_dtype(flags_obj))
if flags_obj.eval_only or not flags_obj.train_epochs:
# If --eval_only is set, perform a single loop with zero train epochs.
schedule, n_loops = [0], 1
else:
# Compute the number of times to loop while training. All but the last
# pass will train for `epochs_between_evals` epochs, while the last will
# train for the number needed to reach `training_epochs`. For instance if
# train_epochs = 25 and epochs_between_evals = 10
# schedule will be set to [10, 10, 5]. That is to say, the loop will:
# Train for 10 epochs and then evaluate.
# Train for another 10 epochs and then evaluate.
# Train for a final 5 epochs (to reach 25 epochs) and then evaluate.
n_loops = math.ceil(flags_obj.train_epochs / flags_obj.epochs_between_evals)
schedule = [flags_obj.epochs_between_evals for _ in range(int(n_loops))]
schedule[-1] = flags_obj.train_epochs - sum(schedule[:-1]) # over counting.
for cycle_index, num_train_epochs in enumerate(schedule):
tf.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops))
if num_train_epochs:
classifier.train(input_fn=lambda: input_fn_train(num_train_epochs),
hooks=train_hooks, max_steps=flags_obj.max_train_steps)
tf.logging.info('Starting to evaluate.')
# flags_obj.max_train_steps is generally associated with testing and
# profiling. As a result it is frequently called with synthetic data, which
# will iterate forever. Passing steps=flags_obj.max_train_steps allows the
# eval (which is generally unimportant in those circumstances) to terminate.
# Note that eval will run for max_train_steps each loop, regardless of the
# global_step count.
eval_results = classifier.evaluate(input_fn=input_fn_eval,
steps=flags_obj.max_train_steps)
benchmark_logger.log_evaluation_result(eval_results)
if model_helpers.past_stop_threshold(
flags_obj.stop_threshold, eval_results['accuracy']):
break
if flags_obj.export_dir is not None:
# Exports a saved model for the given classifier.
export_dtype = flags_core.get_tf_dtype(flags_obj)
if flags_obj.image_bytes_as_serving_input:
input_receiver_fn = functools.partial(
image_bytes_serving_input_fn, shape, dtype=export_dtype)
else:
input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
shape, batch_size=flags_obj.batch_size, dtype=export_dtype)
classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn,
strip_default_attrs=True)
return eval_results
def define_resnet_flags(resnet_size_choices=None):
"""Add flags and validators for ResNet."""
flags_core.define_base()
flags_core.define_performance(num_parallel_calls=False,
tf_gpu_thread_mode=True,
datasets_num_private_threads=True,
datasets_num_parallel_batches=True)
flags_core.define_image()
flags_core.define_benchmark()
flags.adopt_module_key_flags(flags_core)
flags.DEFINE_enum(
name='resnet_version', short_name='rv', default='1',
enum_values=['1', '2'],
help=flags_core.help_wrap(
'Version of ResNet. (1 or 2) See README.md for details.'))
flags.DEFINE_bool(
name='fine_tune', short_name='ft', default=False,
help=flags_core.help_wrap(
'If True do not train any parameters except for the final layer.'))
flags.DEFINE_string(
name='pretrained_model_checkpoint_path', short_name='pmcp', default=None,
help=flags_core.help_wrap(
'If not None initialize all the network except the final layer with '
'these values'))
flags.DEFINE_boolean(
name='eval_only', default=False,
help=flags_core.help_wrap('Skip training and only perform evaluation on '
'the latest checkpoint.'))
flags.DEFINE_boolean(
name='image_bytes_as_serving_input', default=False,
help=flags_core.help_wrap(
'If True exports savedmodel with serving signature that accepts '
'JPEG image bytes instead of a fixed size [HxWxC] tensor that '
'represents the image. The former is easier to use for serving at '
'the expense of image resize/cropping being done as part of model '
'inference. Note, this flag only applies to ImageNet and cannot '
'be used for CIFAR.'))
flags.DEFINE_boolean(
name='turn_off_distribution_strategy', default=False,
help=flags_core.help_wrap('Set to True to not use distribution '
'strategies.'))
choice_kwargs = dict(
name='resnet_size', short_name='rs', default='50',
help=flags_core.help_wrap('The size of the ResNet model to use.'))
if resnet_size_choices is None:
flags.DEFINE_string(**choice_kwargs)
else:
flags.DEFINE_enum(enum_values=resnet_size_choices, **choice_kwargs)
# Transformer Translation Model
This is an implementation of the Transformer translation model as described in the [Attention is All You Need](https://arxiv.org/abs/1706.03762) paper. Based on the code provided by the authors: [Transformer code](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py) from [Tensor2Tensor](https://github.com/tensorflow/tensor2tensor).
Transformer is a neural network architecture that solves sequence to sequence problems using attention mechanisms. Unlike traditional neural seq2seq models, Transformer does not involve recurrent connections. The attention mechanism learns dependencies between tokens in two sequences. Since attention weights apply to all tokens in the sequences, the Transformer model is able to easily capture long-distance dependencies.
Transformer's overall structure follows the standard encoder-decoder pattern. The encoder uses self-attention to compute a representation of the input sequence. The decoder generates the output sequence one token at a time, taking the encoder output and previous decoder-outputted tokens as inputs.
The model also applies embeddings on the input and output tokens, and adds a constant positional encoding. The positional encoding adds information about the position of each token.
## Contents
* [Contents](#contents)
* [Walkthrough](#walkthrough)
* [Benchmarks](#benchmarks)
* [Training times](#training-times)
* [Evaluation results](#evaluation-results)
* [Detailed instructions](#detailed-instructions)
* [Environment preparation](#environment-preparation)
* [Download and preprocess datasets](#download-and-preprocess-datasets)
* [Model training and evaluation](#model-training-and-evaluation)
* [Translate using the model](#translate-using-the-model)
* [Compute official BLEU score](#compute-official-bleu-score)
* [TPU](#tpu)
* [Export trained model](#export-trained-model)
* [Example translation](#example-translation)
* [Implementation overview](#implementation-overview)
* [Model Definition](#model-definition)
* [Model Estimator](#model-estimator)
* [Other scripts](#other-scripts)
* [Test dataset](#test-dataset)
* [Term definitions](#term-definitions)
## Walkthrough
Below are the commands for running the Transformer model. See the [Detailed instrutions](#detailed-instructions) for more details on running the model.
```
cd /path/to/models/official/transformer
# Ensure that PYTHONPATH is correctly defined as described in
# https://github.com/tensorflow/models/tree/master/official#requirements
# export PYTHONPATH="$PYTHONPATH:/path/to/models"
# Export variables
PARAM_SET=big
DATA_DIR=$HOME/transformer/data
MODEL_DIR=$HOME/transformer/model_$PARAM_SET
VOCAB_FILE=$DATA_DIR/vocab.ende.32768
# Download training/evaluation datasets
python data_download.py --data_dir=$DATA_DIR
# Train the model for 10 epochs, and evaluate after every epoch.
python transformer_main.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--vocab_file=$VOCAB_FILE --param_set=$PARAM_SET \
--bleu_source=test_data/newstest2014.en --bleu_ref=test_data/newstest2014.de
# Run during training in a separate process to get continuous updates,
# or after training is complete.
tensorboard --logdir=$MODEL_DIR
# Translate some text using the trained model
python translate.py --model_dir=$MODEL_DIR --vocab_file=$VOCAB_FILE \
--param_set=$PARAM_SET --text="hello world"
# Compute model's BLEU score using the newstest2014 dataset.
python translate.py --model_dir=$MODEL_DIR --vocab_file=$VOCAB_FILE \
--param_set=$PARAM_SET --file=test_data/newstest2014.en --file_out=translation.en
python compute_bleu.py --translation=translation.en --reference=test_data/newstest2014.de
```
## Benchmarks
### Training times
Currently, both big and base parameter sets run on a single GPU. The measurements below
are reported from running the model on a P100 GPU.
Param Set | batches/sec | batches per epoch | time per epoch
--- | --- | --- | ---
base | 4.8 | 83244 | 4 hr
big | 1.1 | 41365 | 10 hr
### Evaluation results
Below are the case-insensitive BLEU scores after 10 epochs.
Param Set | Score
--- | --- |
base | 27.7
big | 28.9
## Detailed instructions
0. ### Environment preparation
#### Add models repo to PYTHONPATH
Follow the instructions described in the [Requirements](https://github.com/tensorflow/models/tree/master/official#requirements) section to add the models folder to the python path.
#### Export variables (optional)
Export the following variables, or modify the values in each of the snippets below:
```
PARAM_SET=big
DATA_DIR=$HOME/transformer/data
MODEL_DIR=$HOME/transformer/model_$PARAM_SET
VOCAB_FILE=$DATA_DIR/vocab.ende.32768
```
1. ### Download and preprocess datasets
[data_download.py](data_download.py) downloads and preprocesses the training and evaluation WMT datasets. After the data is downloaded and extracted, the training data is used to generate a vocabulary of subtokens. The evaluation and training strings are tokenized, and the resulting data is sharded, shuffled, and saved as TFRecords.
1.75GB of compressed data will be downloaded. In total, the raw files (compressed, extracted, and combined files) take up 8.4GB of disk space. The resulting TFRecord and vocabulary files are 722MB. The script takes around 40 minutes to run, with the bulk of the time spent downloading and ~15 minutes spent on preprocessing.
Command to run:
```
python data_download.py --data_dir=$DATA_DIR
```
Arguments:
* `--data_dir`: Path where the preprocessed TFRecord data, and vocab file will be saved.
* Use the `--help` or `-h` flag to get a full list of possible arguments.
2. ### Model training and evaluation
[transformer_main.py](transformer_main.py) creates a Transformer model, and trains it using Tensorflow Estimator.
Command to run:
```
python transformer_main.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--vocab_file=$VOCAB_FILE --param_set=$PARAM_SET
```
Arguments:
* `--data_dir`: This should be set to the same directory given to the `data_download`'s `data_dir` argument.
* `--model_dir`: Directory to save Transformer model training checkpoints.
* `--vocab_file`: Path to subtoken vacbulary file. If data_download was used, you may find the file in `data_dir`.
* `--param_set`: Parameter set to use when creating and training the model. Options are `base` and `big` (default).
* Use the `--help` or `-h` flag to get a full list of possible arguments.
#### Customizing training schedule
By default, the model will train for 10 epochs, and evaluate after every epoch. The training schedule may be defined through the flags:
* Training with epochs (default):
* `--train_epochs`: The total number of complete passes to make through the dataset
* `--epochs_between_evals`: The number of epochs to train between evaluations.
* Training with steps:
* `--train_steps`: sets the total number of training steps to run.
* `--steps_between_evals`: Number of training steps to run between evaluations.
Only one of `train_epochs` or `train_steps` may be set. Since the default option is to evaluate the model after training for an epoch, it may take 4 or more hours between model evaluations. To get more frequent evaluations, use the flags `--train_steps=250000 --steps_between_evals=1000`.
Note: At the beginning of each training session, the training dataset is reloaded and shuffled. Stopping the training before completing an epoch may result in worse model quality, due to the chance that some examples may be seen more than others. Therefore, it is recommended to use epochs when the model quality is important.
#### Compute BLEU score during model evaluation
Use these flags to compute the BLEU when the model evaluates:
* `--bleu_source`: Path to file containing text to translate.
* `--bleu_ref`: Path to file containing the reference translation.
* `--stop_threshold`: Train until the BLEU score reaches this lower bound. This setting overrides the `--train_steps` and `--train_epochs` flags.
The test source and reference files located in the `test_data` directory are extracted from the preprocessed dataset from the [NMT Seq2Seq tutorial](https://google.github.io/seq2seq/nmt/#download-data).
When running `transformer_main.py`, use the flags: `--bleu_source=test_data/newstest2014.en --bleu_ref=test_data/newstest2014.de`
#### Tensorboard
Training and evaluation metrics (loss, accuracy, approximate BLEU score, etc.) are logged, and can be displayed in the browser using Tensorboard.
```
tensorboard --logdir=$MODEL_DIR
```
The values are displayed at [localhost:6006](localhost:6006).
3. ### Translate using the model
[translate.py](translate.py) contains the script to use the trained model to translate input text or file. Each line in the file is translated separately.
Command to run:
```
python translate.py --model_dir=$MODEL_DIR --vocab_file=$VOCAB_FILE \
--param_set=$PARAM_SET --text="hello world"
```
Arguments for initializing the Subtokenizer and trained model:
* `--model_dir` and `--param_set`: These parameters are used to rebuild the trained model
* `--vocab_file`: Path to subtoken vacbulary file. If data_download was used, you may find the file in `data_dir`.
Arguments for specifying what to translate:
* `--text`: Text to translate
* `--file`: Path to file containing text to translate
* `--file_out`: If `--file` is set, then this file will store the input file's translations.
To translate the newstest2014 data, run:
```
python translate.py --model_dir=$MODEL_DIR --vocab_file=$VOCAB_FILE \
--param_set=$PARAM_SET --file=test_data/newstest2014.en --file_out=translation.en
```
Translating the file takes around 15 minutes on a GTX1080, or 5 minutes on a P100.
4. ### Compute official BLEU score
Use [compute_bleu.py](compute_bleu.py) to compute the BLEU by comparing generated translations to the reference translation.
Command to run:
```
python compute_bleu.py --translation=translation.en --reference=test_data/newstest2014.de
```
Arguments:
* `--translation`: Path to file containing generated translations.
* `--reference`: Path to file containing reference translations.
* Use the `--help` or `-h` flag to get a full list of possible arguments.
5. ### TPU
TPU support for this version of Transformer is experimental. Currently it is present for
demonstration purposes only, but will be optimized in the coming weeks.
## Export trained model
To export the model as a Tensorflow [SavedModel](https://www.tensorflow.org/guide/saved_model) format, use the argument `--export_dir` when running `transformer_main.py`. A folder will be created in the directory with the name as the timestamp (e.g. $EXPORT_DIR/1526427396).
```
EXPORT_DIR=$HOME/transformer/saved_model
python transformer_main.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--vocab_file=$VOCAB_FILE --param_set=$PARAM_SET --export_model=$EXPORT_DIR
```
To inspect the SavedModel, use saved_model_cli:
```
SAVED_MODEL_DIR=$EXPORT_DIR/{TIMESTAMP} # replace {TIMESTAMP} with the name of the folder created
saved_model_cli show --dir=$SAVED_MODEL_DIR --all
```
### Example translation
Let's translate **"hello world!"**, **"goodbye world."**, and **"Would you like some pie?"**.
The SignatureDef for "translate" is:
signature_def['translate']:
The given SavedModel SignatureDef contains the following input(s):
inputs['input'] tensor_info:
dtype: DT_INT64
shape: (-1, -1)
name: Placeholder:0
The given SavedModel SignatureDef contains the following output(s):
outputs['outputs'] tensor_info:
dtype: DT_INT32
shape: (-1, -1)
name: model/Transformer/strided_slice_19:0
outputs['scores'] tensor_info:
dtype: DT_FLOAT
shape: (-1)
name: model/Transformer/strided_slice_20:0
Follow the steps below to use the translate signature def:
1. #### Encode the inputs to integer arrays.
This can be done using `utils.tokenizer.Subtokenizer`, and the vocab file in the SavedModel assets (`$SAVED_MODEL_DIR/assets.extra/vocab.txt`).
```
from official.transformer.utils.tokenizer import Subtokenizer
s = Subtokenizer(PATH_TO_VOCAB_FILE)
print(s.encode("hello world!", add_eos=True))
```
The encoded inputs are:
* `"hello world!" = [6170, 3731, 178, 207, 1]`
* `"goodbye world." = [15431, 13966, 36, 178, 3, 1]`
* `"Would you like some pie?" = [9092, 72, 155, 202, 19851, 102, 1]`
2. #### Run `saved_model_cli` to obtain the predicted translations
The encoded inputs should be padded so that they are the same length. The padding token is `0`.
```
ENCODED_INPUTS="[[26228, 145, 178, 1, 0, 0, 0], \
[15431, 13966, 36, 178, 3, 1, 0], \
[9092, 72, 155, 202, 19851, 102, 1]]"
```
Now, use the `run` command with `saved_model_cli` to get the outputs.
```
saved_model_cli run --dir=$SAVED_MODEL_DIR --tag_set=serve --signature_def=translate \
--input_expr="input=$ENCODED_INPUTS"
```
The outputs will look similar to:
```
Result for output key outputs:
[[18744 145 297 1 0 0 0 0 0 0 0 0
0 0]
[ 5450 4642 21 11 297 3 1 0 0 0 0 0
0 0]
[25940 22 66 103 21713 31 102 1 0 0 0 0
0 0]]
Result for output key scores:
[-1.5493642 -1.4032784 -3.252089 ]
```
3. #### Decode the outputs to strings.
Use the `Subtokenizer` and vocab file as described in step 1 to decode the output integer arrays.
```
from official.transformer.utils.tokenizer import Subtokenizer
s = Subtokenizer(PATH_TO_VOCAB_FILE)
print(s.decode([18744, 145, 297, 1]))
```
The decoded outputs from above are:
* `[18744, 145, 297, 1] = "Hallo Welt<EOS>"`
* `[5450, 4642, 21, 11, 297, 3, 1] = "Abschied von der Welt.<EOS>"`
* `[25940, 22, 66, 103, 21713, 31, 102, 1] = "Möchten Sie einen Kuchen?<EOS>"`
## Implementation overview
A brief look at each component in the code:
### Model Definition
The [model](model) subdirectory contains the implementation of the Transformer model. The following files define the Transformer model and its layers:
* [transformer.py](model/transformer.py): Defines the transformer model and its encoder/decoder layer stacks.
* [embedding_layer.py](model/embedding_layer.py): Contains the layer that calculates the embeddings. The embedding weights are also used to calculate the pre-softmax probabilities from the decoder output.
* [attention_layer.py](model/attention_layer.py): Defines the multi-headed and self attention layers that are used in the encoder/decoder stacks.
* [ffn_layer.py](model/ffn_layer.py): Defines the feedforward network that is used in the encoder/decoder stacks. The network is composed of 2 fully connected layers.
Other files:
* [beam_search.py](model/beam_search.py) contains the beam search implementation, which is used during model inference to find high scoring translations.
* [model_params.py](model/model_params.py) contains the parameters used for the big and base models.
* [model_utils.py](model/model_utils.py) defines some helper functions used in the model (calculating padding, bias, etc.).
### Model Estimator
[transformer_main.py](model/transformer.py) creates an `Estimator` to train and evaluate the model.
Helper functions:
* [utils/dataset.py](utils/dataset.py): contains functions for creating a `dataset` that is passed to the `Estimator`.
* [utils/metrics.py](utils/metrics.py): defines metrics functions used by the `Estimator` to evaluate the
### Other scripts
Aside from the main file to train the Transformer model, we provide other scripts for using the model or downloading the data:
#### Data download and preprocessing
[data_download.py](data_download.py) downloads and extracts data, then uses `Subtokenizer` to tokenize strings into arrays of int IDs. The int arrays are converted to `tf.Examples` and saved in the `tf.RecordDataset` format.
The data is downloaded from the Workshop of Machine Transtion (WMT) [news translation task](http://www.statmt.org/wmt17/translation-task.html). The following datasets are used:
* Europarl v7
* Common Crawl corpus
* News Commentary v12
See the [download section](http://www.statmt.org/wmt17/translation-task.html#download) to explore the raw datasets. The parameters in this model are tuned to fit the English-German translation data, so the EN-DE texts are extracted from the downloaded compressed files.
The text is transformed into arrays of integer IDs using the `Subtokenizer` defined in [`utils/tokenizer.py`](util/tokenizer.py). During initialization of the `Subtokenizer`, the raw training data is used to generate a vocabulary list containing common subtokens.
The target vocabulary size of the WMT dataset is 32,768. The set of subtokens is found through binary search on the minimum number of times a subtoken appears in the data. The actual vocabulary size is 33,708, and is stored in a 324kB file.
#### Translation
Translation is defined in [translate.py](translate.py). First, `Subtokenizer` tokenizes the input. The vocabulary file is the same used to tokenize the training/eval files. Next, beam search is used to find the combination of tokens that maximizes the probability outputted by the model decoder. The tokens are then converted back to strings with `Subtokenizer`.
#### BLEU computation
[compute_bleu.py](compute_bleu.py): Implementation from [https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/bleu_hook.py](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/bleu_hook.py).
### Test dataset
The [newstest2014 files](test_data) are extracted from the [NMT Seq2Seq tutorial](https://google.github.io/seq2seq/nmt/#download-data). The raw text files are converted from the SGM format of the [WMT 2016](http://www.statmt.org/wmt16/translation-task.html) test sets.
## Term definitions
**Steps / Epochs**:
* Step: unit for processing a single batch of data
* Epoch: a complete run through the dataset
Example: Consider a training a dataset with 100 examples that is divided into 20 batches with 5 examples per batch. A single training step trains the model on one batch. After 20 training steps, the model will have trained on every batch in the dataset, or one epoch.
**Subtoken**: Words are referred to as tokens, and parts of words are referred to as 'subtokens'. For example, the word 'inclined' may be split into `['incline', 'd_']`. The '\_' indicates the end of the token. The subtoken vocabulary list is guaranteed to contain the alphabet (including numbers and special characters), so all words can be tokenized.
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Script to compute official BLEU score.
Source:
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/bleu_hook.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
import sys
import unicodedata
# pylint: disable=g-bad-import-order
import six
from absl import app as absl_app
from absl import flags
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.transformer.utils import metrics
from official.utils.flags import core as flags_core
class UnicodeRegex(object):
"""Ad-hoc hack to recognize all punctuation and symbols."""
def __init__(self):
punctuation = self.property_chars("P")
self.nondigit_punct_re = re.compile(r"([^\d])([" + punctuation + r"])")
self.punct_nondigit_re = re.compile(r"([" + punctuation + r"])([^\d])")
self.symbol_re = re.compile("([" + self.property_chars("S") + "])")
def property_chars(self, prefix):
return "".join(six.unichr(x) for x in range(sys.maxunicode)
if unicodedata.category(six.unichr(x)).startswith(prefix))
uregex = UnicodeRegex()
def bleu_tokenize(string):
r"""Tokenize a string following the official BLEU implementation.
See https://github.com/moses-smt/mosesdecoder/'
'blob/master/scripts/generic/mteval-v14.pl#L954-L983
In our case, the input string is expected to be just one line
and no HTML entities de-escaping is needed.
So we just tokenize on punctuation and symbols,
except when a punctuation is preceded and followed by a digit
(e.g. a comma/dot as a thousand/decimal separator).
Note that a numer (e.g. a year) followed by a dot at the end of sentence
is NOT tokenized,
i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g`
does not match this case (unless we add a space after each sentence).
However, this error is already in the original mteval-v14.pl
and we want to be consistent with it.
Args:
string: the input string
Returns:
a list of tokens
"""
string = uregex.nondigit_punct_re.sub(r"\1 \2 ", string)
string = uregex.punct_nondigit_re.sub(r" \1 \2", string)
string = uregex.symbol_re.sub(r" \1 ", string)
return string.split()
def bleu_wrapper(ref_filename, hyp_filename, case_sensitive=False):
"""Compute BLEU for two files (reference and hypothesis translation)."""
ref_lines = tf.gfile.Open(ref_filename).read().strip().splitlines()
hyp_lines = tf.gfile.Open(hyp_filename).read().strip().splitlines()
if len(ref_lines) != len(hyp_lines):
raise ValueError("Reference and translation files have different number of "
"lines.")
if not case_sensitive:
ref_lines = [x.lower() for x in ref_lines]
hyp_lines = [x.lower() for x in hyp_lines]
ref_tokens = [bleu_tokenize(x) for x in ref_lines]
hyp_tokens = [bleu_tokenize(x) for x in hyp_lines]
return metrics.compute_bleu(ref_tokens, hyp_tokens) * 100
def main(unused_argv):
if FLAGS.bleu_variant in ("both", "uncased"):
score = bleu_wrapper(FLAGS.reference, FLAGS.translation, False)
tf.logging.info("Case-insensitive results: %f" % score)
if FLAGS.bleu_variant in ("both", "cased"):
score = bleu_wrapper(FLAGS.reference, FLAGS.translation, True)
tf.logging.info("Case-sensitive results: %f" % score)
def define_compute_bleu_flags():
"""Add flags for computing BLEU score."""
flags.DEFINE_string(
name="translation", default=None,
help=flags_core.help_wrap("File containing translated text."))
flags.mark_flag_as_required("translation")
flags.DEFINE_string(
name="reference", default=None,
help=flags_core.help_wrap("File containing reference translation."))
flags.mark_flag_as_required("reference")
flags.DEFINE_enum(
name="bleu_variant", short_name="bv", default="both",
enum_values=["both", "uncased", "cased"], case_sensitive=False,
help=flags_core.help_wrap(
"Specify one or more BLEU variants to calculate. Variants: \"cased\""
", \"uncased\", or \"both\"."))
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
define_compute_bleu_flags()
FLAGS = flags.FLAGS
absl_app.run(main)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test functions in compute_blue.py."""
import tempfile
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.transformer import compute_bleu
class ComputeBleuTest(tf.test.TestCase):
def _create_temp_file(self, text):
temp_file = tempfile.NamedTemporaryFile(delete=False)
with tf.gfile.Open(temp_file.name, 'w') as w:
w.write(text)
return temp_file.name
def test_bleu_same(self):
ref = self._create_temp_file("test 1 two 3\nmore tests!")
hyp = self._create_temp_file("test 1 two 3\nmore tests!")
uncased_score = compute_bleu.bleu_wrapper(ref, hyp, False)
cased_score = compute_bleu.bleu_wrapper(ref, hyp, True)
self.assertEqual(100, uncased_score)
self.assertEqual(100, cased_score)
def test_bleu_same_different_case(self):
ref = self._create_temp_file("Test 1 two 3\nmore tests!")
hyp = self._create_temp_file("test 1 two 3\nMore tests!")
uncased_score = compute_bleu.bleu_wrapper(ref, hyp, False)
cased_score = compute_bleu.bleu_wrapper(ref, hyp, True)
self.assertEqual(100, uncased_score)
self.assertLess(cased_score, 100)
def test_bleu_different(self):
ref = self._create_temp_file("Testing\nmore tests!")
hyp = self._create_temp_file("Dog\nCat")
uncased_score = compute_bleu.bleu_wrapper(ref, hyp, False)
cased_score = compute_bleu.bleu_wrapper(ref, hyp, True)
self.assertLess(uncased_score, 100)
self.assertLess(cased_score, 100)
def test_bleu_tokenize(self):
s = "Test0, 1 two, 3"
tokenized = compute_bleu.bleu_tokenize(s)
self.assertEqual(["Test0", ",", "1", "two", ",", "3"], tokenized)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment