Commit 9c28cff8 authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Switch out euclidean distance to projection based distance and 1D ResNet.

PiperOrigin-RevId: 395297294
parent b6bb00b4
......@@ -126,7 +126,7 @@ def _get_deepmac_network_by_type(name, num_init_channels, mask_size=None):
raise ValueError('Mask size must be set.')
return FullyConnectedMaskHead(num_init_channels, mask_size)
elif name == 'embedding_distance_probability':
elif name == 'embedding_projection':
return tf.keras.layers.Lambda(lambda x: x)
elif name.startswith('resnet'):
......@@ -266,8 +266,8 @@ def fill_boxes(boxes, height, width):
return tf.cast(filled_boxes, tf.float32)
def embedding_distance_to_probability(x, y):
"""Compute probability based on pixel-wise embedding distance.
def embedding_projection(x, y):
"""Compute dot product between two given embeddings.
Args:
x: [num_instances, height, width, dimension] float tensor input.
......@@ -277,12 +277,11 @@ def embedding_distance_to_probability(x, y):
Returns:
dist: [num_instances, height, width, 1] A float tensor returning
the per-pixel probability. Pixels whose embeddings are close in
euclidean distance get a probability of close to 1.
the per-pixel embedding projection.
"""
diff = x - y
squared_dist = tf.reduce_sum(diff * diff, axis=3, keepdims=True)
return tf.exp(-squared_dist)
dot = tf.reduce_sum(x * y, axis=3, keepdims=True)
return dot
class ResNetMaskNetwork(tf.keras.layers.Layer):
......@@ -364,6 +363,92 @@ class FullyConnectedMaskHead(tf.keras.layers.Layer):
[num_instances, self.mask_size, self.mask_size, 1])
class DenseResidualBlock(tf.keras.layers.Layer):
"""Residual block for 1D inputs.
This class implemented the pre-activation version of the ResNet block.
"""
def __init__(self, hidden_size, use_shortcut_linear):
"""Residual Block for 1D inputs.
Args:
hidden_size: size of the hidden layer.
use_shortcut_linear: bool, whether or not to use a linear layer for
shortcut.
"""
super(DenseResidualBlock, self).__init__()
self.bn_0 = tf.keras.layers.experimental.SyncBatchNormalization(axis=-1)
self.bn_1 = tf.keras.layers.experimental.SyncBatchNormalization(axis=-1)
self.fc_0 = tf.keras.layers.Dense(
hidden_size, activation=None)
self.fc_1 = tf.keras.layers.Dense(
hidden_size, activation=None, kernel_initializer='zeros')
self.activation = tf.keras.layers.Activation('relu')
if use_shortcut_linear:
self.shortcut = tf.keras.layers.Dense(
hidden_size, activation=None, use_bias=False)
else:
self.shortcut = tf.keras.layers.Lambda(lambda x: x)
def __call__(self, inputs):
"""Layer's forward pass.
Args:
inputs: input tensor.
Returns:
Tensor after residual block w/ CondBatchNorm.
"""
out = self.fc_0(self.activation(self.bn_0(inputs)))
residual_inp = self.fc_1(self.activation(self.bn_1(out)))
skip = self.shortcut(inputs)
return residual_inp + skip
class DenseResNet(tf.keras.layers.Layer):
"""Resnet with dense layers."""
def __init__(self, num_layers, hidden_size, output_size):
"""Resnet with dense layers.
Args:
num_layers: int, the number of layers.
hidden_size: size of the hidden layer.
output_size: size of the output.
"""
super(DenseResNet, self).__init__()
self.input_proj = DenseResidualBlock(hidden_size, use_shortcut_linear=True)
if num_layers < 4:
raise ValueError(
'Cannot construct a DenseResNet with less than 4 layers')
num_blocks = (num_layers - 2) // 2
if ((num_blocks * 2) + 2) != num_layers:
raise ValueError(('DenseResNet depth has to be of the form (2n + 2). '
f'Found {num_layers}'))
self._num_blocks = num_blocks
blocks = [DenseResidualBlock(hidden_size, use_shortcut_linear=False)
for _ in range(num_blocks)]
self.resnet = tf.keras.Sequential(blocks)
self.out_conv = tf.keras.layers.Dense(output_size)
def __call__(self, inputs):
net = self.input_proj(inputs)
return self.out_conv(self.resnet(net))
class MaskHeadNetwork(tf.keras.layers.Layer):
"""Mask head class for DeepMAC."""
......@@ -392,11 +477,11 @@ class MaskHeadNetwork(tf.keras.layers.Layer):
self._network_type = network_type
if (self._use_instance_embedding and
(self._network_type == 'embedding_distance_probability')):
(self._network_type == 'embedding_projection')):
raise ValueError(('Cannot feed instance embedding to mask head when '
'computing distance from instance embedding.'))
'computing embedding projection.'))
if network_type == 'embedding_distance_probability':
if network_type == 'embedding_projection':
self.project_out = tf.keras.layers.Lambda(lambda x: x)
else:
self.project_out = tf.keras.layers.Conv2D(
......@@ -432,9 +517,9 @@ class MaskHeadNetwork(tf.keras.layers.Layer):
if isinstance(out, list):
out = out[-1]
if self._network_type == 'embedding_distance_probability':
if self._network_type == 'embedding_projection':
instance_embedding = instance_embedding[:, tf.newaxis, tf.newaxis, :]
out = embedding_distance_to_probability(instance_embedding, out)
out = embedding_projection(instance_embedding, out)
if out.shape[-1] > 1:
out = self.project_out(out)
......@@ -502,24 +587,20 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
if self._deepmac_params.mask_num_subsamples > 0:
raise ValueError('Subsampling masks is currently not supported.')
if self._deepmac_params.network_type == 'embedding_distance_probability':
if self._deepmac_params.network_type == 'embedding_projection':
if self._deepmac_params.use_xy:
raise ValueError(
'Cannot use x/y coordinates when using embedding distance.')
'Cannot use x/y coordinates when using embedding projection.')
pixel_embedding_dim = self._deepmac_params.pixel_embedding_dim
dim = self._deepmac_params.dim
if dim != pixel_embedding_dim:
raise ValueError(
'When using embedding distance mask head, '
'When using embedding projection mask head, '
f'pixel_embedding_dim({pixel_embedding_dim}) '
f'must be same as dim({dim}).')
loss = self._deepmac_params.classification_loss
if ((not isinstance(loss, losses.WeightedDiceClassificationLoss))
or (not loss.is_prediction_probability)):
raise ValueError('Only dice loss with is_prediction_probability=true '
'is supported with embedding distance mask head.')
super(DeepMACMetaArch, self).__init__(
is_training=is_training, add_summaries=add_summaries,
......@@ -964,10 +1045,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
mask_logits = crop_masks_within_boxes(
mask_logits, boxes, self._deepmac_params.postprocess_crop_size)
if self._deepmac_params.network_type == 'embedding_distance_probability':
masks_prob = mask_logits
else:
masks_prob = tf.nn.sigmoid(mask_logits)
masks_prob = tf.nn.sigmoid(mask_logits)
return masks_prob
......
......@@ -174,22 +174,41 @@ class DeepMACUtilsTest(tf.test.TestCase, parameterized.TestCase):
features, boxes, 32)
self.assertEqual(output.shape, (5, 32, 32, 7))
def test_embedding_distance_prob_shape(self):
dist = deepmac_meta_arch.embedding_distance_to_probability(
def test_embedding_projection_prob_shape(self):
dist = deepmac_meta_arch.embedding_projection(
tf.ones((4, 32, 32, 8)), tf.zeros((4, 32, 32, 8)))
self.assertEqual(dist.shape, (4, 32, 32, 1))
@parameterized.parameters([1e-20, 1e20])
def test_embedding_distance_prob_value(self, value):
dist = deepmac_meta_arch.embedding_distance_to_probability(
def test_embedding_projection_value(self, value):
dist = deepmac_meta_arch.embedding_projection(
tf.zeros((1, 1, 1, 8)), value + tf.zeros((1, 1, 1, 8))).numpy()
max_float = np.finfo(dist.dtype).max
self.assertLess(dist.max(), max_float)
self.assertGreater(dist.max(), -max_float)
@parameterized.named_parameters(
[('no_conv_shortcut', (False,)),
('conv_shortcut', (True,))]
)
def test_res_dense_block(self, conv_shortcut):
net = deepmac_meta_arch.DenseResidualBlock(32, conv_shortcut)
out = net(tf.zeros((2, 32)))
self.assertEqual(out.shape, (2, 32))
@parameterized.parameters(
[4, 8, 20]
)
def test_dense_resnet(self, num_layers):
net = deepmac_meta_arch.DenseResNet(num_layers, 16, 8)
out = net(tf.zeros((2, 24)))
self.assertEqual(out.shape, (2, 8))
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class DeepMACMaskHeadTest(tf.test.TestCase):
class DeepMACMaskHeadTest(tf.test.TestCase, parameterized.TestCase):
def test_mask_network(self):
net = deepmac_meta_arch.MaskHeadNetwork('hourglass10', 8)
......@@ -218,10 +237,10 @@ class DeepMACMaskHeadTest(tf.test.TestCase):
out = call_func(tf.zeros((2, 4)), tf.zeros((2, 32, 32, 16)), training=True)
self.assertEqual(out.shape, (2, 32, 32))
def test_mask_network_embedding_distance_zero_dist(self):
def test_mask_network_embedding_projection_zero(self):
net = deepmac_meta_arch.MaskHeadNetwork(
'embedding_distance_probability', num_init_channels=8,
'embedding_projection', num_init_channels=8,
use_instance_embedding=False)
call_func = tf.function(net.__call__)
......@@ -230,10 +249,10 @@ class DeepMACMaskHeadTest(tf.test.TestCase):
self.assertAllGreater(out.numpy(), -np.inf)
self.assertAllLess(out.numpy(), np.inf)
def test_mask_network_embedding_distance_small_dist(self):
def test_mask_network_embedding_projection_small(self):
net = deepmac_meta_arch.MaskHeadNetwork(
'embedding_distance_probability', num_init_channels=-1,
'embedding_projection', num_init_channels=-1,
use_instance_embedding=False)
call_func = tf.function(net.__call__)
......@@ -396,9 +415,9 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
prob = tf.nn.sigmoid(0.9).numpy()
self.assertAllClose(masks, prob * np.ones((2, 3, 16, 16)))
def test_postprocess_emb_dist(self):
def test_postprocess_emb_proj(self):
model = build_meta_arch(network_type='embedding_distance_probability',
model = build_meta_arch(network_type='embedding_projection',
use_instance_embedding=False,
use_xy=False, pixel_embedding_dim=8,
use_dice_loss=True,
......@@ -412,14 +431,13 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
boxes, tf.zeros((2, 32, 32, 2)), tf.zeros((2, 32, 32, 2)))
self.assertEqual(masks.shape, (2, 3, 16, 16))
def test_postprocess_emb_dist_fullres(self):
def test_postprocess_emb_proj_fullres(self):
model = build_meta_arch(network_type='embedding_distance_probability',
model = build_meta_arch(network_type='embedding_projection',
predict_full_resolution_masks=True,
use_instance_embedding=False,
pixel_embedding_dim=8, use_xy=False,
use_dice_loss=True,
dice_loss_prediction_probability=True)
use_dice_loss=True)
boxes = np.zeros((2, 3, 4), dtype=np.float32)
boxes = tf.constant(boxes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment