"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "61fb912ca46fe902180892316f6cc34adda07b67"
Commit 5a66a397 authored by Barret Zoph's avatar Barret Zoph
Browse files

Added autoaugment model.

parent 4e92bc57
...@@ -4,6 +4,7 @@ research/adversarial_text/* @rsepassi ...@@ -4,6 +4,7 @@ research/adversarial_text/* @rsepassi
research/adv_imagenet_models/* @AlexeyKurakin research/adv_imagenet_models/* @AlexeyKurakin
research/attention_ocr/* @alexgorban research/attention_ocr/* @alexgorban
research/audioset/* @plakal @dpwe research/audioset/* @plakal @dpwe
research/autoaugment/* @barretzoph
research/autoencoders/* @snurkabill research/autoencoders/* @snurkabill
research/cognitive_mapping_and_planning/* @s-gupta research/cognitive_mapping_and_planning/* @s-gupta
research/compression/* @nmjohn research/compression/* @nmjohn
......
<font size=4><b>Train Wide-ResNet, Shake-Shake and ShakeDrop models on CIFAR-10
and CIFAR-100 dataset with AutoAugment.</b></font>
The CIFAR-10/CIFAR-100 data can be downloaded from:
https://www.cs.toronto.edu/~kriz/cifar.html.
The code replicates the results from Tables 1 and 2 on CIFAR-10/100 with the
following models: Wide-ResNet-28-10, Shake-Shake (26 2x32d), Shake-Shake (26
2x96d) and PyramidNet+ShakeDrop.
<b>Related papers:</b>
AutoAugment: Learning Augmentation Policies from Data
https://arxiv.org/abs/1805.09501
Wide Residual Networks
https://arxiv.org/abs/1605.07146
Shake-Shake regularization
https://arxiv.org/abs/1705.07485
ShakeDrop regularization
https://arxiv.org/abs/1802.02375
<b>Settings:</b>
CIFAR-10 Model | Learning Rate | Weight Decay | Num. Epochs | Batch Size
---------------------- | ------------- | ------------ | ----------- | ----------
Wide-ResNet-28-10 | 0.1 | 5e-4 | 200 | 128
Shake-Shake (26 2x32d) | 0.01 | 1e-3 | 1800 | 128
Shake-Shake (26 2x96d) | 0.01 | 1e-3 | 1800 | 128
PyramidNet + ShakeDrop | 0.05 | 5e-5 | 1800 | 64
<b>Prerequisite:</b>
1. Install TensorFlow.
2. Download CIFAR-10/CIFAR-100 dataset.
```shell
curl -o cifar-10-binary.tar.gz https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz
curl -o cifar-100-binary.tar.gz https://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz
```
<b>How to run:</b>
```shell
# cd to the your workspace.
# Specify the directory where dataset is located using the data_path flag.
# Note: User can split samples from training set into the eval set by changing train_size and validation_size.
# For example, to train the Wide-ResNet-28-10 model on a GPU.
python train_cifar.py --model_name=wrn \
--checkpoint_dir=/tmp/training \
--data_path=/tmp/data \
--dataset='cifar10' \
--use_cpu=0
```
## Contact for Issues
* Barret Zoph, @barretzoph <barretzoph@google.com>
* Ekin Dogus Cubuk, <cubuk@google.com>
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Transforms used in the Augmentation Policies."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import numpy as np
# pylint:disable=g-multiple-import
from PIL import ImageOps, ImageEnhance, ImageFilter, Image
# pylint:enable=g-multiple-import
IMAGE_SIZE = 32
# What is the dataset mean and std of the images on the training set
MEANS = [0.49139968, 0.48215841, 0.44653091]
STDS = [0.24703223, 0.24348513, 0.26158784]
PARAMETER_MAX = 10 # What is the max 'level' a transform could be predicted
def random_flip(x):
"""Flip the input x horizontally with 50% probability."""
if np.random.rand(1)[0] > 0.5:
return np.fliplr(x)
return x
def zero_pad_and_crop(img, amount=4):
"""Zero pad by `amount` zero pixels on each side then take a random crop.
Args:
img: numpy image that will be zero padded and cropped.
amount: amount of zeros to pad `img` with horizontally and verically.
Returns:
The cropped zero padded img. The returned numpy array will be of the same
shape as `img`.
"""
padded_img = np.zeros((img.shape[0] + amount * 2, img.shape[1] + amount * 2,
img.shape[2]))
padded_img[amount:img.shape[0] + amount, amount:
img.shape[1] + amount, :] = img
top = np.random.randint(low=0, high=2 * amount)
left = np.random.randint(low=0, high=2 * amount)
new_img = padded_img[top:top + img.shape[0], left:left + img.shape[1], :]
return new_img
def create_cutout_mask(img_height, img_width, num_channels, size):
"""Creates a zero mask used for cutout of shape `img_height` x `img_width`.
Args:
img_height: Height of image cutout mask will be applied to.
img_width: Width of image cutout mask will be applied to.
num_channels: Number of channels in the image.
size: Size of the zeros mask.
Returns:
A mask of shape `img_height` x `img_width` with all ones except for a
square of zeros of shape `size` x `size`. This mask is meant to be
elementwise multiplied with the original image. Additionally returns
the `upper_coord` and `lower_coord` which specify where the cutout mask
will be applied.
"""
assert img_height == img_width
# Sample center where cutout mask will be applied
height_loc = np.random.randint(low=0, high=img_height)
width_loc = np.random.randint(low=0, high=img_width)
# Determine upper right and lower left corners of patch
upper_coord = (max(0, height_loc - size // 2), max(0, width_loc - size // 2))
lower_coord = (min(img_height, height_loc + size // 2),
min(img_width, width_loc + size // 2))
mask_height = lower_coord[0] - upper_coord[0]
mask_width = lower_coord[1] - upper_coord[1]
assert mask_height > 0
assert mask_width > 0
mask = np.ones((img_height, img_width, num_channels))
zeros = np.zeros((mask_height, mask_width, num_channels))
mask[upper_coord[0]:lower_coord[0], upper_coord[1]:lower_coord[1], :] = (
zeros)
return mask, upper_coord, lower_coord
def cutout_numpy(img, size=16):
"""Apply cutout with mask of shape `size` x `size` to `img`.
The cutout operation is from the paper https://arxiv.org/abs/1708.04552.
This operation applies a `size`x`size` mask of zeros to a random location
within `img`.
Args:
img: Numpy image that cutout will be applied to.
size: Height/width of the cutout mask that will be
Returns:
A numpy tensor that is the result of applying the cutout mask to `img`.
"""
img_height, img_width, num_channels = (img.shape[0], img.shape[1],
img.shape[2])
assert len(img.shape) == 3
mask, _, _ = create_cutout_mask(img_height, img_width, num_channels, size)
return img * mask
def float_parameter(level, maxval):
"""Helper function to scale `val` between 0 and maxval .
Args:
level: Level of the operation that will be between [0, `PARAMETER_MAX`].
maxval: Maximum value that the operation can have. This will be scaled
to level/PARAMETER_MAX.
Returns:
A float that results from scaling `maxval` according to `level`.
"""
return float(level) * maxval / PARAMETER_MAX
def int_parameter(level, maxval):
"""Helper function to scale `val` between 0 and maxval .
Args:
level: Level of the operation that will be between [0, `PARAMETER_MAX`].
maxval: Maximum value that the operation can have. This will be scaled
to level/PARAMETER_MAX.
Returns:
An int that results from scaling `maxval` according to `level`.
"""
return int(level * maxval / PARAMETER_MAX)
def pil_wrap(img):
"""Convert the `img` numpy tensor to a PIL Image."""
return Image.fromarray(
np.uint8((img * STDS + MEANS) * 255.0)).convert('RGBA')
def pil_unwrap(pil_img):
"""Converts the PIL img to a numpy array."""
pic_array = (np.array(pil_img.getdata()).reshape((32, 32, 4)) / 255.0)
i1, i2 = np.where(pic_array[:, :, 3] == 0)
pic_array = (pic_array[:, :, :3] - MEANS) / STDS
pic_array[i1, i2] = [0, 0, 0]
return pic_array
def apply_policy(policy, img):
"""Apply the `policy` to the numpy `img`.
Args:
policy: A list of tuples with the form (name, probability, level) where
`name` is the name of the augmentation operation to apply, `probability`
is the probability of applying the operation and `level` is what strength
the operation to apply.
img: Numpy image that will have `policy` applied to it.
Returns:
The result of applying `policy` to `img`.
"""
pil_img = pil_wrap(img)
for xform in policy:
assert len(xform) == 3
name, probability, level = xform
xform_fn = NAME_TO_TRANSFORM[name].pil_transformer(probability, level)
pil_img = xform_fn(pil_img)
return pil_unwrap(pil_img)
class TransformFunction(object):
"""Wraps the Transform function for pretty printing options."""
def __init__(self, func, name):
self.f = func
self.name = name
def __repr__(self):
return '<' + self.name + '>'
def __call__(self, pil_img):
return self.f(pil_img)
class TransformT(object):
"""Each instance of this class represents a specific transform."""
def __init__(self, name, xform_fn):
self.name = name
self.xform = xform_fn
def pil_transformer(self, probability, level):
def return_function(im):
if random.random() < probability:
im = self.xform(im, level)
return im
name = self.name + '({:.1f},{})'.format(probability, level)
return TransformFunction(return_function, name)
def do_transform(self, image, level):
f = self.pil_transformer(PARAMETER_MAX, level)
return pil_unwrap(f(pil_wrap(image)))
################## Transform Functions ##################
identity = TransformT('identity', lambda pil_img, level: pil_img)
flip_lr = TransformT(
'FlipLR',
lambda pil_img, level: pil_img.transpose(Image.FLIP_LEFT_RIGHT))
flip_ud = TransformT(
'FlipUD',
lambda pil_img, level: pil_img.transpose(Image.FLIP_TOP_BOTTOM))
# pylint:disable=g-long-lambda
auto_contrast = TransformT(
'AutoContrast',
lambda pil_img, level: ImageOps.autocontrast(
pil_img.convert('RGB')).convert('RGBA'))
equalize = TransformT(
'Equalize',
lambda pil_img, level: ImageOps.equalize(
pil_img.convert('RGB')).convert('RGBA'))
invert = TransformT(
'Invert',
lambda pil_img, level: ImageOps.invert(
pil_img.convert('RGB')).convert('RGBA'))
# pylint:enable=g-long-lambda
blur = TransformT(
'Blur', lambda pil_img, level: pil_img.filter(ImageFilter.BLUR))
smooth = TransformT(
'Smooth',
lambda pil_img, level: pil_img.filter(ImageFilter.SMOOTH))
def _rotate_impl(pil_img, level):
"""Rotates `pil_img` from -30 to 30 degrees depending on `level`."""
degrees = int_parameter(level, 30)
if random.random() > 0.5:
degrees = -degrees
return pil_img.rotate(degrees)
rotate = TransformT('Rotate', _rotate_impl)
def _posterize_impl(pil_img, level):
"""Applies PIL Posterize to `pil_img`."""
level = int_parameter(level, 4)
return ImageOps.posterize(pil_img.convert('RGB'), 4 - level).convert('RGBA')
posterize = TransformT('Posterize', _posterize_impl)
def _shear_x_impl(pil_img, level):
"""Applies PIL ShearX to `pil_img`.
The ShearX operation shears the image along the horizontal axis with `level`
magnitude.
Args:
pil_img: Image in PIL object.
level: Strength of the operation specified as an Integer from
[0, `PARAMETER_MAX`].
Returns:
A PIL Image that has had ShearX applied to it.
"""
level = float_parameter(level, 0.3)
if random.random() > 0.5:
level = -level
return pil_img.transform((32, 32), Image.AFFINE, (1, level, 0, 0, 1, 0))
shear_x = TransformT('ShearX', _shear_x_impl)
def _shear_y_impl(pil_img, level):
"""Applies PIL ShearY to `pil_img`.
The ShearY operation shears the image along the vertical axis with `level`
magnitude.
Args:
pil_img: Image in PIL object.
level: Strength of the operation specified as an Integer from
[0, `PARAMETER_MAX`].
Returns:
A PIL Image that has had ShearX applied to it.
"""
level = float_parameter(level, 0.3)
if random.random() > 0.5:
level = -level
return pil_img.transform((32, 32), Image.AFFINE, (1, 0, 0, level, 1, 0))
shear_y = TransformT('ShearY', _shear_y_impl)
def _translate_x_impl(pil_img, level):
"""Applies PIL TranslateX to `pil_img`.
Translate the image in the horizontal direction by `level`
number of pixels.
Args:
pil_img: Image in PIL object.
level: Strength of the operation specified as an Integer from
[0, `PARAMETER_MAX`].
Returns:
A PIL Image that has had TranslateX applied to it.
"""
level = int_parameter(level, 10)
if random.random() > 0.5:
level = -level
return pil_img.transform((32, 32), Image.AFFINE, (1, 0, level, 0, 1, 0))
translate_x = TransformT('TranslateX', _translate_x_impl)
def _translate_y_impl(pil_img, level):
"""Applies PIL TranslateY to `pil_img`.
Translate the image in the vertical direction by `level`
number of pixels.
Args:
pil_img: Image in PIL object.
level: Strength of the operation specified as an Integer from
[0, `PARAMETER_MAX`].
Returns:
A PIL Image that has had TranslateY applied to it.
"""
level = int_parameter(level, 10)
if random.random() > 0.5:
level = -level
return pil_img.transform((32, 32), Image.AFFINE, (1, 0, 0, 0, 1, level))
translate_y = TransformT('TranslateY', _translate_y_impl)
def _crop_impl(pil_img, level, interpolation=Image.BILINEAR):
"""Applies a crop to `pil_img` with the size depending on the `level`."""
cropped = pil_img.crop((level, level, IMAGE_SIZE - level, IMAGE_SIZE - level))
resized = cropped.resize((IMAGE_SIZE, IMAGE_SIZE), interpolation)
return resized
crop_bilinear = TransformT('CropBilinear', _crop_impl)
def _solarize_impl(pil_img, level):
"""Applies PIL Solarize to `pil_img`.
Translate the image in the vertical direction by `level`
number of pixels.
Args:
pil_img: Image in PIL object.
level: Strength of the operation specified as an Integer from
[0, `PARAMETER_MAX`].
Returns:
A PIL Image that has had Solarize applied to it.
"""
level = int_parameter(level, 256)
return ImageOps.solarize(pil_img.convert('RGB'), 256 - level).convert('RGBA')
solarize = TransformT('Solarize', _solarize_impl)
def _cutout_pil_impl(pil_img, level):
"""Apply cutout to pil_img at the specified level."""
size = int_parameter(level, 20)
if size <= 0:
return pil_img
img_height, img_width, num_channels = (32, 32, 3)
_, upper_coord, lower_coord = (
create_cutout_mask(img_height, img_width, num_channels, size))
pixels = pil_img.load() # create the pixel map
for i in range(upper_coord[0], lower_coord[0]): # for every col:
for j in range(upper_coord[1], lower_coord[1]): # For every row
pixels[i, j] = (125, 122, 113, 0) # set the colour accordingly
return pil_img
cutout = TransformT('Cutout', _cutout_pil_impl)
def _enhancer_impl(enhancer):
"""Sets level to be between 0.1 and 1.8 for ImageEnhance transforms of PIL."""
def impl(pil_img, level):
v = float_parameter(level, 1.8) + .1 # going to 0 just destroys it
return enhancer(pil_img).enhance(v)
return impl
color = TransformT('Color', _enhancer_impl(ImageEnhance.Color))
contrast = TransformT('Contrast', _enhancer_impl(ImageEnhance.Contrast))
brightness = TransformT('Brightness', _enhancer_impl(
ImageEnhance.Brightness))
sharpness = TransformT('Sharpness', _enhancer_impl(ImageEnhance.Sharpness))
ALL_TRANSFORMS = [
flip_lr,
flip_ud,
auto_contrast,
equalize,
invert,
rotate,
posterize,
crop_bilinear,
solarize,
color,
contrast,
brightness,
sharpness,
shear_x,
shear_y,
translate_x,
translate_y,
cutout,
blur,
smooth
]
NAME_TO_TRANSFORM = {t.name: t for t in ALL_TRANSFORMS}
TRANSFORM_NAMES = NAME_TO_TRANSFORM.keys()
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains convenience wrappers for typical Neural Network TensorFlow layers.
Ops that have different behavior during training or eval have an is_training
parameter.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
arg_scope = tf.contrib.framework.arg_scope
def variable(name, shape, dtype, initializer, trainable):
"""Returns a TF variable with the passed in specifications."""
var = tf.get_variable(
name,
shape=shape,
dtype=dtype,
initializer=initializer,
trainable=trainable)
return var
def global_avg_pool(x, scope=None):
"""Average pools away spatial height and width dimension of 4D tensor."""
assert x.get_shape().ndims == 4
with tf.name_scope(scope, 'global_avg_pool', [x]):
kernel_size = (1, int(x.shape[1]), int(x.shape[2]), 1)
squeeze_dims = (1, 2)
result = tf.nn.avg_pool(
x,
ksize=kernel_size,
strides=(1, 1, 1, 1),
padding='VALID',
data_format='NHWC')
return tf.squeeze(result, squeeze_dims)
def zero_pad(inputs, in_filter, out_filter):
"""Zero pads `input` tensor to have `out_filter` number of filters."""
outputs = tf.pad(inputs, [[0, 0], [0, 0], [0, 0],
[(out_filter - in_filter) // 2,
(out_filter - in_filter) // 2]])
return outputs
@tf.contrib.framework.add_arg_scope
def batch_norm(inputs,
decay=0.999,
center=True,
scale=False,
epsilon=0.001,
is_training=True,
reuse=None,
scope=None):
"""Small wrapper around tf.contrib.layers.batch_norm."""
return tf.contrib.layers.batch_norm(
inputs,
decay=decay,
center=center,
scale=scale,
epsilon=epsilon,
activation_fn=None,
param_initializers=None,
updates_collections=tf.GraphKeys.UPDATE_OPS,
is_training=is_training,
reuse=reuse,
trainable=True,
fused=True,
data_format='NHWC',
zero_debias_moving_mean=False,
scope=scope)
def stride_arr(stride_h, stride_w):
return [1, stride_h, stride_w, 1]
@tf.contrib.framework.add_arg_scope
def conv2d(inputs,
num_filters_out,
kernel_size,
stride=1,
scope=None,
reuse=None):
"""Adds a 2D convolution.
conv2d creates a variable called 'weights', representing the convolutional
kernel, that is convolved with the input.
Args:
inputs: a 4D tensor in NHWC format.
num_filters_out: the number of output filters.
kernel_size: an int specifying the kernel height and width size.
stride: an int specifying the height and width stride.
scope: Optional scope for variable_scope.
reuse: whether or not the layer and its variables should be reused.
Returns:
a tensor that is the result of a convolution being applied to `inputs`.
"""
with tf.variable_scope(scope, 'Conv', [inputs], reuse=reuse):
num_filters_in = int(inputs.shape[3])
weights_shape = [kernel_size, kernel_size, num_filters_in, num_filters_out]
# Initialization
n = int(weights_shape[0] * weights_shape[1] * weights_shape[3])
weights_initializer = tf.random_normal_initializer(
stddev=np.sqrt(2.0 / n))
weights = variable(
name='weights',
shape=weights_shape,
dtype=tf.float32,
initializer=weights_initializer,
trainable=True)
strides = stride_arr(stride, stride)
outputs = tf.nn.conv2d(
inputs, weights, strides, padding='SAME', data_format='NHWC')
return outputs
@tf.contrib.framework.add_arg_scope
def fc(inputs,
num_units_out,
scope=None,
reuse=None):
"""Creates a fully connected layer applied to `inputs`.
Args:
inputs: a tensor that the fully connected layer will be applied to. It
will be reshaped if it is not 2D.
num_units_out: the number of output units in the layer.
scope: Optional scope for variable_scope.
reuse: whether or not the layer and its variables should be reused.
Returns:
a tensor that is the result of applying a linear matrix to `inputs`.
"""
if len(inputs.shape) > 2:
inputs = tf.reshape(inputs, [int(inputs.shape[0]), -1])
with tf.variable_scope(scope, 'FC', [inputs], reuse=reuse):
num_units_in = inputs.shape[1]
weights_shape = [num_units_in, num_units_out]
unif_init_range = 1.0 / (num_units_out)**(0.5)
weights_initializer = tf.random_uniform_initializer(
-unif_init_range, unif_init_range)
weights = variable(
name='weights',
shape=weights_shape,
dtype=tf.float32,
initializer=weights_initializer,
trainable=True)
bias_initializer = tf.constant_initializer(0.0)
biases = variable(
name='biases',
shape=[num_units_out,],
dtype=tf.float32,
initializer=bias_initializer,
trainable=True)
outputs = tf.nn.xw_plus_b(inputs, weights, biases)
return outputs
@tf.contrib.framework.add_arg_scope
def avg_pool(inputs, kernel_size, stride=2, padding='VALID', scope=None):
"""Wrapper around tf.nn.avg_pool."""
with tf.name_scope(scope, 'AvgPool', [inputs]):
kernel = stride_arr(kernel_size, kernel_size)
strides = stride_arr(stride, stride)
return tf.nn.avg_pool(
inputs,
ksize=kernel,
strides=strides,
padding=padding,
data_format='NHWC')
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Data utils for CIFAR-10 and CIFAR-100."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import cPickle
import os
import augmentation_transforms
import numpy as np
import policies as found_policies
import tensorflow as tf
# pylint:disable=logging-format-interpolation
class DataSet(object):
"""Dataset object that produces augmented training and eval data."""
def __init__(self, hparams):
self.hparams = hparams
self.epochs = 0
self.curr_train_index = 0
all_labels = []
self.good_policies = found_policies.good_policies()
# Determine how many databatched to load
num_data_batches_to_load = 5
total_batches_to_load = num_data_batches_to_load
train_batches_to_load = total_batches_to_load
assert hparams.train_size + hparams.validation_size <= 50000
if hparams.eval_test:
total_batches_to_load += 1
# Determine how many images we have loaded
total_dataset_size = 10000 * num_data_batches_to_load
train_dataset_size = total_dataset_size
if hparams.eval_test:
total_dataset_size += 10000
if hparams.dataset == 'cifar10':
all_data = np.empty((total_batches_to_load, 10000, 3072), dtype=np.uint8)
elif hparams.dataset == 'cifar100':
assert num_data_batches_to_load == 5
all_data = np.empty((1, 50000, 3072), dtype=np.uint8)
if hparams.eval_test:
test_data = np.empty((1, 10000, 3072), dtype=np.uint8)
if hparams.dataset == 'cifar10':
tf.logging.info('Cifar10')
datafiles = [
'data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4',
'data_batch_5']
datafiles = datafiles[:train_batches_to_load]
if hparams.eval_test:
datafiles.append('test_batch')
num_classes = 10
elif hparams.dataset == 'cifar100':
datafiles = ['train']
if hparams.eval_test:
datafiles.append('test')
num_classes = 100
else:
raise NotImplementedError('Unimplemented dataset: ', hparams.dataset)
if hparams.dataset != 'test':
for file_num, f in enumerate(datafiles):
d = unpickle(os.path.join(hparams.data_path, f))
if f == 'test':
test_data[0] = copy.deepcopy(d['data'])
all_data = np.concatenate([all_data, test_data], axis=1)
else:
all_data[file_num] = copy.deepcopy(d['data'])
if hparams.dataset == 'cifar10':
labels = np.array(d['labels'])
else:
labels = np.array(d['fine_labels'])
nsamples = len(labels)
for idx in range(nsamples):
all_labels.append(labels[idx])
all_data = all_data.reshape(total_dataset_size, 3072)
all_data = all_data.reshape(-1, 3, 32, 32)
all_data = all_data.transpose(0, 2, 3, 1).copy()
all_data = all_data / 255.0
mean = augmentation_transforms.MEANS
std = augmentation_transforms.STDS
tf.logging.info('mean:{} std: {}'.format(mean, std))
all_data = (all_data - mean) / std
all_labels = np.eye(num_classes)[np.array(all_labels, dtype=np.int32)]
assert len(all_data) == len(all_labels)
tf.logging.info(
'In CIFAR10 loader, number of images: {}'.format(len(all_data)))
# Break off test data
if hparams.eval_test:
self.test_images = all_data[train_dataset_size:]
self.test_labels = all_labels[train_dataset_size:]
# Shuffle the rest of the data
all_data = all_data[:train_dataset_size]
all_labels = all_labels[:train_dataset_size]
np.random.seed(0)
perm = np.arange(len(all_data))
np.random.shuffle(perm)
all_data = all_data[perm]
all_labels = all_labels[perm]
# Break into train and val
train_size, val_size = hparams.train_size, hparams.validation_size
assert 50000 >= train_size + val_size
self.train_images = all_data[:train_size]
self.train_labels = all_labels[:train_size]
self.val_images = all_data[train_size:train_size + val_size]
self.val_labels = all_labels[train_size:train_size + val_size]
self.num_train = self.train_images.shape[0]
def next_batch(self):
"""Return the next minibatch of augmented data."""
next_train_index = self.curr_train_index + self.hparams.batch_size
if next_train_index > self.num_train:
# Increase epoch number
epoch = self.epochs + 1
self.reset()
self.epochs = epoch
batched_data = (
self.train_images[self.curr_train_index:
self.curr_train_index + self.hparams.batch_size],
self.train_labels[self.curr_train_index:
self.curr_train_index + self.hparams.batch_size])
final_imgs = []
images, labels = batched_data
for data in images:
epoch_policy = self.good_policies[np.random.choice(
len(self.good_policies))]
final_img = augmentation_transforms.apply_policy(
epoch_policy, data)
final_img = augmentation_transforms.random_flip(
augmentation_transforms.zero_pad_and_crop(final_img, 4))
# Apply cutout
final_img = augmentation_transforms.cutout_numpy(final_img)
final_imgs.append(final_img)
batched_data = (np.array(final_imgs, np.float32), labels)
self.curr_train_index += self.hparams.batch_size
return batched_data
def reset(self):
"""Reset training data and index into the training data."""
self.epochs = 0
# Shuffle the training data
perm = np.arange(self.num_train)
np.random.shuffle(perm)
assert self.num_train == self.train_images.shape[
0], 'Error incorrect shuffling mask'
self.train_images = self.train_images[perm]
self.train_labels = self.train_labels[perm]
self.curr_train_index = 0
def unpickle(f):
tf.logging.info('loading file: {}'.format(f))
fo = tf.gfile.Open(f, 'r')
d = cPickle.load(fo)
fo.close()
return d
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper functions used for training AutoAugment models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
def setup_loss(logits, labels):
"""Returns the cross entropy for the given `logits` and `labels`."""
predictions = tf.nn.softmax(logits)
cost = tf.losses.softmax_cross_entropy(onehot_labels=labels,
logits=logits)
return predictions, cost
def decay_weights(cost, weight_decay_rate):
"""Calculates the loss for l2 weight decay and adds it to `cost`."""
costs = []
for var in tf.trainable_variables():
costs.append(tf.nn.l2_loss(var))
cost += tf.multiply(weight_decay_rate, tf.add_n(costs))
return cost
def eval_child_model(session, model, data_loader, mode):
"""Evaluates `model` on held out data depending on `mode`.
Args:
session: TensorFlow session the model will be run with.
model: TensorFlow model that will be evaluated.
data_loader: DataSet object that contains data that `model` will
evaluate.
mode: Will `model` either evaluate validation or test data.
Returns:
Accuracy of `model` when evaluated on the specified dataset.
Raises:
ValueError: if invalid dataset `mode` is specified.
"""
if mode == 'val':
images = data_loader.val_images
labels = data_loader.val_labels
elif mode == 'test':
images = data_loader.test_images
labels = data_loader.test_labels
else:
raise ValueError('Not valid eval mode')
assert len(images) == len(labels)
tf.logging.info('model.batch_size is {}'.format(model.batch_size))
assert len(images) % model.batch_size == 0
eval_batches = int(len(images) / model.batch_size)
for i in range(eval_batches):
eval_images = images[i * model.batch_size:(i + 1) * model.batch_size]
eval_labels = labels[i * model.batch_size:(i + 1) * model.batch_size]
_ = session.run(
model.eval_op,
feed_dict={
model.images: eval_images,
model.labels: eval_labels,
})
return session.run(model.accuracy)
def cosine_lr(learning_rate, epoch, iteration, batches_per_epoch, total_epochs):
"""Cosine Learning rate.
Args:
learning_rate: Initial learning rate.
epoch: Current epoch we are one. This is one based.
iteration: Current batch in this epoch.
batches_per_epoch: Batches per epoch.
total_epochs: Total epochs you are training for.
Returns:
The learning rate to be used for this current batch.
"""
t_total = total_epochs * batches_per_epoch
t_cur = float(epoch * batches_per_epoch + iteration)
return 0.5 * learning_rate * (1 + np.cos(np.pi * t_cur / t_total))
def get_lr(curr_epoch, hparams, iteration=None):
"""Returns the learning rate during training based on the current epoch."""
assert iteration is not None
batches_per_epoch = int(hparams.train_size / hparams.batch_size)
lr = cosine_lr(hparams.lr, curr_epoch, iteration, batches_per_epoch,
hparams.num_epochs)
return lr
def run_epoch_training(session, model, data_loader, curr_epoch):
"""Runs one epoch of training for the model passed in.
Args:
session: TensorFlow session the model will be run with.
model: TensorFlow model that will be evaluated.
data_loader: DataSet object that contains data that `model` will
evaluate.
curr_epoch: How many of epochs of training have been done so far.
Returns:
The accuracy of 'model' on the training set
"""
steps_per_epoch = int(model.hparams.train_size / model.hparams.batch_size)
tf.logging.info('steps per epoch: {}'.format(steps_per_epoch))
curr_step = session.run(model.global_step)
assert curr_step % steps_per_epoch == 0
# Get the current learning rate for the model based on the current epoch
curr_lr = get_lr(curr_epoch, model.hparams, iteration=0)
tf.logging.info('lr of {} for epoch {}'.format(curr_lr, curr_epoch))
for step in xrange(steps_per_epoch):
curr_lr = get_lr(curr_epoch, model.hparams, iteration=(step + 1))
# Update the lr rate variable to the current LR.
model.lr_rate_ph.load(curr_lr, session=session)
if step % 20 == 0:
tf.logging.info('Training {}/{}'.format(step, steps_per_epoch))
train_images, train_labels = data_loader.next_batch()
_, step, _ = session.run(
[model.train_op, model.global_step, model.eval_op],
feed_dict={
model.images: train_images,
model.labels: train_labels,
})
train_accuracy = session.run(model.accuracy)
tf.logging.info('Train accuracy: {}'.format(train_accuracy))
return train_accuracy
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
def good_policies():
"""AutoAugment policies found on Cifar."""
exp0_0 = [
[('Invert', 0.1, 7), ('Contrast', 0.2, 6)],
[('Rotate', 0.7, 2), ('TranslateX', 0.3, 9)],
[('Sharpness', 0.8, 1), ('Sharpness', 0.9, 3)],
[('ShearY', 0.5, 8), ('TranslateY', 0.7, 9)],
[('AutoContrast', 0.5, 8), ('Equalize', 0.9, 2)]]
exp0_1 = [
[('Solarize', 0.4, 5), ('AutoContrast', 0.9, 3)],
[('TranslateY', 0.9, 9), ('TranslateY', 0.7, 9)],
[('AutoContrast', 0.9, 2), ('Solarize', 0.8, 3)],
[('Equalize', 0.8, 8), ('Invert', 0.1, 3)],
[('TranslateY', 0.7, 9), ('AutoContrast', 0.9, 1)]]
exp0_2 = [
[('Solarize', 0.4, 5), ('AutoContrast', 0.0, 2)],
[('TranslateY', 0.7, 9), ('TranslateY', 0.7, 9)],
[('AutoContrast', 0.9, 0), ('Solarize', 0.4, 3)],
[('Equalize', 0.7, 5), ('Invert', 0.1, 3)],
[('TranslateY', 0.7, 9), ('TranslateY', 0.7, 9)]]
exp0_3 = [
[('Solarize', 0.4, 5), ('AutoContrast', 0.9, 1)],
[('TranslateY', 0.8, 9), ('TranslateY', 0.9, 9)],
[('AutoContrast', 0.8, 0), ('TranslateY', 0.7, 9)],
[('TranslateY', 0.2, 7), ('Color', 0.9, 6)],
[('Equalize', 0.7, 6), ('Color', 0.4, 9)]]
exp1_0 = [
[('ShearY', 0.2, 7), ('Posterize', 0.3, 7)],
[('Color', 0.4, 3), ('Brightness', 0.6, 7)],
[('Sharpness', 0.3, 9), ('Brightness', 0.7, 9)],
[('Equalize', 0.6, 5), ('Equalize', 0.5, 1)],
[('Contrast', 0.6, 7), ('Sharpness', 0.6, 5)]]
exp1_1 = [
[('Brightness', 0.3, 7), ('AutoContrast', 0.5, 8)],
[('AutoContrast', 0.9, 4), ('AutoContrast', 0.5, 6)],
[('Solarize', 0.3, 5), ('Equalize', 0.6, 5)],
[('TranslateY', 0.2, 4), ('Sharpness', 0.3, 3)],
[('Brightness', 0.0, 8), ('Color', 0.8, 8)]]
exp1_2 = [
[('Solarize', 0.2, 6), ('Color', 0.8, 6)],
[('Solarize', 0.2, 6), ('AutoContrast', 0.8, 1)],
[('Solarize', 0.4, 1), ('Equalize', 0.6, 5)],
[('Brightness', 0.0, 0), ('Solarize', 0.5, 2)],
[('AutoContrast', 0.9, 5), ('Brightness', 0.5, 3)]]
exp1_3 = [
[('Contrast', 0.7, 5), ('Brightness', 0.0, 2)],
[('Solarize', 0.2, 8), ('Solarize', 0.1, 5)],
[('Contrast', 0.5, 1), ('TranslateY', 0.2, 9)],
[('AutoContrast', 0.6, 5), ('TranslateY', 0.0, 9)],
[('AutoContrast', 0.9, 4), ('Equalize', 0.8, 4)]]
exp1_4 = [
[('Brightness', 0.0, 7), ('Equalize', 0.4, 7)],
[('Solarize', 0.2, 5), ('Equalize', 0.7, 5)],
[('Equalize', 0.6, 8), ('Color', 0.6, 2)],
[('Color', 0.3, 7), ('Color', 0.2, 4)],
[('AutoContrast', 0.5, 2), ('Solarize', 0.7, 2)]]
exp1_5 = [
[('AutoContrast', 0.2, 0), ('Equalize', 0.1, 0)],
[('ShearY', 0.6, 5), ('Equalize', 0.6, 5)],
[('Brightness', 0.9, 3), ('AutoContrast', 0.4, 1)],
[('Equalize', 0.8, 8), ('Equalize', 0.7, 7)],
[('Equalize', 0.7, 7), ('Solarize', 0.5, 0)]]
exp1_6 = [
[('Equalize', 0.8, 4), ('TranslateY', 0.8, 9)],
[('TranslateY', 0.8, 9), ('TranslateY', 0.6, 9)],
[('TranslateY', 0.9, 0), ('TranslateY', 0.5, 9)],
[('AutoContrast', 0.5, 3), ('Solarize', 0.3, 4)],
[('Solarize', 0.5, 3), ('Equalize', 0.4, 4)]]
exp2_0 = [
[('Color', 0.7, 7), ('TranslateX', 0.5, 8)],
[('Equalize', 0.3, 7), ('AutoContrast', 0.4, 8)],
[('TranslateY', 0.4, 3), ('Sharpness', 0.2, 6)],
[('Brightness', 0.9, 6), ('Color', 0.2, 8)],
[('Solarize', 0.5, 2), ('Invert', 0.0, 3)]]
exp2_1 = [
[('AutoContrast', 0.1, 5), ('Brightness', 0.0, 0)],
[('Cutout', 0.2, 4), ('Equalize', 0.1, 1)],
[('Equalize', 0.7, 7), ('AutoContrast', 0.6, 4)],
[('Color', 0.1, 8), ('ShearY', 0.2, 3)],
[('ShearY', 0.4, 2), ('Rotate', 0.7, 0)]]
exp2_2 = [
[('ShearY', 0.1, 3), ('AutoContrast', 0.9, 5)],
[('TranslateY', 0.3, 6), ('Cutout', 0.3, 3)],
[('Equalize', 0.5, 0), ('Solarize', 0.6, 6)],
[('AutoContrast', 0.3, 5), ('Rotate', 0.2, 7)],
[('Equalize', 0.8, 2), ('Invert', 0.4, 0)]]
exp2_3 = [
[('Equalize', 0.9, 5), ('Color', 0.7, 0)],
[('Equalize', 0.1, 1), ('ShearY', 0.1, 3)],
[('AutoContrast', 0.7, 3), ('Equalize', 0.7, 0)],
[('Brightness', 0.5, 1), ('Contrast', 0.1, 7)],
[('Contrast', 0.1, 4), ('Solarize', 0.6, 5)]]
exp2_4 = [
[('Solarize', 0.2, 3), ('ShearX', 0.0, 0)],
[('TranslateX', 0.3, 0), ('TranslateX', 0.6, 0)],
[('Equalize', 0.5, 9), ('TranslateY', 0.6, 7)],
[('ShearX', 0.1, 0), ('Sharpness', 0.5, 1)],
[('Equalize', 0.8, 6), ('Invert', 0.3, 6)]]
exp2_5 = [
[('AutoContrast', 0.3, 9), ('Cutout', 0.5, 3)],
[('ShearX', 0.4, 4), ('AutoContrast', 0.9, 2)],
[('ShearX', 0.0, 3), ('Posterize', 0.0, 3)],
[('Solarize', 0.4, 3), ('Color', 0.2, 4)],
[('Equalize', 0.1, 4), ('Equalize', 0.7, 6)]]
exp2_6 = [
[('Equalize', 0.3, 8), ('AutoContrast', 0.4, 3)],
[('Solarize', 0.6, 4), ('AutoContrast', 0.7, 6)],
[('AutoContrast', 0.2, 9), ('Brightness', 0.4, 8)],
[('Equalize', 0.1, 0), ('Equalize', 0.0, 6)],
[('Equalize', 0.8, 4), ('Equalize', 0.0, 4)]]
exp2_7 = [
[('Equalize', 0.5, 5), ('AutoContrast', 0.1, 2)],
[('Solarize', 0.5, 5), ('AutoContrast', 0.9, 5)],
[('AutoContrast', 0.6, 1), ('AutoContrast', 0.7, 8)],
[('Equalize', 0.2, 0), ('AutoContrast', 0.1, 2)],
[('Equalize', 0.6, 9), ('Equalize', 0.4, 4)]]
exp0s = exp0_0 + exp0_1 + exp0_2 + exp0_3
exp1s = exp1_0 + exp1_1 + exp1_2 + exp1_3 + exp1_4 + exp1_5 + exp1_6
exp2s = exp2_0 + exp2_1 + exp2_2 + exp2_3 + exp2_4 + exp2_5 + exp2_6 + exp2_7
return exp0s + exp1s + exp2s
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Builds the Shake-Shake Model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import custom_ops as ops
import tensorflow as tf
def round_int(x):
"""Rounds `x` and then converts to an int."""
return int(math.floor(x + 0.5))
def shortcut(x, output_filters, stride):
"""Applies strided avg pool or zero padding to make output_filters match x."""
num_filters = int(x.shape[3])
if stride == 2:
x = ops.avg_pool(x, 2, stride=stride, padding='SAME')
if num_filters != output_filters:
diff = output_filters - num_filters
assert diff > 0
# Zero padd diff zeros
padding = [[0, 0], [0, 0], [0, 0], [0, diff]]
x = tf.pad(x, padding)
return x
def calc_prob(curr_layer, total_layers, p_l):
"""Calculates drop prob depending on the current layer."""
return 1 - (float(curr_layer) / total_layers) * p_l
def bottleneck_layer(x, n, stride, prob, is_training, alpha, beta):
"""Bottleneck layer for shake drop model."""
assert alpha[1] > alpha[0]
assert beta[1] > beta[0]
with tf.variable_scope('bottleneck_{}'.format(prob)):
input_layer = x
x = ops.batch_norm(x, scope='bn_1_pre')
x = ops.conv2d(x, n, 1, scope='1x1_conv_contract')
x = ops.batch_norm(x, scope='bn_1_post')
x = tf.nn.relu(x)
x = ops.conv2d(x, n, 3, stride=stride, scope='3x3')
x = ops.batch_norm(x, scope='bn_2')
x = tf.nn.relu(x)
x = ops.conv2d(x, n * 4, 1, scope='1x1_conv_expand')
x = ops.batch_norm(x, scope='bn_3')
# Apply regularization here
# Sample bernoulli with prob
if is_training:
batch_size = tf.shape(x)[0]
bern_shape = [batch_size, 1, 1, 1]
random_tensor = prob
random_tensor += tf.random_uniform(bern_shape, dtype=tf.float32)
binary_tensor = tf.floor(random_tensor)
alpha_values = tf.random_uniform(
[batch_size, 1, 1, 1], minval=alpha[0], maxval=alpha[1],
dtype=tf.float32)
beta_values = tf.random_uniform(
[batch_size, 1, 1, 1], minval=beta[0], maxval=beta[1],
dtype=tf.float32)
rand_forward = (
binary_tensor + alpha_values - binary_tensor * alpha_values)
rand_backward = (
binary_tensor + beta_values - binary_tensor * beta_values)
x = x * rand_backward + tf.stop_gradient(x * rand_forward -
x * rand_backward)
else:
expected_alpha = (alpha[1] + alpha[0])/2
# prob is the expectation of the bernoulli variable
x = (prob + expected_alpha - prob * expected_alpha) * x
res = shortcut(input_layer, n * 4, stride)
return x + res
def build_shake_drop_model(images, num_classes, is_training):
"""Builds the PyramidNet Shake-Drop model.
Build the PyramidNet Shake-Drop model from https://arxiv.org/abs/1802.02375.
Args:
images: Tensor of images that will be fed into the Wide ResNet Model.
num_classes: Number of classed that the model needs to predict.
is_training: Is the model training or not.
Returns:
The logits of the PyramidNet Shake-Drop model.
"""
# ShakeDrop Hparams
p_l = 0.5
alpha_shake = [-1, 1]
beta_shake = [0, 1]
# PyramidNet Hparams
alpha = 200
depth = 272
# This is for the bottleneck architecture specifically
n = int((depth - 2) / 9)
start_channel = 16
add_channel = alpha / (3 * n)
# Building the models
x = images
x = ops.conv2d(x, 16, 3, scope='init_conv')
x = ops.batch_norm(x, scope='init_bn')
layer_num = 1
total_layers = n * 3
start_channel += add_channel
prob = calc_prob(layer_num, total_layers, p_l)
x = bottleneck_layer(
x, round_int(start_channel), 1, prob, is_training, alpha_shake,
beta_shake)
layer_num += 1
for _ in range(1, n):
start_channel += add_channel
prob = calc_prob(layer_num, total_layers, p_l)
x = bottleneck_layer(
x, round_int(start_channel), 1, prob, is_training, alpha_shake,
beta_shake)
layer_num += 1
start_channel += add_channel
prob = calc_prob(layer_num, total_layers, p_l)
x = bottleneck_layer(
x, round_int(start_channel), 2, prob, is_training, alpha_shake,
beta_shake)
layer_num += 1
for _ in range(1, n):
start_channel += add_channel
prob = calc_prob(layer_num, total_layers, p_l)
x = bottleneck_layer(
x, round_int(start_channel), 1, prob, is_training, alpha_shake,
beta_shake)
layer_num += 1
start_channel += add_channel
prob = calc_prob(layer_num, total_layers, p_l)
x = bottleneck_layer(
x, round_int(start_channel), 2, prob, is_training, alpha_shake,
beta_shake)
layer_num += 1
for _ in range(1, n):
start_channel += add_channel
prob = calc_prob(layer_num, total_layers, p_l)
x = bottleneck_layer(
x, round_int(start_channel), 1, prob, is_training, alpha_shake,
beta_shake)
layer_num += 1
assert layer_num - 1 == total_layers
x = ops.batch_norm(x, scope='final_bn')
x = tf.nn.relu(x)
x = ops.global_avg_pool(x)
# Fully connected
logits = ops.fc(x, num_classes)
return logits
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Builds the Shake-Shake Model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import custom_ops as ops
import tensorflow as tf
def _shake_shake_skip_connection(x, output_filters, stride):
"""Adds a residual connection to the filter x for the shake-shake model."""
curr_filters = int(x.shape[3])
if curr_filters == output_filters:
return x
stride_spec = ops.stride_arr(stride, stride)
# Skip path 1
path1 = tf.nn.avg_pool(
x, [1, 1, 1, 1], stride_spec, 'VALID', data_format='NHWC')
path1 = ops.conv2d(path1, int(output_filters / 2), 1, scope='path1_conv')
# Skip path 2
# First pad with 0's then crop
pad_arr = [[0, 0], [0, 1], [0, 1], [0, 0]]
path2 = tf.pad(x, pad_arr)[:, 1:, 1:, :]
concat_axis = 3
path2 = tf.nn.avg_pool(
path2, [1, 1, 1, 1], stride_spec, 'VALID', data_format='NHWC')
path2 = ops.conv2d(path2, int(output_filters / 2), 1, scope='path2_conv')
# Concat and apply BN
final_path = tf.concat(values=[path1, path2], axis=concat_axis)
final_path = ops.batch_norm(final_path, scope='final_path_bn')
return final_path
def _shake_shake_branch(x, output_filters, stride, rand_forward, rand_backward,
is_training):
"""Building a 2 branching convnet."""
x = tf.nn.relu(x)
x = ops.conv2d(x, output_filters, 3, stride=stride, scope='conv1')
x = ops.batch_norm(x, scope='bn1')
x = tf.nn.relu(x)
x = ops.conv2d(x, output_filters, 3, scope='conv2')
x = ops.batch_norm(x, scope='bn2')
if is_training:
x = x * rand_backward + tf.stop_gradient(x * rand_forward -
x * rand_backward)
else:
x *= 1.0 / 2
return x
def _shake_shake_block(x, output_filters, stride, is_training):
"""Builds a full shake-shake sub layer."""
batch_size = tf.shape(x)[0]
# Generate random numbers for scaling the branches
rand_forward = [
tf.random_uniform(
[batch_size, 1, 1, 1], minval=0, maxval=1, dtype=tf.float32)
for _ in range(2)
]
rand_backward = [
tf.random_uniform(
[batch_size, 1, 1, 1], minval=0, maxval=1, dtype=tf.float32)
for _ in range(2)
]
# Normalize so that all sum to 1
total_forward = tf.add_n(rand_forward)
total_backward = tf.add_n(rand_backward)
rand_forward = [samp / total_forward for samp in rand_forward]
rand_backward = [samp / total_backward for samp in rand_backward]
zipped_rand = zip(rand_forward, rand_backward)
branches = []
for branch, (r_forward, r_backward) in enumerate(zipped_rand):
with tf.variable_scope('branch_{}'.format(branch)):
b = _shake_shake_branch(x, output_filters, stride, r_forward, r_backward,
is_training)
branches.append(b)
res = _shake_shake_skip_connection(x, output_filters, stride)
return res + tf.add_n(branches)
def _shake_shake_layer(x, output_filters, num_blocks, stride,
is_training):
"""Builds many sub layers into one full layer."""
for block_num in range(num_blocks):
curr_stride = stride if (block_num == 0) else 1
with tf.variable_scope('layer_{}'.format(block_num)):
x = _shake_shake_block(x, output_filters, curr_stride,
is_training)
return x
def build_shake_shake_model(images, num_classes, hparams, is_training):
"""Builds the Shake-Shake model.
Build the Shake-Shake model from https://arxiv.org/abs/1705.07485.
Args:
images: Tensor of images that will be fed into the Wide ResNet Model.
num_classes: Number of classed that the model needs to predict.
hparams: tf.HParams object that contains additional hparams needed to
construct the model. In this case it is the `shake_shake_widen_factor`
that is used to determine how many filters the model has.
is_training: Is the model training or not.
Returns:
The logits of the Shake-Shake model.
"""
depth = 26
k = hparams.shake_shake_widen_factor # The widen factor
n = int((depth - 2) / 6)
x = images
x = ops.conv2d(x, 16, 3, scope='init_conv')
x = ops.batch_norm(x, scope='init_bn')
with tf.variable_scope('L1'):
x = _shake_shake_layer(x, 16 * k, n, 1, is_training)
with tf.variable_scope('L2'):
x = _shake_shake_layer(x, 32 * k, n, 2, is_training)
with tf.variable_scope('L3'):
x = _shake_shake_layer(x, 64 * k, n, 2, is_training)
x = tf.nn.relu(x)
x = ops.global_avg_pool(x)
# Fully connected
logits = ops.fc(x, num_classes)
return logits
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""AutoAugment Train/Eval module.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import os
import time
import custom_ops as ops
import data_utils
import helper_utils
import numpy as np
from shake_drop import build_shake_drop_model
from shake_shake import build_shake_shake_model
import tensorflow as tf
from wrn import build_wrn_model
tf.flags.DEFINE_string('model_name', 'wrn',
'wrn, shake_shake_32, shake_shake_96, shake_shake_112, '
'pyramid_net')
tf.flags.DEFINE_string('checkpoint_dir', '/tmp/training', 'Training Directory.')
tf.flags.DEFINE_string('data_path', '/tmp/data',
'Directory where dataset is located.')
tf.flags.DEFINE_string('dataset', 'cifar10',
'Dataset to train with. Either cifar10 or cifar100')
tf.flags.DEFINE_integer('use_cpu', 1, '1 if use CPU, else GPU.')
FLAGS = tf.flags.FLAGS
arg_scope = tf.contrib.framework.arg_scope
def setup_arg_scopes(is_training):
"""Sets up the argscopes that will be used when building an image model.
Args:
is_training: Is the model training or not.
Returns:
Arg scopes to be put around the model being constructed.
"""
batch_norm_decay = 0.9
batch_norm_epsilon = 1e-5
batch_norm_params = {
# Decay for the moving averages.
'decay': batch_norm_decay,
# epsilon to prevent 0s in variance.
'epsilon': batch_norm_epsilon,
'scale': True,
# collection containing the moving mean and moving variance.
'is_training': is_training,
}
scopes = []
scopes.append(arg_scope([ops.batch_norm], **batch_norm_params))
return scopes
def build_model(inputs, num_classes, is_training, hparams):
"""Constructs the vision model being trained/evaled.
Args:
inputs: input features/images being fed to the image model build built.
num_classes: number of output classes being predicted.
is_training: is the model training or not.
hparams: additional hyperparameters associated with the image model.
Returns:
The logits of the image model.
"""
scopes = setup_arg_scopes(is_training)
with contextlib.nested(*scopes):
if hparams.model_name == 'pyramid_net':
logits = build_shake_drop_model(
inputs, num_classes, is_training)
elif hparams.model_name == 'wrn':
logits = build_wrn_model(
inputs, num_classes, hparams.wrn_size)
elif hparams.model_name == 'shake_shake':
logits = build_shake_shake_model(
inputs, num_classes, hparams, is_training)
return logits
class CifarModel(object):
"""Builds an image model for Cifar10/Cifar100."""
def __init__(self, hparams):
self.hparams = hparams
def build(self, mode):
"""Construct the cifar model."""
assert mode in ['train', 'eval']
self.mode = mode
self._setup_misc(mode)
self._setup_images_and_labels()
self._build_graph(self.images, self.labels, mode)
self.init = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
def _setup_misc(self, mode):
"""Sets up miscellaneous in the cifar model constructor."""
self.lr_rate_ph = tf.Variable(0.0, name='lrn_rate', trainable=False)
self.reuse = None if (mode == 'train') else True
self.batch_size = self.hparams.batch_size
if mode == 'eval':
self.batch_size = 25
def _setup_images_and_labels(self):
"""Sets up image and label placeholders for the cifar model."""
if FLAGS.dataset == 'cifar10':
self.num_classes = 10
else:
self.num_classes = 100
self.images = tf.placeholder(tf.float32, [self.batch_size, 32, 32, 3])
self.labels = tf.placeholder(tf.float32,
[self.batch_size, self.num_classes])
def assign_epoch(self, session, epoch_value):
session.run(self._epoch_update, feed_dict={self._new_epoch: epoch_value})
def _build_graph(self, images, labels, mode):
"""Constructs the TF graph for the cifar model.
Args:
images: A 4-D image Tensor
labels: A 2-D labels Tensor.
mode: string indicating training mode ( e.g., 'train', 'valid', 'test').
"""
is_training = 'train' in mode
if is_training:
self.global_step = tf.train.get_or_create_global_step()
logits = build_model(
images,
self.num_classes,
is_training,
self.hparams)
self.predictions, self.cost = helper_utils.setup_loss(
logits, labels)
self.accuracy, self.eval_op = tf.metrics.accuracy(
tf.argmax(labels, 1), tf.argmax(self.predictions, 1))
self._calc_num_trainable_params()
# Adds L2 weight decay to the cost
self.cost = helper_utils.decay_weights(self.cost,
self.hparams.weight_decay_rate)
if is_training:
self._build_train_op()
# Setup checkpointing for this child model
# Keep 2 or more checkpoints around during training.
with tf.device('/cpu:0'):
self.saver = tf.train.Saver(max_to_keep=2)
self.init = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
def _calc_num_trainable_params(self):
self.num_trainable_params = np.sum([
np.prod(var.get_shape().as_list()) for var in tf.trainable_variables()
])
tf.logging.info('number of trainable params: {}'.format(
self.num_trainable_params))
def _build_train_op(self):
"""Builds the train op for the cifar model."""
hparams = self.hparams
tvars = tf.trainable_variables()
grads = tf.gradients(self.cost, tvars)
if hparams.gradient_clipping_by_global_norm > 0.0:
grads, norm = tf.clip_by_global_norm(
grads, hparams.gradient_clipping_by_global_norm)
tf.summary.scalar('grad_norm', norm)
# Setup the initial learning rate
initial_lr = self.lr_rate_ph
optimizer = tf.train.MomentumOptimizer(
initial_lr,
0.9,
use_nesterov=True)
self.optimizer = optimizer
apply_op = optimizer.apply_gradients(
zip(grads, tvars), global_step=self.global_step, name='train_step')
train_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies([apply_op]):
self.train_op = tf.group(*train_ops)
class CifarModelTrainer(object):
"""Trains an instance of the CifarModel class."""
def __init__(self, hparams):
self._session = None
self.hparams = hparams
self.model_dir = os.path.join(FLAGS.checkpoint_dir, 'model')
self.log_dir = os.path.join(FLAGS.checkpoint_dir, 'log')
# Set the random seed to be sure the same validation set
# is used for each model
np.random.seed(0)
self.data_loader = data_utils.DataSet(hparams)
np.random.seed() # Put the random seed back to random
self.data_loader.reset()
def save_model(self, step=None):
"""Dumps model into the backup_dir.
Args:
step: If provided, creates a checkpoint with the given step
number, instead of overwriting the existing checkpoints.
"""
model_save_name = os.path.join(self.model_dir, 'model.ckpt')
if not tf.gfile.IsDirectory(self.model_dir):
tf.gfile.MakeDirs(self.model_dir)
self.saver.save(self.session, model_save_name, global_step=step)
tf.logging.info('Saved child model')
def extract_model_spec(self):
"""Loads a checkpoint with the architecture structure stored in the name."""
checkpoint_path = tf.train.latest_checkpoint(self.model_dir)
if checkpoint_path is not None:
self.saver.restore(self.session, checkpoint_path)
tf.logging.info('Loaded child model checkpoint from %s',
checkpoint_path)
else:
self.save_model(step=0)
def eval_child_model(self, model, data_loader, mode):
"""Evaluate the child model.
Args:
model: image model that will be evaluated.
data_loader: dataset object to extract eval data from.
mode: will the model be evalled on train, val or test.
Returns:
Accuracy of the model on the specified dataset.
"""
tf.logging.info('Evaluating child model in mode %s', mode)
while True:
try:
with self._new_session(model):
accuracy = helper_utils.eval_child_model(
self.session,
model,
data_loader,
mode)
tf.logging.info('Eval child model accuracy: {}'.format(accuracy))
# If epoch trained without raising the below errors, break
# from loop.
break
except (tf.errors.AbortedError, tf.errors.UnavailableError) as e:
tf.logging.info('Retryable error caught: %s. Retrying.', e)
return accuracy
@contextlib.contextmanager
def _new_session(self, m):
"""Creates a new session for model m."""
# Create a new session for this model, initialize
# variables, and save / restore from
# checkpoint.
self._session = tf.Session(
'',
config=tf.ConfigProto(
allow_soft_placement=True, log_device_placement=False))
self.session.run(m.init)
# Load in a previous checkpoint, or save this one
self.extract_model_spec()
try:
yield
finally:
tf.Session.reset('')
self._session = None
def _build_models(self):
"""Builds the image models for train and eval."""
# Determine if we should build the train and eval model. When using
# distributed training we only want to build one or the other and not both.
with tf.variable_scope('model', use_resource=False):
m = CifarModel(self.hparams)
m.build('train')
self._num_trainable_params = m.num_trainable_params
self._saver = m.saver
with tf.variable_scope('model', reuse=True, use_resource=False):
meval = CifarModel(self.hparams)
meval.build('eval')
return m, meval
def _calc_starting_epoch(self, m):
"""Calculates the starting epoch for model m based on global step."""
hparams = self.hparams
batch_size = hparams.batch_size
steps_per_epoch = int(hparams.train_size / batch_size)
with self._new_session(m):
curr_step = self.session.run(m.global_step)
total_steps = steps_per_epoch * hparams.num_epochs
epochs_left = (total_steps - curr_step) // steps_per_epoch
starting_epoch = hparams.num_epochs - epochs_left
return starting_epoch
def _run_training_loop(self, m, curr_epoch):
"""Trains the cifar model `m` for one epoch."""
start_time = time.time()
while True:
try:
with self._new_session(m):
train_accuracy = helper_utils.run_epoch_training(
self.session, m, self.data_loader, curr_epoch)
tf.logging.info('Saving model after epoch')
self.save_model(step=curr_epoch)
break
except (tf.errors.AbortedError, tf.errors.UnavailableError) as e:
tf.logging.info('Retryable error caught: %s. Retrying.', e)
tf.logging.info('Finished epoch: {}'.format(curr_epoch))
tf.logging.info('Epoch time(min): {}'.format(
(time.time() - start_time) / 60.0))
return train_accuracy
def _compute_final_accuracies(self, meval):
"""Run once training is finished to compute final val/test accuracies."""
valid_accuracy = self.eval_child_model(meval, self.data_loader, 'val')
if self.hparams.eval_test:
test_accuracy = self.eval_child_model(meval, self.data_loader, 'test')
else:
test_accuracy = 0
tf.logging.info('Test Accuracy: {}'.format(test_accuracy))
return valid_accuracy, test_accuracy
def run_model(self):
"""Trains and evalutes the image model."""
hparams = self.hparams
# Build the child graph
with tf.Graph().as_default(), tf.device(
'/cpu:0' if FLAGS.use_cpu else '/gpu:0'):
m, meval = self._build_models()
# Figure out what epoch we are on
starting_epoch = self._calc_starting_epoch(m)
# Run the validation error right at the beginning
valid_accuracy = self.eval_child_model(
meval, self.data_loader, 'val')
tf.logging.info('Before Training Epoch: {} Val Acc: {}'.format(
starting_epoch, valid_accuracy))
training_accuracy = None
for curr_epoch in xrange(starting_epoch, hparams.num_epochs):
# Run one training epoch
training_accuracy = self._run_training_loop(m, curr_epoch)
valid_accuracy = self.eval_child_model(
meval, self.data_loader, 'val')
tf.logging.info('Epoch: {} Valid Acc: {}'.format(
curr_epoch, valid_accuracy))
valid_accuracy, test_accuracy = self._compute_final_accuracies(
meval)
tf.logging.info(
'Train Acc: {} Valid Acc: {} Test Acc: {}'.format(
training_accuracy, valid_accuracy, test_accuracy))
@property
def saver(self):
return self._saver
@property
def session(self):
return self._session
@property
def num_trainable_params(self):
return self._num_trainable_params
def main(_):
if FLAGS.dataset not in ['cifar10', 'cifar100']:
raise ValueError('Invalid dataset: %s' % FLAGS.dataset)
hparams = tf.contrib.training.HParams(
train_size=50000,
validation_size=0,
eval_test=1,
dataset=FLAGS.dataset,
data_path=FLAGS.data_path,
batch_size=128,
gradient_clipping_by_global_norm=5.0)
if FLAGS.model_name == 'wrn':
hparams.add_hparam('model_name', 'wrn')
hparams.add_hparam('num_epochs', 200)
hparams.add_hparam('wrn_size', 160)
hparams.add_hparam('lr', 0.1)
hparams.add_hparam('weight_decay_rate', 5e-4)
elif FLAGS.model_name == 'shake_shake_32':
hparams.add_hparam('model_name', 'shake_shake')
hparams.add_hparam('num_epochs', 1800)
hparams.add_hparam('shake_shake_widen_factor', 2)
hparams.add_hparam('lr', 0.01)
hparams.add_hparam('weight_decay_rate', 0.001)
elif FLAGS.model_name == 'shake_shake_96':
hparams.add_hparam('model_name', 'shake_shake')
hparams.add_hparam('num_epochs', 1800)
hparams.add_hparam('shake_shake_widen_factor', 6)
hparams.add_hparam('lr', 0.01)
hparams.add_hparam('weight_decay_rate', 0.001)
elif FLAGS.model_name == 'shake_shake_112':
hparams.add_hparam('model_name', 'shake_shake')
hparams.add_hparam('num_epochs', 1800)
hparams.add_hparam('shake_shake_widen_factor', 7)
hparams.add_hparam('lr', 0.01)
hparams.add_hparam('weight_decay_rate', 0.001)
elif FLAGS.model_name == 'pyramid_net':
hparams.add_hparam('model_name', 'pyramid_net')
hparams.add_hparam('num_epochs', 1800)
hparams.add_hparam('lr', 0.05)
hparams.add_hparam('weight_decay_rate', 5e-5)
hparams.batch_size = 64
else:
raise ValueError('Not Valid Model Name: %s' % FLAGS.model_name)
cifar_trainer = CifarModelTrainer(hparams)
cifar_trainer.run_model()
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run()
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Builds the Wide-ResNet Model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import custom_ops as ops
import numpy as np
import tensorflow as tf
def residual_block(
x, in_filter, out_filter, stride, activate_before_residual=False):
"""Adds residual connection to `x` in addition to applying BN->ReLU->3x3 Conv.
Args:
x: Tensor that is the output of the previous layer in the model.
in_filter: Number of filters `x` has.
out_filter: Number of filters that the output of this layer will have.
stride: Integer that specified what stride should be applied to `x`.
activate_before_residual: Boolean on whether a BN->ReLU should be applied
to x before the convolution is applied.
Returns:
A Tensor that is the result of applying two sequences of BN->ReLU->3x3 Conv
and then adding that Tensor to `x`.
"""
if activate_before_residual: # Pass up RELU and BN activation for resnet
with tf.variable_scope('shared_activation'):
x = ops.batch_norm(x, scope='init_bn')
x = tf.nn.relu(x)
orig_x = x
else:
orig_x = x
block_x = x
if not activate_before_residual:
with tf.variable_scope('residual_only_activation'):
block_x = ops.batch_norm(block_x, scope='init_bn')
block_x = tf.nn.relu(block_x)
with tf.variable_scope('sub1'):
block_x = ops.conv2d(
block_x, out_filter, 3, stride=stride, scope='conv1')
with tf.variable_scope('sub2'):
block_x = ops.batch_norm(block_x, scope='bn2')
block_x = tf.nn.relu(block_x)
block_x = ops.conv2d(
block_x, out_filter, 3, stride=1, scope='conv2')
with tf.variable_scope(
'sub_add'): # If number of filters do not agree then zero pad them
if in_filter != out_filter:
orig_x = ops.avg_pool(orig_x, stride, stride)
orig_x = ops.zero_pad(orig_x, in_filter, out_filter)
x = orig_x + block_x
return x
def _res_add(in_filter, out_filter, stride, x, orig_x):
"""Adds `x` with `orig_x`, both of which are layers in the model.
Args:
in_filter: Number of filters in `orig_x`.
out_filter: Number of filters in `x`.
stride: Integer specifying the stide that should be applied `orig_x`.
x: Tensor that is the output of the previous layer.
orig_x: Tensor that is the output of an earlier layer in the network.
Returns:
A Tensor that is the result of `x` and `orig_x` being added after
zero padding and striding are applied to `orig_x` to get the shapes
to match.
"""
if in_filter != out_filter:
orig_x = ops.avg_pool(orig_x, stride, stride)
orig_x = ops.zero_pad(orig_x, in_filter, out_filter)
x = x + orig_x
orig_x = x
return x, orig_x
def build_wrn_model(images, num_classes, wrn_size):
"""Builds the WRN model.
Build the Wide ResNet model from https://arxiv.org/abs/1605.07146.
Args:
images: Tensor of images that will be fed into the Wide ResNet Model.
num_classes: Number of classed that the model needs to predict.
wrn_size: Parameter that scales the number of filters in the Wide ResNet
model.
Returns:
The logits of the Wide ResNet model.
"""
kernel_size = wrn_size
filter_size = 3
num_blocks_per_resnet = 4
filters = [
min(kernel_size, 16), kernel_size, kernel_size * 2, kernel_size * 4
]
strides = [1, 2, 2] # stride for each resblock
# Run the first conv
with tf.variable_scope('init'):
x = images
output_filters = filters[0]
x = ops.conv2d(x, output_filters, filter_size, scope='init_conv')
first_x = x # Res from the beginning
orig_x = x # Res from previous block
for block_num in range(1, 4):
with tf.variable_scope('unit_{}_0'.format(block_num)):
activate_before_residual = True if block_num == 1 else False
x = residual_block(
x,
filters[block_num - 1],
filters[block_num],
strides[block_num - 1],
activate_before_residual=activate_before_residual)
for i in range(1, num_blocks_per_resnet):
with tf.variable_scope('unit_{}_{}'.format(block_num, i)):
x = residual_block(
x,
filters[block_num],
filters[block_num],
1,
activate_before_residual=False)
x, orig_x = _res_add(filters[block_num - 1], filters[block_num],
strides[block_num - 1], x, orig_x)
final_stride_val = np.prod(strides)
x, _ = _res_add(filters[0], filters[3], final_stride_val, x, first_x)
with tf.variable_scope('unit_last'):
x = ops.batch_norm(x, scope='final_bn')
x = tf.nn.relu(x)
x = ops.global_avg_pool(x)
logits = ops.fc(x, num_classes)
return logits
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment