Commit 90f63a1e authored by Alex Lee's avatar Alex Lee
Browse files

Fix CDNA transformation bug and speed up its implementation.

- Fix CDNA transformation bug where transformed channels of color and masks were combined incorrectly.
- Remove for loop over batch size in implementation of CDNA transformation. This speeds up the building of the graph.
parent 44fa1d37
...@@ -261,6 +261,8 @@ def cdna_transformation(prev_image, cdna_input, num_masks, color_channels): ...@@ -261,6 +261,8 @@ def cdna_transformation(prev_image, cdna_input, num_masks, color_channels):
List of images transformed by the predicted CDNA kernels. List of images transformed by the predicted CDNA kernels.
""" """
batch_size = int(cdna_input.get_shape()[0]) batch_size = int(cdna_input.get_shape()[0])
height = int(prev_image.get_shape()[1])
width = int(prev_image.get_shape()[2])
# Predict kernels using linear function of last hidden layer. # Predict kernels using linear function of last hidden layer.
cdna_kerns = slim.layers.fully_connected( cdna_kerns = slim.layers.fully_connected(
...@@ -276,20 +278,22 @@ def cdna_transformation(prev_image, cdna_input, num_masks, color_channels): ...@@ -276,20 +278,22 @@ def cdna_transformation(prev_image, cdna_input, num_masks, color_channels):
norm_factor = tf.reduce_sum(cdna_kerns, [1, 2, 3], keep_dims=True) norm_factor = tf.reduce_sum(cdna_kerns, [1, 2, 3], keep_dims=True)
cdna_kerns /= norm_factor cdna_kerns /= norm_factor
cdna_kerns = tf.tile(cdna_kerns, [1, 1, 1, color_channels, 1]) # Treat the color channel dimension as the batch dimension since the same
cdna_kerns = tf.split(axis=0, num_or_size_splits=batch_size, value=cdna_kerns) # transformation is applied to each color channel.
prev_images = tf.split(axis=0, num_or_size_splits=batch_size, value=prev_image) # Treat the batch dimension as the channel dimension so that
# depthwise_conv2d can apply a different transformation to each sample.
cdna_kerns = tf.transpose(cdna_kerns, [1, 2, 0, 4, 3])
cdna_kerns = tf.reshape(cdna_kerns, [DNA_KERN_SIZE, DNA_KERN_SIZE, batch_size, num_masks])
# Swap the batch and channel dimensions.
prev_image = tf.transpose(prev_image, [3, 1, 2, 0])
# Transform image. # Transform image.
transformed = [] transformed = tf.nn.depthwise_conv2d(prev_image, cdna_kerns, [1, 1, 1, 1], 'SAME')
for kernel, preimg in zip(cdna_kerns, prev_images):
kernel = tf.squeeze(kernel) # Transpose the dimensions to where they belong.
if len(kernel.get_shape()) == 3: transformed = tf.reshape(transformed, [color_channels, height, width, batch_size, num_masks])
kernel = tf.expand_dims(kernel, -1) transformed = tf.transpose(transformed, [3, 1, 2, 0, 4])
transformed.append( transformed = tf.unstack(transformed, axis=-1)
tf.nn.depthwise_conv2d(preimg, kernel, [1, 1, 1, 1], 'SAME'))
transformed = tf.concat(axis=0, values=transformed)
transformed = tf.split(axis=3, num_or_size_splits=num_masks, value=transformed)
return transformed return transformed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment