"vscode:/vscode.git/clone" did not exist on "5e56e56399224abd135e4ac2f2e5d25db0f2a877"
Commit bf60abf8 authored by Timur's avatar Timur Committed by Martin Wicke
Browse files

Spatial transformer: (#57)

* Modified the way the output size is specified.
* Added support for batches of inputs.
parent 8332400d
......@@ -14,7 +14,7 @@
# ==============================================================================
import tensorflow as tf
def transformer(U, theta, downsample_factor=1, name='SpatialTransformer', **kwargs):
def transformer(U, theta, out_size, name='SpatialTransformer', **kwargs):
"""Spatial Transformer Layer
Implements a spatial transformer layer as described in [1]_.
......@@ -28,14 +28,9 @@ def transformer(U, theta, downsample_factor=1, name='SpatialTransformer', **kwar
theta: float
The output of the
localisation network should be [num_batch, 6].
downsample_factor : float
A value of 1 will keep the original size of the image
Values larger than 1 will downsample the image.
Values below 1 will upsample the image
example image: height = 100, width = 200
downsample_factor = 2
output image will then be 50, 100
out_size: tuple of two floats
The size of the output of the network
References
----------
.. [1] Spatial Transformer Networks
......@@ -61,7 +56,7 @@ def transformer(U, theta, downsample_factor=1, name='SpatialTransformer', **kwar
x = tf.matmul(tf.reshape(x,(-1, 1)), rep)
return tf.reshape(x,[-1])
def _interpolate(im, x, y, downsample_factor):
def _interpolate(im, x, y, out_size):
with tf.variable_scope('_interpolate'):
# constants
num_batch = tf.shape(im)[0]
......@@ -73,8 +68,8 @@ def transformer(U, theta, downsample_factor=1, name='SpatialTransformer', **kwar
y = tf.cast(y, 'float32')
height_f = tf.cast(height, 'float32')
width_f = tf.cast(width, 'float32')
out_height = tf.cast(height_f // downsample_factor, 'int32')
out_width = tf.cast(width_f // downsample_factor, 'int32')
out_height = out_size[0]
out_width = out_size[1]
zero = tf.zeros([], dtype='int32')
max_y = tf.cast(tf.shape(im)[1] - 1, 'int32')
max_x = tf.cast(tf.shape(im)[2] - 1, 'int32')
......@@ -142,7 +137,7 @@ def transformer(U, theta, downsample_factor=1, name='SpatialTransformer', **kwar
grid = tf.concat(0, [x_t_flat, y_t_flat, ones])
return grid
def _transform(theta, input_dim, downsample_factor):
def _transform(theta, input_dim, out_size):
with tf.variable_scope('_transform'):
num_batch = tf.shape(input_dim)[0]
height = tf.shape(input_dim)[1]
......@@ -154,8 +149,8 @@ def transformer(U, theta, downsample_factor=1, name='SpatialTransformer', **kwar
# grid of (x_t, y_t, 1), eq (1) in ref [1]
height_f = tf.cast(height, 'float32')
width_f = tf.cast(width, 'float32')
out_height = tf.cast(height_f // downsample_factor, 'int32')
out_width = tf.cast(width_f // downsample_factor, 'int32')
out_height = out_size[0]
out_width = out_size[1]
grid = _meshgrid(out_height, out_width)
grid = tf.expand_dims(grid,0)
grid = tf.reshape(grid,[-1])
......@@ -171,11 +166,34 @@ def transformer(U, theta, downsample_factor=1, name='SpatialTransformer', **kwar
input_transformed = _interpolate(
input_dim, x_s_flat, y_s_flat,
downsample_factor)
out_size)
output = tf.reshape(input_transformed, tf.pack([num_batch, out_height, out_width, num_channels]))
return output
with tf.variable_scope(name):
output = _transform(theta, U, downsample_factor)
return output
\ No newline at end of file
output = _transform(theta, U, out_size)
return output
def batch_transformer(U, thetas, out_size, name='BatchSpatialTransformer'):
"""Batch Spatial Transformer Layer
Parameters
----------
U : float
tensor of inputs [num_batch,height,width,num_channels]
thetas : float
a set of transformations for each input [num_batch,num_transforms,6]
out_size : int
the size of the output [out_height,out_width]
Returns: float
Tensor of size [num_batch*num_transforms,out_height,out_width,num_channels]
"""
with tf.variable_scope(name):
num_batch, num_transforms = map(int, thetas.get_shape().as_list()[:2])
indices = [[i]*num_transforms for i in xrange(num_batch)]
input_repeated = tf.gather(U, tf.reshape(indices, [-1]))
return transformer(input_repeated, thetas, out_size)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment