Commit 3324d816 authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Implement embedding based similarity mask head for DeepMAC.

PiperOrigin-RevId: 394512536
parent 0daae829
......@@ -263,7 +263,8 @@ def _build_classification_loss(loss_config):
elif loss_type == 'weighted_dice_classification_loss':
config = loss_config.weighted_dice_classification_loss
return losses.WeightedDiceClassificationLoss(
squared_normalization=config.squared_normalization)
squared_normalization=config.squared_normalization,
is_prediction_probability=config.is_prediction_probability)
else:
raise ValueError('Empty loss config.')
......@@ -286,15 +286,19 @@ class WeightedDiceClassificationLoss(Loss):
"""
def __init__(self, squared_normalization):
def __init__(self, squared_normalization, is_prediction_probability=False):
"""Initializes the loss object.
Args:
squared_normalization: boolean, if set, we square the probabilities in the
denominator term used for normalization.
is_prediction_probability: boolean, whether or not the input
prediction_tensor represents a probability. If false, it is
first converted to a probability by applying sigmoid.
"""
self._squared_normalization = squared_normalization
self.is_prediction_probability = is_prediction_probability
super(WeightedDiceClassificationLoss, self).__init__()
def _compute_loss(self,
......@@ -332,7 +336,10 @@ class WeightedDiceClassificationLoss(Loss):
tf.shape(prediction_tensor)[2]),
[1, 1, -1])
prob_tensor = tf.nn.sigmoid(prediction_tensor)
if self.is_prediction_probability:
prob_tensor = prediction_tensor
else:
prob_tensor = tf.nn.sigmoid(prediction_tensor)
if self._squared_normalization:
prob_tensor = tf.pow(prob_tensor, 2)
......
......@@ -36,7 +36,8 @@ class DeepMACParams(
'allowed_masked_classes_ids', 'mask_size', 'mask_num_subsamples',
'use_xy', 'network_type', 'use_instance_embedding', 'num_init_channels',
'predict_full_resolution_masks', 'postprocess_crop_size',
'max_roi_jitter_ratio', 'roi_jitter_mode', 'box_consistency_loss_weight'
'max_roi_jitter_ratio', 'roi_jitter_mode',
'box_consistency_loss_weight',
])):
"""Class holding the DeepMAC network configutration."""
......@@ -125,6 +126,9 @@ def _get_deepmac_network_by_type(name, num_init_channels, mask_size=None):
raise ValueError('Mask size must be set.')
return FullyConnectedMaskHead(num_init_channels, mask_size)
elif name == 'embedding_distance_probability':
return tf.keras.layers.Lambda(lambda x: x)
elif name.startswith('resnet'):
return ResNetMaskNetwork(name, num_init_channels)
......@@ -262,6 +266,25 @@ def fill_boxes(boxes, height, width):
return tf.cast(filled_boxes, tf.float32)
def embedding_distance_to_probability(x, y):
"""Compute probability based on pixel-wise embedding distance.
Args:
x: [num_instances, height, width, dimension] float tensor input.
y: [num_instances, height, width, dimension] or
[num_instances, 1, 1, dimension] float tensor input. When the height
and width dimensions are 1, TF will broadcast it.
Returns:
dist: [num_instances, height, width, 1] A float tensor returning
the per-pixel probability. Pixels whose embeddings are close in
euclidean distance get a probability of close to 1.
"""
diff = x - y
squared_dist = tf.reduce_sum(diff * diff, axis=3, keepdims=True)
return tf.exp(-squared_dist)
class ResNetMaskNetwork(tf.keras.layers.Layer):
"""A small wrapper around ResNet blocks to predict masks."""
......@@ -366,8 +389,18 @@ class MaskHeadNetwork(tf.keras.layers.Layer):
network_type, num_init_channels, mask_size)
self._use_instance_embedding = use_instance_embedding
self.project_out = tf.keras.layers.Conv2D(
filters=1, kernel_size=1, activation=None)
self._network_type = network_type
if (self._use_instance_embedding and
(self._network_type == 'embedding_distance_probability')):
raise ValueError(('Cannot feed instance embedding to mask head when '
'computing distance from instance embedding.'))
if network_type == 'embedding_distance_probability':
self.project_out = tf.keras.layers.Lambda(lambda x: x)
else:
self.project_out = tf.keras.layers.Conv2D(
filters=1, kernel_size=1, activation=None)
def __call__(self, instance_embedding, pixel_embedding, training):
"""Returns mask logits given object center and spatial embeddings.
......@@ -388,10 +421,9 @@ class MaskHeadNetwork(tf.keras.layers.Layer):
height = tf.shape(pixel_embedding)[1]
width = tf.shape(pixel_embedding)[2]
instance_embedding = instance_embedding[:, tf.newaxis, tf.newaxis, :]
instance_embedding = tf.tile(instance_embedding, [1, height, width, 1])
if self._use_instance_embedding:
instance_embedding = instance_embedding[:, tf.newaxis, tf.newaxis, :]
instance_embedding = tf.tile(instance_embedding, [1, height, width, 1])
inputs = tf.concat([pixel_embedding, instance_embedding], axis=3)
else:
inputs = pixel_embedding
......@@ -400,6 +432,10 @@ class MaskHeadNetwork(tf.keras.layers.Layer):
if isinstance(out, list):
out = out[-1]
if self._network_type == 'embedding_distance_probability':
instance_embedding = instance_embedding[:, tf.newaxis, tf.newaxis, :]
out = embedding_distance_to_probability(instance_embedding, out)
if out.shape[-1] > 1:
out = self.project_out(out)
......@@ -466,6 +502,25 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
if self._deepmac_params.mask_num_subsamples > 0:
raise ValueError('Subsampling masks is currently not supported.')
if self._deepmac_params.network_type == 'embedding_distance_probability':
if self._deepmac_params.use_xy:
raise ValueError(
'Cannot use x/y coordinates when using embedding distance.')
pixel_embedding_dim = self._deepmac_params.pixel_embedding_dim
dim = self._deepmac_params.dim
if dim != pixel_embedding_dim:
raise ValueError(
'When using embedding distance mask head, '
f'pixel_embedding_dim({pixel_embedding_dim}) '
f'must be same as dim({dim}).')
loss = self._deepmac_params.classification_loss
if ((not isinstance(loss, losses.WeightedDiceClassificationLoss))
or (not loss.is_prediction_probability)):
raise ValueError('Only dice loss with is_prediction_probability=true '
'is supported with embedding distance mask head.')
super(DeepMACMetaArch, self).__init__(
is_training=is_training, add_summaries=add_summaries,
num_classes=num_classes, feature_extractor=feature_extractor,
......@@ -909,7 +964,10 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
mask_logits = crop_masks_within_boxes(
mask_logits, boxes, self._deepmac_params.postprocess_crop_size)
masks_prob = tf.nn.sigmoid(mask_logits)
if self._deepmac_params.network_type == 'embedding_distance_probability':
masks_prob = mask_logits
else:
masks_prob = tf.nn.sigmoid(mask_logits)
return masks_prob
......
......@@ -61,7 +61,10 @@ class MockMaskNet(tf.keras.layers.Layer):
def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False,
mask_num_subsamples=-1):
use_instance_embedding=True, mask_num_subsamples=-1,
network_type='hourglass10', use_xy=True,
pixel_embedding_dim=2,
dice_loss_prediction_probability=False):
"""Builds the DeepMAC meta architecture."""
feature_extractor = DummyFeatureExtractor(
......@@ -84,7 +87,9 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False,
use_labeled_classes=False)
if use_dice_loss:
classification_loss = losses.WeightedDiceClassificationLoss(False)
classification_loss = losses.WeightedDiceClassificationLoss(
squared_normalization=False,
is_prediction_probability=dice_loss_prediction_probability)
else:
classification_loss = losses.WeightedSigmoidClassificationLoss()
......@@ -92,13 +97,13 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False,
classification_loss=classification_loss,
dim=8,
task_loss_weight=1.0,
pixel_embedding_dim=2,
pixel_embedding_dim=pixel_embedding_dim,
allowed_masked_classes_ids=[],
mask_size=16,
mask_num_subsamples=mask_num_subsamples,
use_xy=True,
network_type='hourglass10',
use_instance_embedding=True,
use_xy=use_xy,
network_type=network_type,
use_instance_embedding=use_instance_embedding,
num_init_channels=8,
predict_full_resolution_masks=predict_full_resolution_masks,
postprocess_crop_size=128,
......@@ -125,7 +130,7 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False,
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class DeepMACUtilsTest(tf.test.TestCase):
class DeepMACUtilsTest(tf.test.TestCase, parameterized.TestCase):
def test_subsample_trivial(self):
"""Test subsampling masks."""
......@@ -169,12 +174,22 @@ class DeepMACUtilsTest(tf.test.TestCase):
features, boxes, 32)
self.assertEqual(output.shape, (5, 32, 32, 7))
def test_embedding_distance_prob_shape(self):
dist = deepmac_meta_arch.embedding_distance_to_probability(
tf.ones((4, 32, 32, 8)), tf.zeros((4, 32, 32, 8)))
self.assertEqual(dist.shape, (4, 32, 32, 1))
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class DeepMACMetaArchTest(tf.test.TestCase):
@parameterized.parameters([1e-20, 1e20])
def test_embedding_distance_prob_value(self, value):
dist = deepmac_meta_arch.embedding_distance_to_probability(
tf.zeros((1, 1, 1, 8)), value + tf.zeros((1, 1, 1, 8))).numpy()
max_float = np.finfo(dist.dtype).max
self.assertLess(dist.max(), max_float)
self.assertGreater(dist.max(), -max_float)
def setUp(self): # pylint:disable=g-missing-super-call
self.model = build_meta_arch()
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class DeepMACMaskHeadTest(tf.test.TestCase):
def test_mask_network(self):
net = deepmac_meta_arch.MaskHeadNetwork('hourglass10', 8)
......@@ -203,6 +218,38 @@ class DeepMACMetaArchTest(tf.test.TestCase):
out = call_func(tf.zeros((2, 4)), tf.zeros((2, 32, 32, 16)), training=True)
self.assertEqual(out.shape, (2, 32, 32))
def test_mask_network_embedding_distance_zero_dist(self):
net = deepmac_meta_arch.MaskHeadNetwork(
'embedding_distance_probability', num_init_channels=8,
use_instance_embedding=False)
call_func = tf.function(net.__call__)
out = call_func(tf.zeros((2, 7)), tf.zeros((2, 32, 32, 7)), training=True)
self.assertEqual(out.shape, (2, 32, 32))
self.assertAllGreater(out.numpy(), -np.inf)
self.assertAllLess(out.numpy(), np.inf)
def test_mask_network_embedding_distance_small_dist(self):
net = deepmac_meta_arch.MaskHeadNetwork(
'embedding_distance_probability', num_init_channels=-1,
use_instance_embedding=False)
call_func = tf.function(net.__call__)
out = call_func(1e6 + tf.zeros((2, 7)),
tf.zeros((2, 32, 32, 7)), training=True)
self.assertEqual(out.shape, (2, 32, 32))
self.assertAllGreater(out.numpy(), -np.inf)
self.assertAllLess(out.numpy(), np.inf)
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self): # pylint:disable=g-missing-super-call
self.model = build_meta_arch()
def test_get_mask_head_input(self):
boxes = tf.constant([[0., 0., 0.25, 0.25], [0.75, 0.75, 1.0, 1.0]],
......@@ -349,6 +396,37 @@ class DeepMACMetaArchTest(tf.test.TestCase):
prob = tf.nn.sigmoid(0.9).numpy()
self.assertAllClose(masks, prob * np.ones((2, 3, 16, 16)))
def test_postprocess_emb_dist(self):
model = build_meta_arch(network_type='embedding_distance_probability',
use_instance_embedding=False,
use_xy=False, pixel_embedding_dim=8,
use_dice_loss=True,
dice_loss_prediction_probability=True)
boxes = np.zeros((2, 3, 4), dtype=np.float32)
boxes[:, :, [0, 2]] = 0.0
boxes[:, :, [1, 3]] = 8.0
boxes = tf.constant(boxes)
masks = model._postprocess_masks(
boxes, tf.zeros((2, 32, 32, 2)), tf.zeros((2, 32, 32, 2)))
self.assertEqual(masks.shape, (2, 3, 16, 16))
def test_postprocess_emb_dist_fullres(self):
model = build_meta_arch(network_type='embedding_distance_probability',
predict_full_resolution_masks=True,
use_instance_embedding=False,
pixel_embedding_dim=8, use_xy=False,
use_dice_loss=True,
dice_loss_prediction_probability=True)
boxes = np.zeros((2, 3, 4), dtype=np.float32)
boxes = tf.constant(boxes)
masks = model._postprocess_masks(
boxes, tf.zeros((2, 32, 32, 2)), tf.zeros((2, 32, 32, 2)))
self.assertEqual(masks.shape, (2, 3, 128, 128))
def test_postprocess_no_crop_resize_shape(self):
model = build_meta_arch(predict_full_resolution_masks=True)
......@@ -494,7 +572,7 @@ class FullyConnectedMaskHeadTest(tf.test.TestCase):
class ResNetMaskHeadTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(['resnet4', 'resnet8', 'resnet20'])
def test_pass(self, name):
def test_forward(self, name):
net = deepmac_meta_arch.ResNetMaskNetwork(name, 8)
out = net(tf.zeros((3, 32, 32, 16)))
self.assertEqual(out.shape[:3], (3, 32, 32))
......
......@@ -231,6 +231,10 @@ message WeightedDiceClassificationLoss {
// If set, we square the probabilities in the denominator term used for
// normalization.
optional bool squared_normalization = 1 [default=false];
// Whether or not the input prediction to the loss function is a
// probability. If not, the input is to be interpreted as logit
optional bool is_prediction_probability = 2 [default=false];
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment