Commit 65407126 authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Fix conditional convs by adding ReLU.

PiperOrigin-RevId: 411887283
parent c280c4ee
......@@ -465,7 +465,8 @@ def per_pixel_conditional_conv(input_tensor, parameters, channels, depth):
output = input_tensor
for i in range(depth):
if i == (depth - 1):
is_last_layer = i == (depth - 1)
if is_last_layer:
channels = 1
num_params_single_conv = channels * input_channels + channels
......@@ -473,6 +474,10 @@ def per_pixel_conditional_conv(input_tensor, parameters, channels, depth):
start += num_params_single_conv
output = _per_pixel_single_conv(output, params, channels)
if not is_last_layer:
output = tf.nn.relu(output)
input_channels = channels
return output
......
......@@ -306,30 +306,95 @@ class DeepMACUtilsTest(tf.test.TestCase, parameterized.TestCase):
_ = deepmac_meta_arch.per_pixel_conditional_conv(
tf.zeros((10, 32, 32, 7)), tf.zeros((10, 8)), 99, 1)
def test_per_pixel_conditional_conv_depth1(self):
@parameterized.parameters([
{
'num_input_channels': 7,
'instance_embedding_dim': 8,
'channels': 7,
'depth': 1
},
{
'num_input_channels': 7,
'instance_embedding_dim': 82,
'channels': 9,
'depth': 2
},
{ # From https://arxiv.org/abs/2003.05664
'num_input_channels': 10,
'instance_embedding_dim': 169,
'channels': 8,
'depth': 3
},
{
'num_input_channels': 8,
'instance_embedding_dim': 433,
'channels': 16,
'depth': 3
},
{
'num_input_channels': 8,
'instance_embedding_dim': 1377,
'channels': 32,
'depth': 3
},
{
'num_input_channels': 8,
'instance_embedding_dim': 4801,
'channels': 64,
'depth': 3
},
])
def test_per_pixel_conditional_conv_shape(
self, num_input_channels, instance_embedding_dim, channels, depth):
out = deepmac_meta_arch.per_pixel_conditional_conv(
tf.zeros((10, 32, 32, 7)), tf.zeros((10, 8)), 7, 1)
tf.zeros((10, 32, 32, num_input_channels)),
tf.zeros((10, instance_embedding_dim)), channels, depth)
self.assertEqual(out.shape, (10, 32, 32, 1))
def test_per_pixel_conditional_conv_depth2(self):
def test_per_pixel_conditional_conv_value_depth1(self):
num_params = (
7 * 9 + 9 + # layer 1
9 + 1) # layer 2
input_tensor = tf.constant(np.array([1, 2, 3]))
input_tensor = tf.reshape(input_tensor, (1, 1, 1, 3))
instance_embedding = tf.constant(
np.array([1, 10, 100, 1000]))
instance_embedding = tf.reshape(instance_embedding, (1, 4))
out = deepmac_meta_arch.per_pixel_conditional_conv(
tf.zeros((10, 32, 32, 7)), tf.zeros((10, num_params)), 9, 2)
input_tensor, instance_embedding, channels=3, depth=1)
self.assertEqual(out.shape, (10, 32, 32, 1))
expected_output = np.array([1321])
expected_output = np.reshape(expected_output, (1, 1, 1, 1))
self.assertAllClose(expected_output, out)
def test_per_pixel_conditional_conv_depth3(self):
def test_per_pixel_conditional_conv_value_depth2_single(self):
# From the paper https://arxiv.org/abs/2003.05664
input_tensor = tf.constant(np.array([2]))
input_tensor = tf.reshape(input_tensor, (1, 1, 1, 1))
instance_embedding = tf.constant(
np.array([-2, 3, 100, 5]))
instance_embedding = tf.reshape(instance_embedding, (1, 4))
out = deepmac_meta_arch.per_pixel_conditional_conv(
tf.zeros((10, 32, 32, 10)), tf.zeros((10, 169)), 8, 3)
input_tensor, instance_embedding, channels=1, depth=2)
self.assertEqual(out.shape, (10, 32, 32, 1))
expected_output = np.array([5])
expected_output = np.reshape(expected_output, (1, 1, 1, 1))
self.assertAllClose(expected_output, out)
def test_per_pixel_conditional_conv_value_depth2_identity(self):
input_tensor = tf.constant(np.array([1, 2]))
input_tensor = tf.reshape(input_tensor, (1, 1, 1, 2))
instance_embedding = tf.constant(
np.array([1, 0, 0, 1, 1, -3, 5, 100, -9]))
instance_embedding = tf.reshape(
instance_embedding, (1, 9))
out = deepmac_meta_arch.per_pixel_conditional_conv(
input_tensor, instance_embedding, channels=2, depth=2)
expected_output = np.array([1])
expected_output = np.reshape(expected_output, (1, 1, 1, 1))
self.assertAllClose(expected_output, out)
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
......@@ -358,27 +423,55 @@ class DeepMACMaskHeadTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllLess(out.numpy(), np.inf)
@parameterized.parameters([
{'mask_net': 'resnet4', 'mask_net_channels': 8,
'instance_embedding_dim': 4, 'input_channels': 16,
'use_instance_embedding': False},
{'mask_net': 'hourglass10', 'mask_net_channels': 8,
'instance_embedding_dim': 4, 'input_channels': 16,
'use_instance_embedding': False},
{'mask_net': 'hourglass20', 'mask_net_channels': 8,
'instance_embedding_dim': 4, 'input_channels': 16,
'use_instance_embedding': False},
{'mask_net': 'cond_inst3', 'mask_net_channels': 8,
'instance_embedding_dim': 153, 'input_channels': 8,
'use_instance_embedding': False},
{'mask_net': 'cond_inst3', 'mask_net_channels': 8,
'instance_embedding_dim': 169, 'input_channels': 10,
'use_instance_embedding': False},
{'mask_net': 'cond_inst1', 'mask_net_channels': 8,
'instance_embedding_dim': 9, 'input_channels': 8,
'use_instance_embedding': False},
{'mask_net': 'cond_inst2', 'mask_net_channels': 8,
'instance_embedding_dim': 81, 'input_channels': 8,
'use_instance_embedding': False},
{
'mask_net': 'resnet4',
'mask_net_channels': 8,
'instance_embedding_dim': 4,
'input_channels': 16,
'use_instance_embedding': False
},
{
'mask_net': 'hourglass10',
'mask_net_channels': 8,
'instance_embedding_dim': 4,
'input_channels': 16,
'use_instance_embedding': False
},
{
'mask_net': 'hourglass20',
'mask_net_channels': 8,
'instance_embedding_dim': 4,
'input_channels': 16,
'use_instance_embedding': False
},
{
'mask_net': 'cond_inst3',
'mask_net_channels': 8,
'instance_embedding_dim': 153,
'input_channels': 8,
'use_instance_embedding': False
},
{
'mask_net': 'cond_inst3',
'mask_net_channels': 8,
'instance_embedding_dim': 169,
'input_channels': 10,
'use_instance_embedding': False
},
{
'mask_net': 'cond_inst1',
'mask_net_channels': 8,
'instance_embedding_dim': 9,
'input_channels': 8,
'use_instance_embedding': False
},
{
'mask_net': 'cond_inst2',
'mask_net_channels': 8,
'instance_embedding_dim': 81,
'input_channels': 8,
'use_instance_embedding': False
},
])
def test_mask_network(self, mask_net, mask_net_channels,
instance_embedding_dim, input_channels,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment