Commit 6406a152 authored by Martin Wicke's avatar Martin Wicke
Browse files

Merge pull request #35 from daviddao/master

Adding Spatial Transformer to Models
parents 30004531 41c52d60
...@@ -7,3 +7,4 @@ ...@@ -7,3 +7,4 @@
# The email address is not required for organizations. # The email address is not required for organizations.
Google Inc. Google Inc.
David Dao <daviddao@broad.mit.edu>
# Spatial Transformer Network
The Spatial Transformer Network [1] allows the spatial manipulation of data within the network.
<div align="center">
<img width="600px" src="http://i.imgur.com/ExGDVul.png"><br><br>
</div>
### API
A Spatial Transformer Network implemented in Tensorflow 0.7 and based on [2].
#### How to use
<div align="center">
<img src="http://i.imgur.com/gfqLV3f.png"><br><br>
</div>
```python
transformer(U, theta, downsample_factor=1)
```
#### Parameters
U : float
The output of a convolutional net should have the
shape [num_batch, height, width, num_channels].
theta: float
The output of the
localisation network should be [num_batch, 6].
downsample_factor : float
A value of 1 will keep the original size of the image
Values larger than 1 will downsample the image.
Values below 1 will upsample the image
example image: height = 100, width = 200
downsample_factor = 2
output image will then be 50, 100
#### Notes
To initialize the network to the identity transform init ``theta`` to :
```python
identity = np.array([[1., 0., 0.],
[0., 1., 0.]])
identity = identity.flatten()
theta = tf.Variable(initial_value=identity)
```
#### Experiments
<div align="center">
<img width="600px" src="http://i.imgur.com/HtCBYk2.png"><br><br>
</div>
We used cluttered MNIST. Left column are the input images, right are the attended parts of the image by an STN.
All experiments were run in Tensorflow 0.7.
### References
[1] Jaderberg, Max, et al. "Spatial Transformer Networks." arXiv preprint arXiv:1506.02025 (2015)
[2] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
from spatial_transformer import transformer
from scipy import ndimage
import numpy as np
import matplotlib.pyplot as plt
from tf_utils import conv2d, linear, weight_variable, bias_variable, dense_to_one_hot
# %% Load data
mnist_cluttered = np.load('./data/mnist_sequence1_sample_5distortions5x5.npz')
X_train = mnist_cluttered['X_train']
y_train = mnist_cluttered['y_train']
X_valid = mnist_cluttered['X_valid']
y_valid = mnist_cluttered['y_valid']
X_test = mnist_cluttered['X_test']
y_test = mnist_cluttered['y_test']
# % turn from dense to one hot representation
Y_train = dense_to_one_hot(y_train, n_classes=10)
Y_valid = dense_to_one_hot(y_valid, n_classes=10)
Y_test = dense_to_one_hot(y_test, n_classes=10)
# %% Graph representation of our network
# %% Placeholders for 40x40 resolution
x = tf.placeholder(tf.float32, [None, 1600])
y = tf.placeholder(tf.float32, [None, 10])
# %% Since x is currently [batch, height*width], we need to reshape to a
# 4-D tensor to use it in a convolutional graph. If one component of
# `shape` is the special value -1, the size of that dimension is
# computed so that the total size remains constant. Since we haven't
# defined the batch dimension's shape yet, we use -1 to denote this
# dimension should not change size.
x_tensor = tf.reshape(x, [-1, 40, 40, 1])
# %% We'll setup the two-layer localisation network to figure out the parameters for an affine transformation of the input
# %% Create variables for fully connected layer
W_fc_loc1 = weight_variable([1600, 20])
b_fc_loc1 = bias_variable([20])
W_fc_loc2 = weight_variable([20, 6])
initial = np.array([[1.,0, 0],[0,1.,0]]) # Use identity transformation as starting point
initial = initial.astype('float32')
initial = initial.flatten()
b_fc_loc2 = tf.Variable(initial_value=initial, name='b_fc_loc2')
# %% Define the two layer localisation network
h_fc_loc1 = tf.nn.tanh(tf.matmul(x, W_fc_loc1) + b_fc_loc1)
# %% We can add dropout for regularizing and to reduce overfitting like so:
keep_prob = tf.placeholder(tf.float32)
h_fc_loc1_drop = tf.nn.dropout(h_fc_loc1, keep_prob)
# %% Second layer
h_fc_loc2 = tf.nn.tanh(tf.matmul(h_fc_loc1_drop, W_fc_loc2) + b_fc_loc2)
# %% We'll create a spatial transformer module to identify discriminative patches
h_trans = transformer(x_tensor, h_fc_loc2, downsample_factor=1)
# %% We'll setup the first convolutional layer
# Weight matrix is [height x width x input_channels x output_channels]
filter_size = 3
n_filters_1 = 16
W_conv1 = weight_variable([filter_size, filter_size, 1, n_filters_1])
# %% Bias is [output_channels]
b_conv1 = bias_variable([n_filters_1])
# %% Now we can build a graph which does the first layer of convolution:
# we define our stride as batch x height x width x channels
# instead of pooling, we use strides of 2 and more layers
# with smaller filters.
h_conv1 = tf.nn.relu(
tf.nn.conv2d(input=h_trans,
filter=W_conv1,
strides=[1, 2, 2, 1],
padding='SAME') +
b_conv1)
# %% And just like the first layer, add additional layers to create
# a deep net
n_filters_2 = 16
W_conv2 = weight_variable([filter_size, filter_size, n_filters_1, n_filters_2])
b_conv2 = bias_variable([n_filters_2])
h_conv2 = tf.nn.relu(
tf.nn.conv2d(input=h_conv1,
filter=W_conv2,
strides=[1, 2, 2, 1],
padding='SAME') +
b_conv2)
# %% We'll now reshape so we can connect to a fully-connected layer:
h_conv2_flat = tf.reshape(h_conv2, [-1, 10 * 10 * n_filters_2])
# %% Create a fully-connected layer:
n_fc = 1024
W_fc1 = weight_variable([10 * 10 * n_filters_2, n_fc])
b_fc1 = bias_variable([n_fc])
h_fc1 = tf.nn.relu(tf.matmul(h_conv2_flat, W_fc1) + b_fc1)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
# %% And finally our softmax layer:
W_fc2 = weight_variable([n_fc, 10])
b_fc2 = bias_variable([10])
y_pred = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
# %% Define loss/eval/training functions
cross_entropy = -tf.reduce_sum(y * tf.log(y_pred))
opt = tf.train.AdamOptimizer()
optimizer = opt.minimize(cross_entropy)
grads = opt.compute_gradients(cross_entropy, [b_fc_loc2])
# %% Monitor accuracy
correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))
# %% We now create a new session to actually perform the initialization the
# variables:
sess = tf.Session()
sess.run(tf.initialize_all_variables())
# %% We'll now train in minibatches and report accuracy, loss:
iter_per_epoch = 100
n_epochs = 500
train_size = 10000
indices = np.linspace(0,10000 - 1,iter_per_epoch)
indices = indices.astype('int')
for epoch_i in range(n_epochs):
for iter_i in range(iter_per_epoch - 1):
batch_xs = X_train[indices[iter_i]:indices[iter_i+1]]
batch_ys = Y_train[indices[iter_i]:indices[iter_i+1]]
if iter_i % 10 == 0:
loss = sess.run(cross_entropy,
feed_dict={
x: batch_xs,
y: batch_ys,
keep_prob: 1.0
})
print('Iteration: ' + str(iter_i) + ' Loss: ' + str(loss))
sess.run(optimizer, feed_dict={
x: batch_xs, y: batch_ys, keep_prob: 0.8})
print('Accuracy: ' + str(sess.run(accuracy,
feed_dict={
x: X_valid,
y: Y_valid,
keep_prob: 1.0
})))
#theta = sess.run(h_fc_loc2, feed_dict={
# x: batch_xs, keep_prob: 1.0})
#print(theta[0])
### How to get the data
#### Cluttered MNIST
The cluttered MNIST dataset can be found here [1] or can be generated via [2].
Settings used for `cluttered_mnist.py` :
```python
ORG_SHP = [28, 28]
OUT_SHP = [40, 40]
NUM_DISTORTIONS = 8
dist_size = (5, 5)
```
[1] https://github.com/daviddao/spatial-transformer-tensorflow
[2] https://github.com/skaae/recurrent-spatial-transformer-code/blob/master/MNIST_SEQUENCE/create_mnist_sequence.py
\ No newline at end of file
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
from spatial_transformer import transformer
from scipy import ndimage
import numpy as np
import matplotlib.pyplot as plt
from tf_utils import conv2d, linear, weight_variable, bias_variable
# %% Create a batch of three images (1600 x 1200)
# %% Image retrieved from https://raw.githubusercontent.com/skaae/transformer_network/master/cat.jpg
im = ndimage.imread('cat.jpg')
im = im / 255.
im = im.reshape(1, 1200, 1600, 3)
im = im.astype('float32')
# %% Simulate batch
batch = np.append(im, im, axis=0)
batch = np.append(batch, im, axis=0)
num_batch = 3
x = tf.placeholder(tf.float32, [None, 1200, 1600, 3])
x = tf.cast(batch,'float32')
# %% Create localisation network and convolutional layer
with tf.variable_scope('spatial_transformer_0'):
# %% Create a fully-connected layer with 6 output nodes
n_fc = 6
W_fc1 = tf.Variable(tf.zeros([1200 * 1600 * 3, n_fc]), name='W_fc1')
# %% Zoom into the image
initial = np.array([[0.5,0, 0],[0,0.5,0]])
initial = initial.astype('float32')
initial = initial.flatten()
b_fc1 = tf.Variable(initial_value=initial, name='b_fc1')
h_fc1 = tf.matmul(tf.zeros([num_batch ,1200 * 1600 * 3]), W_fc1) + b_fc1
h_trans = transformer(x, h_fc1, downsample_factor=2)
# %% Run session
sess = tf.Session()
sess.run(tf.initialize_all_variables())
y = sess.run(h_trans, feed_dict={x: batch})
# plt.imshow(y[0])
\ No newline at end of file
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
def transformer(U, theta, downsample_factor=1, name='SpatialTransformer', **kwargs):
"""Spatial Transformer Layer
Implements a spatial transformer layer as described in [1]_.
Based on [2]_ and edited by David Dao for Tensorflow.
Parameters
----------
U : float
The output of a convolutional net should have the
shape [num_batch, height, width, num_channels].
theta: float
The output of the
localisation network should be [num_batch, 6].
downsample_factor : float
A value of 1 will keep the original size of the image
Values larger than 1 will downsample the image.
Values below 1 will upsample the image
example image: height = 100, width = 200
downsample_factor = 2
output image will then be 50, 100
References
----------
.. [1] Spatial Transformer Networks
Max Jaderberg, Karen Simonyan, Andrew Zisserman, Koray Kavukcuoglu
Submitted on 5 Jun 2015
.. [2] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py
Notes
-----
To initialize the network to the identity transform init
``theta`` to :
identity = np.array([[1., 0., 0.],
[0., 1., 0.]])
identity = identity.flatten()
theta = tf.Variable(initial_value=identity)
"""
def _repeat(x, n_repeats):
with tf.variable_scope('_repeat'):
rep = tf.transpose(tf.expand_dims(tf.ones(shape=tf.pack([n_repeats,])),1),[1,0])
rep = tf.cast(rep, 'int32')
x = tf.matmul(tf.reshape(x,(-1, 1)), rep)
return tf.reshape(x,[-1])
def _interpolate(im, x, y, downsample_factor):
with tf.variable_scope('_interpolate'):
# constants
num_batch = tf.shape(im)[0]
height = tf.shape(im)[1]
width = tf.shape(im)[2]
channels = tf.shape(im)[3]
x = tf.cast(x, 'float32')
y = tf.cast(y, 'float32')
height_f = tf.cast(height, 'float32')
width_f = tf.cast(width, 'float32')
out_height = tf.cast(height_f // downsample_factor, 'int32')
out_width = tf.cast(width_f // downsample_factor, 'int32')
zero = tf.zeros([], dtype='int32')
max_y = tf.cast(tf.shape(im)[1] - 1, 'int32')
max_x = tf.cast(tf.shape(im)[2] - 1, 'int32')
# scale indices from [-1, 1] to [0, width/height]
x = (x + 1.0)*(width_f) / 2.0
y = (y + 1.0)*(height_f) / 2.0
# do sampling
x0 = tf.cast(tf.floor(x), 'int32')
x1 = x0 + 1
y0 = tf.cast(tf.floor(y), 'int32')
y1 = y0 + 1
x0 = tf.clip_by_value(x0, zero, max_x)
x1 = tf.clip_by_value(x1, zero, max_x)
y0 = tf.clip_by_value(y0, zero, max_y)
y1 = tf.clip_by_value(y1, zero, max_y)
dim2 = width
dim1 = width*height
base = _repeat(tf.range(num_batch)*dim1, out_height*out_width)
base_y0 = base + y0*dim2
base_y1 = base + y1*dim2
idx_a = base_y0 + x0
idx_b = base_y1 + x0
idx_c = base_y0 + x1
idx_d = base_y1 + x1
# use indices to lookup pixels in the flat image and restore channels dim
im_flat = tf.reshape(im,tf.pack([-1, channels]))
im_flat = tf.cast(im_flat, 'float32')
Ia = tf.gather(im_flat, idx_a)
Ib = tf.gather(im_flat, idx_b)
Ic = tf.gather(im_flat, idx_c)
Id = tf.gather(im_flat, idx_d)
# and finally calculate interpolated values
x0_f = tf.cast(x0, 'float32')
x1_f = tf.cast(x1, 'float32')
y0_f = tf.cast(y0, 'float32')
y1_f = tf.cast(y1, 'float32')
wa = tf.expand_dims(((x1_f-x) * (y1_f-y)),1)
wb = tf.expand_dims(((x1_f-x) * (y-y0_f)),1)
wc = tf.expand_dims(((x-x0_f) * (y1_f-y)),1)
wd = tf.expand_dims(((x-x0_f) * (y-y0_f)),1)
output = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])
return output
def _meshgrid(height, width):
with tf.variable_scope('_meshgrid'):
# This should be equivalent to:
# x_t, y_t = np.meshgrid(np.linspace(-1, 1, width),
# np.linspace(-1, 1, height))
# ones = np.ones(np.prod(x_t.shape))
# grid = np.vstack([x_t.flatten(), y_t.flatten(), ones])
x_t = tf.matmul(tf.ones(shape=tf.pack([height, 1])),
tf.transpose(tf.expand_dims(tf.linspace(-1.0, 1.0, width),1),[1,0]))
y_t = tf.matmul(tf.expand_dims(tf.linspace(-1.0, 1.0, height),1),
tf.ones(shape=tf.pack([1, width])))
x_t_flat = tf.reshape(x_t,(1, -1))
y_t_flat = tf.reshape(y_t,(1, -1))
ones = tf.ones_like(x_t_flat)
grid = tf.concat(0, [x_t_flat, y_t_flat, ones])
return grid
def _transform(theta, input_dim, downsample_factor):
with tf.variable_scope('_transform'):
num_batch = tf.shape(input_dim)[0]
height = tf.shape(input_dim)[1]
width = tf.shape(input_dim)[2]
num_channels = tf.shape(input_dim)[3]
theta = tf.reshape(theta, (-1, 2, 3))
theta = tf.cast(theta, 'float32')
# grid of (x_t, y_t, 1), eq (1) in ref [1]
height_f = tf.cast(height, 'float32')
width_f = tf.cast(width, 'float32')
out_height = tf.cast(height_f // downsample_factor, 'int32')
out_width = tf.cast(width_f // downsample_factor, 'int32')
grid = _meshgrid(out_height, out_width)
grid = tf.expand_dims(grid,0)
grid = tf.reshape(grid,[-1])
grid = tf.tile(grid,tf.pack([num_batch]))
grid = tf.reshape(grid,tf.pack([num_batch, 3, -1]))
# Transform A x (x_t, y_t, 1)^T -> (x_s, y_s)
T_g = tf.batch_matmul(theta, grid)
x_s = tf.slice(T_g, [0,0,0], [-1,1,-1])
y_s = tf.slice(T_g, [0,1,0], [-1,1,-1])
x_s_flat = tf.reshape(x_s,[-1])
y_s_flat = tf.reshape(y_s,[-1])
input_transformed = _interpolate(
input_dim, x_s_flat, y_s_flat,
downsample_factor)
output = tf.reshape(input_transformed, tf.pack([num_batch, out_height, out_width, num_channels]))
return output
with tf.variable_scope(name):
output = _transform(theta, U, downsample_factor)
return output
\ No newline at end of file
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# %% Borrowed utils from here: https://github.com/pkmital/tensorflow_tutorials/
import tensorflow as tf
import numpy as np
def conv2d(x, n_filters,
k_h=5, k_w=5,
stride_h=2, stride_w=2,
stddev=0.02,
activation=lambda x: x,
bias=True,
padding='SAME',
name="Conv2D"):
"""2D Convolution with options for kernel size, stride, and init deviation.
Parameters
----------
x : Tensor
Input tensor to convolve.
n_filters : int
Number of filters to apply.
k_h : int, optional
Kernel height.
k_w : int, optional
Kernel width.
stride_h : int, optional
Stride in rows.
stride_w : int, optional
Stride in cols.
stddev : float, optional
Initialization's standard deviation.
activation : arguments, optional
Function which applies a nonlinearity
padding : str, optional
'SAME' or 'VALID'
name : str, optional
Variable scope to use.
Returns
-------
x : Tensor
Convolved input.
"""
with tf.variable_scope(name):
w = tf.get_variable(
'w', [k_h, k_w, x.get_shape()[-1], n_filters],
initializer=tf.truncated_normal_initializer(stddev=stddev))
conv = tf.nn.conv2d(
x, w, strides=[1, stride_h, stride_w, 1], padding=padding)
if bias:
b = tf.get_variable(
'b', [n_filters],
initializer=tf.truncated_normal_initializer(stddev=stddev))
conv = conv + b
return conv
def linear(x, n_units, scope=None, stddev=0.02,
activation=lambda x: x):
"""Fully-connected network.
Parameters
----------
x : Tensor
Input tensor to the network.
n_units : int
Number of units to connect to.
scope : str, optional
Variable scope to use.
stddev : float, optional
Initialization's standard deviation.
activation : arguments, optional
Function which applies a nonlinearity
Returns
-------
x : Tensor
Fully-connected output.
"""
shape = x.get_shape().as_list()
with tf.variable_scope(scope or "Linear"):
matrix = tf.get_variable("Matrix", [shape[1], n_units], tf.float32,
tf.random_normal_initializer(stddev=stddev))
return activation(tf.matmul(x, matrix))
# %%
def weight_variable(shape):
'''Helper function to create a weight variable initialized with
a normal distribution
Parameters
----------
shape : list
Size of weight variable
'''
#initial = tf.random_normal(shape, mean=0.0, stddev=0.01)
initial = tf.zeros(shape)
return tf.Variable(initial)
# %%
def bias_variable(shape):
'''Helper function to create a bias variable initialized with
a constant value.
Parameters
----------
shape : list
Size of weight variable
'''
initial = tf.random_normal(shape, mean=0.0, stddev=0.01)
return tf.Variable(initial)
# %%
def dense_to_one_hot(labels, n_classes=2):
"""Convert class labels from scalars to one-hot vectors."""
labels = np.array(labels)
n_labels = labels.shape[0]
index_offset = np.arange(n_labels) * n_classes
labels_one_hot = np.zeros((n_labels, n_classes), dtype=np.float32)
labels_one_hot.flat[index_offset + labels.ravel()] = 1
return labels_one_hot
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment