Commit eea839c4 authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Add multi-layer conditional mask prediction head.

PiperOrigin-RevId: 410930769
parent 10c0e96b
...@@ -151,7 +151,7 @@ def _get_deepmac_network_by_type(name, num_init_channels, mask_size=None): ...@@ -151,7 +151,7 @@ def _get_deepmac_network_by_type(name, num_init_channels, mask_size=None):
raise ValueError('Mask size must be set.') raise ValueError('Mask size must be set.')
return FullyConnectedMaskHead(num_init_channels, mask_size) return FullyConnectedMaskHead(num_init_channels, mask_size)
elif name == 'embedding_projection': elif _is_mask_head_param_free(name):
return tf.keras.layers.Lambda(lambda x: x) return tf.keras.layers.Lambda(lambda x: x)
elif name.startswith('resnet'): elif name.startswith('resnet'):
...@@ -395,6 +395,89 @@ def dilated_cross_same_mask_label(instance_masks, dilation=2): ...@@ -395,6 +395,89 @@ def dilated_cross_same_mask_label(instance_masks, dilation=2):
return tf.transpose(same_mask_prob, (0, 3, 1, 2)) return tf.transpose(same_mask_prob, (0, 3, 1, 2))
def _per_pixel_single_conv(input_tensor, params, channels):
"""Convolve the given input with the given params.
Args:
input_tensor: A [num_instances, height, width, channels] shaped
float tensor.
params: A [num_instances, num_params] shaped float tensor.
channels: int, number of channels in the convolution.
Returns:
output: A float tensor of shape [num_instances, height, width, channels]
"""
input_channels = input_tensor.get_shape().as_list()[3]
weights = params[:, :(input_channels * channels)]
biases = params[:, (input_channels * channels):]
num_instances = tf.shape(params)[0]
weights = tf.reshape(weights, (num_instances, input_channels, channels))
output = (input_tensor[:, :, tf.newaxis, :] @
weights[:, tf.newaxis, tf.newaxis, :, :])
output = output[:, :, 0, :, :]
output = output + biases[:, tf.newaxis, tf.newaxis, :]
return output
def per_pixel_conditional_conv(input_tensor, parameters, channels, depth):
"""Use parameters perform per-pixel convolutions with the given depth [1].
[1]: https://arxiv.org/abs/2003.05664
Args:
input_tensor: float tensor of shape [num_instances, height,
width, input_channels]
parameters: A [num_instances, num_params] float tensor. If num_params
is incomparible with the given channels and depth, a ValueError will
be raised.
channels: int, the number of channels in the convolution.
depth: int, the number of layers of convolutions to perform.
Returns:
output: A [num_instances, height, width] tensor with the conditional
conv applied according to each instance's parameters.
"""
input_channels = input_tensor.get_shape().as_list()[3]
num_params = parameters.get_shape().as_list()[1]
input_convs = 1 if depth > 1 else 0
intermediate_convs = depth - 2 if depth >= 2 else 0
expected_weights = ((input_channels * channels * input_convs) +
(channels * channels * intermediate_convs) +
channels) # final conv
expected_biases = (channels * (depth - 1)) + 1
if depth == 1:
if input_channels != channels:
raise ValueError(
'When depth=1, input_channels({}) should be equal to'.format(
input_channels) + ' channels({})'.format(channels))
if num_params != (expected_weights + expected_biases):
raise ValueError('Expected {} parameters at depth {}, but got {}'.format(
expected_weights + expected_biases, depth, num_params))
start = 0
output = input_tensor
for i in range(depth):
if i == (depth - 1):
channels = 1
num_params_single_conv = channels * input_channels + channels
params = parameters[:, start:start + num_params_single_conv]
start += num_params_single_conv
output = _per_pixel_single_conv(output, params, channels)
input_channels = channels
return output
class ResNetMaskNetwork(tf.keras.layers.Layer): class ResNetMaskNetwork(tf.keras.layers.Layer):
"""A small wrapper around ResNet blocks to predict masks.""" """A small wrapper around ResNet blocks to predict masks."""
...@@ -560,6 +643,16 @@ class DenseResNet(tf.keras.layers.Layer): ...@@ -560,6 +643,16 @@ class DenseResNet(tf.keras.layers.Layer):
return self.out_conv(self.resnet(net)) return self.out_conv(self.resnet(net))
def _is_mask_head_param_free(name):
# Mask heads which don't have parameters of their own and instead rely
# on the instance embedding.
if name == 'embedding_projection' or name.startswith('cond_inst'):
return True
return False
class MaskHeadNetwork(tf.keras.layers.Layer): class MaskHeadNetwork(tf.keras.layers.Layer):
"""Mask head class for DeepMAC.""" """Mask head class for DeepMAC."""
...@@ -586,13 +679,14 @@ class MaskHeadNetwork(tf.keras.layers.Layer): ...@@ -586,13 +679,14 @@ class MaskHeadNetwork(tf.keras.layers.Layer):
self._use_instance_embedding = use_instance_embedding self._use_instance_embedding = use_instance_embedding
self._network_type = network_type self._network_type = network_type
self._num_init_channels = num_init_channels
if (self._use_instance_embedding and if (self._use_instance_embedding and
(self._network_type == 'embedding_projection')): (_is_mask_head_param_free(network_type))):
raise ValueError(('Cannot feed instance embedding to mask head when ' raise ValueError(('Cannot feed instance embedding to mask head when '
'computing embedding projection.')) 'mask-head has no parameters.'))
if network_type == 'embedding_projection': if _is_mask_head_param_free(network_type):
self.project_out = tf.keras.layers.Lambda(lambda x: x) self.project_out = tf.keras.layers.Lambda(lambda x: x)
else: else:
self.project_out = tf.keras.layers.Conv2D( self.project_out = tf.keras.layers.Conv2D(
...@@ -632,6 +726,11 @@ class MaskHeadNetwork(tf.keras.layers.Layer): ...@@ -632,6 +726,11 @@ class MaskHeadNetwork(tf.keras.layers.Layer):
instance_embedding = instance_embedding[:, tf.newaxis, tf.newaxis, :] instance_embedding = instance_embedding[:, tf.newaxis, tf.newaxis, :]
out = embedding_projection(instance_embedding, out) out = embedding_projection(instance_embedding, out)
elif self._network_type.startswith('cond_inst'):
depth = int(self._network_type.lstrip('cond_inst'))
out = per_pixel_conditional_conv(out, instance_embedding,
self._num_init_channels, depth)
if out.shape[-1] > 1: if out.shape[-1] > 1:
out = self.project_out(out) out = self.project_out(out)
......
...@@ -280,17 +280,60 @@ class DeepMACUtilsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -280,17 +280,60 @@ class DeepMACUtilsTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(np.ones((8, 5, 5)), output[:, 1, :, :]) self.assertAllClose(np.ones((8, 5, 5)), output[:, 1, :, :])
self.assertAllClose([1, 0, 0, 0, 0, 0, 0, 1], output[:, 0, 2, 2]) self.assertAllClose([1, 0, 0, 0, 0, 0, 0, 1], output[:, 0, 2, 2])
def test_per_pixel_single_conv_multiple_instance(self):
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.') inp = tf.zeros((5, 32, 32, 7))
class DeepMACMaskHeadTest(tf.test.TestCase, parameterized.TestCase): params = tf.zeros((5, 7*8 + 8))
@parameterized.parameters( out = deepmac_meta_arch._per_pixel_single_conv(inp, params, 8)
['hourglass10', 'hourglass20', 'resnet4']) self.assertEqual(out.shape, (5, 32, 32, 8))
def test_mask_network(self, head_type):
net = deepmac_meta_arch.MaskHeadNetwork(head_type, 8)
out = net(tf.zeros((2, 4)), tf.zeros((2, 32, 32, 16)), training=True) def test_per_pixel_conditional_conv_error(self):
self.assertEqual(out.shape, (2, 32, 32))
with self.assertRaises(ValueError):
deepmac_meta_arch.per_pixel_conditional_conv(
tf.zeros((10, 32, 32, 8)), tf.zeros((10, 2)), 8, 3)
def test_per_pixel_conditional_conv_error_tf_func(self):
with self.assertRaises(ValueError):
func = tf.function(deepmac_meta_arch.per_pixel_conditional_conv)
func(tf.zeros((10, 32, 32, 8)), tf.zeros((10, 2)), 8, 3)
def test_per_pixel_conditional_conv_depth1_error(self):
with self.assertRaises(ValueError):
_ = deepmac_meta_arch.per_pixel_conditional_conv(
tf.zeros((10, 32, 32, 7)), tf.zeros((10, 8)), 99, 1)
def test_per_pixel_conditional_conv_depth1(self):
out = deepmac_meta_arch.per_pixel_conditional_conv(
tf.zeros((10, 32, 32, 7)), tf.zeros((10, 8)), 7, 1)
self.assertEqual(out.shape, (10, 32, 32, 1))
def test_per_pixel_conditional_conv_depth2(self):
num_params = (
7 * 9 + 9 + # layer 1
9 + 1) # layer 2
out = deepmac_meta_arch.per_pixel_conditional_conv(
tf.zeros((10, 32, 32, 7)), tf.zeros((10, num_params)), 9, 2)
self.assertEqual(out.shape, (10, 32, 32, 1))
def test_per_pixel_conditional_conv_depth3(self):
# From the paper https://arxiv.org/abs/2003.05664
out = deepmac_meta_arch.per_pixel_conditional_conv(
tf.zeros((10, 32, 32, 10)), tf.zeros((10, 169)), 8, 3)
self.assertEqual(out.shape, (10, 32, 32, 1))
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class DeepMACMaskHeadTest(tf.test.TestCase, parameterized.TestCase):
def test_mask_network_params_resnet4(self): def test_mask_network_params_resnet4(self):
net = deepmac_meta_arch.MaskHeadNetwork('resnet4', num_init_channels=8) net = deepmac_meta_arch.MaskHeadNetwork('resnet4', num_init_channels=8)
...@@ -301,39 +344,65 @@ class DeepMACMaskHeadTest(tf.test.TestCase, parameterized.TestCase): ...@@ -301,39 +344,65 @@ class DeepMACMaskHeadTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual(trainable_params.numpy(), 8665) self.assertEqual(trainable_params.numpy(), 8665)
def test_mask_network_resnet_tf_function(self): def test_mask_network_embedding_projection_small(self):
net = deepmac_meta_arch.MaskHeadNetwork('resnet8')
call_func = tf.function(net.__call__)
out = call_func(tf.zeros((2, 4)), tf.zeros((2, 32, 32, 16)), training=True)
self.assertEqual(out.shape, (2, 32, 32))
def test_mask_network_embedding_projection_zero(self):
net = deepmac_meta_arch.MaskHeadNetwork( net = deepmac_meta_arch.MaskHeadNetwork(
'embedding_projection', num_init_channels=8, 'embedding_projection', num_init_channels=-1,
use_instance_embedding=False) use_instance_embedding=False)
call_func = tf.function(net.__call__) call_func = tf.function(net.__call__)
out = call_func(tf.zeros((2, 7)), tf.zeros((2, 32, 32, 7)), training=True) out = call_func(1e6 + tf.zeros((2, 7)),
tf.zeros((2, 32, 32, 7)), training=True)
self.assertEqual(out.shape, (2, 32, 32)) self.assertEqual(out.shape, (2, 32, 32))
self.assertAllGreater(out.numpy(), -np.inf) self.assertAllGreater(out.numpy(), -np.inf)
self.assertAllLess(out.numpy(), np.inf) self.assertAllLess(out.numpy(), np.inf)
def test_mask_network_embedding_projection_small(self): @parameterized.parameters([
{'mask_net': 'resnet4', 'mask_net_channels': 8,
'instance_embedding_dim': 4, 'input_channels': 16,
'use_instance_embedding': False},
{'mask_net': 'hourglass10', 'mask_net_channels': 8,
'instance_embedding_dim': 4, 'input_channels': 16,
'use_instance_embedding': False},
{'mask_net': 'hourglass20', 'mask_net_channels': 8,
'instance_embedding_dim': 4, 'input_channels': 16,
'use_instance_embedding': False},
{'mask_net': 'cond_inst3', 'mask_net_channels': 8,
'instance_embedding_dim': 153, 'input_channels': 8,
'use_instance_embedding': False},
{'mask_net': 'cond_inst3', 'mask_net_channels': 8,
'instance_embedding_dim': 169, 'input_channels': 10,
'use_instance_embedding': False},
{'mask_net': 'cond_inst1', 'mask_net_channels': 8,
'instance_embedding_dim': 9, 'input_channels': 8,
'use_instance_embedding': False},
{'mask_net': 'cond_inst2', 'mask_net_channels': 8,
'instance_embedding_dim': 81, 'input_channels': 8,
'use_instance_embedding': False},
])
def test_mask_network(self, mask_net, mask_net_channels,
instance_embedding_dim, input_channels,
use_instance_embedding):
net = deepmac_meta_arch.MaskHeadNetwork( net = deepmac_meta_arch.MaskHeadNetwork(
'embedding_projection', num_init_channels=-1, mask_net, num_init_channels=mask_net_channels,
use_instance_embedding=False) use_instance_embedding=use_instance_embedding)
call_func = tf.function(net.__call__) call_func = tf.function(net.__call__)
out = call_func(1e6 + tf.zeros((2, 7)), out = call_func(tf.zeros((2, instance_embedding_dim)),
tf.zeros((2, 32, 32, 7)), training=True) tf.zeros((2, 32, 32, input_channels)), training=True)
self.assertEqual(out.shape, (2, 32, 32)) self.assertEqual(out.shape, (2, 32, 32))
self.assertAllGreater(out.numpy(), -np.inf) self.assertAllGreater(out.numpy(), -np.inf)
self.assertAllLess(out.numpy(), np.inf) self.assertAllLess(out.numpy(), np.inf)
out = call_func(tf.zeros((2, instance_embedding_dim)),
tf.zeros((2, 32, 32, input_channels)), training=True)
self.assertEqual(out.shape, (2, 32, 32))
out = call_func(tf.zeros((0, instance_embedding_dim)),
tf.zeros((0, 32, 32, input_channels)), training=True)
self.assertEqual(out.shape, (0, 32, 32))
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.') @unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment