Commit 46ef4860 authored by Neal Wu's avatar Neal Wu Committed by GitHub
Browse files

Merge pull request #1822 from alexlee-gk/master

Fix CDNA transformation bug and speed up its implementation.
parents afabda6d 90f63a1e
......@@ -261,6 +261,8 @@ def cdna_transformation(prev_image, cdna_input, num_masks, color_channels):
List of images transformed by the predicted CDNA kernels.
"""
batch_size = int(cdna_input.get_shape()[0])
height = int(prev_image.get_shape()[1])
width = int(prev_image.get_shape()[2])
# Predict kernels using linear function of last hidden layer.
cdna_kerns = slim.layers.fully_connected(
......@@ -276,20 +278,22 @@ def cdna_transformation(prev_image, cdna_input, num_masks, color_channels):
norm_factor = tf.reduce_sum(cdna_kerns, [1, 2, 3], keep_dims=True)
cdna_kerns /= norm_factor
cdna_kerns = tf.tile(cdna_kerns, [1, 1, 1, color_channels, 1])
cdna_kerns = tf.split(axis=0, num_or_size_splits=batch_size, value=cdna_kerns)
prev_images = tf.split(axis=0, num_or_size_splits=batch_size, value=prev_image)
# Treat the color channel dimension as the batch dimension since the same
# transformation is applied to each color channel.
# Treat the batch dimension as the channel dimension so that
# depthwise_conv2d can apply a different transformation to each sample.
cdna_kerns = tf.transpose(cdna_kerns, [1, 2, 0, 4, 3])
cdna_kerns = tf.reshape(cdna_kerns, [DNA_KERN_SIZE, DNA_KERN_SIZE, batch_size, num_masks])
# Swap the batch and channel dimensions.
prev_image = tf.transpose(prev_image, [3, 1, 2, 0])
# Transform image.
transformed = []
for kernel, preimg in zip(cdna_kerns, prev_images):
kernel = tf.squeeze(kernel)
if len(kernel.get_shape()) == 3:
kernel = tf.expand_dims(kernel, -1)
transformed.append(
tf.nn.depthwise_conv2d(preimg, kernel, [1, 1, 1, 1], 'SAME'))
transformed = tf.concat(axis=0, values=transformed)
transformed = tf.split(axis=3, num_or_size_splits=num_masks, value=transformed)
transformed = tf.nn.depthwise_conv2d(prev_image, cdna_kerns, [1, 1, 1, 1], 'SAME')
# Transpose the dimensions to where they belong.
transformed = tf.reshape(transformed, [color_channels, height, width, batch_size, num_masks])
transformed = tf.transpose(transformed, [3, 1, 2, 0, 4])
transformed = tf.unstack(transformed, axis=-1)
return transformed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment