"vscode:/vscode.git/clone" did not exist on "1acc1f561ae859c94c3da746d629d1f60cbe00b6"
Commit 3324d816 authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Implement embedding based similarity mask head for DeepMAC.

PiperOrigin-RevId: 394512536
parent 0daae829
......@@ -263,7 +263,8 @@ def _build_classification_loss(loss_config):
elif loss_type == 'weighted_dice_classification_loss':
config = loss_config.weighted_dice_classification_loss
return losses.WeightedDiceClassificationLoss(
squared_normalization=config.squared_normalization)
squared_normalization=config.squared_normalization,
is_prediction_probability=config.is_prediction_probability)
else:
raise ValueError('Empty loss config.')
......@@ -286,15 +286,19 @@ class WeightedDiceClassificationLoss(Loss):
"""
def __init__(self, squared_normalization):
def __init__(self, squared_normalization, is_prediction_probability=False):
"""Initializes the loss object.
Args:
squared_normalization: boolean, if set, we square the probabilities in the
denominator term used for normalization.
is_prediction_probability: boolean, whether or not the input
prediction_tensor represents a probability. If false, it is
first converted to a probability by applying sigmoid.
"""
self._squared_normalization = squared_normalization
self.is_prediction_probability = is_prediction_probability
super(WeightedDiceClassificationLoss, self).__init__()
def _compute_loss(self,
......@@ -332,6 +336,9 @@ class WeightedDiceClassificationLoss(Loss):
tf.shape(prediction_tensor)[2]),
[1, 1, -1])
if self.is_prediction_probability:
prob_tensor = prediction_tensor
else:
prob_tensor = tf.nn.sigmoid(prediction_tensor)
if self._squared_normalization:
......
......@@ -36,7 +36,8 @@ class DeepMACParams(
'allowed_masked_classes_ids', 'mask_size', 'mask_num_subsamples',
'use_xy', 'network_type', 'use_instance_embedding', 'num_init_channels',
'predict_full_resolution_masks', 'postprocess_crop_size',
'max_roi_jitter_ratio', 'roi_jitter_mode', 'box_consistency_loss_weight'
'max_roi_jitter_ratio', 'roi_jitter_mode',
'box_consistency_loss_weight',
])):
"""Class holding the DeepMAC network configutration."""
......@@ -125,6 +126,9 @@ def _get_deepmac_network_by_type(name, num_init_channels, mask_size=None):
raise ValueError('Mask size must be set.')
return FullyConnectedMaskHead(num_init_channels, mask_size)
elif name == 'embedding_distance_probability':
return tf.keras.layers.Lambda(lambda x: x)
elif name.startswith('resnet'):
return ResNetMaskNetwork(name, num_init_channels)
......@@ -262,6 +266,25 @@ def fill_boxes(boxes, height, width):
return tf.cast(filled_boxes, tf.float32)
def embedding_distance_to_probability(x, y):
"""Compute probability based on pixel-wise embedding distance.
Args:
x: [num_instances, height, width, dimension] float tensor input.
y: [num_instances, height, width, dimension] or
[num_instances, 1, 1, dimension] float tensor input. When the height
and width dimensions are 1, TF will broadcast it.
Returns:
dist: [num_instances, height, width, 1] A float tensor returning
the per-pixel probability. Pixels whose embeddings are close in
euclidean distance get a probability of close to 1.
"""
diff = x - y
squared_dist = tf.reduce_sum(diff * diff, axis=3, keepdims=True)
return tf.exp(-squared_dist)
class ResNetMaskNetwork(tf.keras.layers.Layer):
"""A small wrapper around ResNet blocks to predict masks."""
......@@ -366,6 +389,16 @@ class MaskHeadNetwork(tf.keras.layers.Layer):
network_type, num_init_channels, mask_size)
self._use_instance_embedding = use_instance_embedding
self._network_type = network_type
if (self._use_instance_embedding and
(self._network_type == 'embedding_distance_probability')):
raise ValueError(('Cannot feed instance embedding to mask head when '
'computing distance from instance embedding.'))
if network_type == 'embedding_distance_probability':
self.project_out = tf.keras.layers.Lambda(lambda x: x)
else:
self.project_out = tf.keras.layers.Conv2D(
filters=1, kernel_size=1, activation=None)
......@@ -388,10 +421,9 @@ class MaskHeadNetwork(tf.keras.layers.Layer):
height = tf.shape(pixel_embedding)[1]
width = tf.shape(pixel_embedding)[2]
if self._use_instance_embedding:
instance_embedding = instance_embedding[:, tf.newaxis, tf.newaxis, :]
instance_embedding = tf.tile(instance_embedding, [1, height, width, 1])
if self._use_instance_embedding:
inputs = tf.concat([pixel_embedding, instance_embedding], axis=3)
else:
inputs = pixel_embedding
......@@ -400,6 +432,10 @@ class MaskHeadNetwork(tf.keras.layers.Layer):
if isinstance(out, list):
out = out[-1]
if self._network_type == 'embedding_distance_probability':
instance_embedding = instance_embedding[:, tf.newaxis, tf.newaxis, :]
out = embedding_distance_to_probability(instance_embedding, out)
if out.shape[-1] > 1:
out = self.project_out(out)
......@@ -466,6 +502,25 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
if self._deepmac_params.mask_num_subsamples > 0:
raise ValueError('Subsampling masks is currently not supported.')
if self._deepmac_params.network_type == 'embedding_distance_probability':
if self._deepmac_params.use_xy:
raise ValueError(
'Cannot use x/y coordinates when using embedding distance.')
pixel_embedding_dim = self._deepmac_params.pixel_embedding_dim
dim = self._deepmac_params.dim
if dim != pixel_embedding_dim:
raise ValueError(
'When using embedding distance mask head, '
f'pixel_embedding_dim({pixel_embedding_dim}) '
f'must be same as dim({dim}).')
loss = self._deepmac_params.classification_loss
if ((not isinstance(loss, losses.WeightedDiceClassificationLoss))
or (not loss.is_prediction_probability)):
raise ValueError('Only dice loss with is_prediction_probability=true '
'is supported with embedding distance mask head.')
super(DeepMACMetaArch, self).__init__(
is_training=is_training, add_summaries=add_summaries,
num_classes=num_classes, feature_extractor=feature_extractor,
......@@ -909,6 +964,9 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
mask_logits = crop_masks_within_boxes(
mask_logits, boxes, self._deepmac_params.postprocess_crop_size)
if self._deepmac_params.network_type == 'embedding_distance_probability':
masks_prob = mask_logits
else:
masks_prob = tf.nn.sigmoid(mask_logits)
return masks_prob
......
......@@ -61,7 +61,10 @@ class MockMaskNet(tf.keras.layers.Layer):
def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False,
mask_num_subsamples=-1):
use_instance_embedding=True, mask_num_subsamples=-1,
network_type='hourglass10', use_xy=True,
pixel_embedding_dim=2,
dice_loss_prediction_probability=False):
"""Builds the DeepMAC meta architecture."""
feature_extractor = DummyFeatureExtractor(
......@@ -84,7 +87,9 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False,
use_labeled_classes=False)
if use_dice_loss:
classification_loss = losses.WeightedDiceClassificationLoss(False)
classification_loss = losses.WeightedDiceClassificationLoss(
squared_normalization=False,
is_prediction_probability=dice_loss_prediction_probability)
else:
classification_loss = losses.WeightedSigmoidClassificationLoss()
......@@ -92,13 +97,13 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False,
classification_loss=classification_loss,
dim=8,
task_loss_weight=1.0,
pixel_embedding_dim=2,
pixel_embedding_dim=pixel_embedding_dim,
allowed_masked_classes_ids=[],
mask_size=16,
mask_num_subsamples=mask_num_subsamples,
use_xy=True,
network_type='hourglass10',
use_instance_embedding=True,
use_xy=use_xy,
network_type=network_type,
use_instance_embedding=use_instance_embedding,
num_init_channels=8,
predict_full_resolution_masks=predict_full_resolution_masks,
postprocess_crop_size=128,
......@@ -125,7 +130,7 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False,
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class DeepMACUtilsTest(tf.test.TestCase):
class DeepMACUtilsTest(tf.test.TestCase, parameterized.TestCase):
def test_subsample_trivial(self):
"""Test subsampling masks."""
......@@ -169,12 +174,22 @@ class DeepMACUtilsTest(tf.test.TestCase):
features, boxes, 32)
self.assertEqual(output.shape, (5, 32, 32, 7))
def test_embedding_distance_prob_shape(self):
dist = deepmac_meta_arch.embedding_distance_to_probability(
tf.ones((4, 32, 32, 8)), tf.zeros((4, 32, 32, 8)))
self.assertEqual(dist.shape, (4, 32, 32, 1))
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class DeepMACMetaArchTest(tf.test.TestCase):
@parameterized.parameters([1e-20, 1e20])
def test_embedding_distance_prob_value(self, value):
dist = deepmac_meta_arch.embedding_distance_to_probability(
tf.zeros((1, 1, 1, 8)), value + tf.zeros((1, 1, 1, 8))).numpy()
max_float = np.finfo(dist.dtype).max
self.assertLess(dist.max(), max_float)
self.assertGreater(dist.max(), -max_float)
def setUp(self): # pylint:disable=g-missing-super-call
self.model = build_meta_arch()
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class DeepMACMaskHeadTest(tf.test.TestCase):
def test_mask_network(self):
net = deepmac_meta_arch.MaskHeadNetwork('hourglass10', 8)
......@@ -203,6 +218,38 @@ class DeepMACMetaArchTest(tf.test.TestCase):
out = call_func(tf.zeros((2, 4)), tf.zeros((2, 32, 32, 16)), training=True)
self.assertEqual(out.shape, (2, 32, 32))
def test_mask_network_embedding_distance_zero_dist(self):
net = deepmac_meta_arch.MaskHeadNetwork(
'embedding_distance_probability', num_init_channels=8,
use_instance_embedding=False)
call_func = tf.function(net.__call__)
out = call_func(tf.zeros((2, 7)), tf.zeros((2, 32, 32, 7)), training=True)
self.assertEqual(out.shape, (2, 32, 32))
self.assertAllGreater(out.numpy(), -np.inf)
self.assertAllLess(out.numpy(), np.inf)
def test_mask_network_embedding_distance_small_dist(self):
net = deepmac_meta_arch.MaskHeadNetwork(
'embedding_distance_probability', num_init_channels=-1,
use_instance_embedding=False)
call_func = tf.function(net.__call__)
out = call_func(1e6 + tf.zeros((2, 7)),
tf.zeros((2, 32, 32, 7)), training=True)
self.assertEqual(out.shape, (2, 32, 32))
self.assertAllGreater(out.numpy(), -np.inf)
self.assertAllLess(out.numpy(), np.inf)
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self): # pylint:disable=g-missing-super-call
self.model = build_meta_arch()
def test_get_mask_head_input(self):
boxes = tf.constant([[0., 0., 0.25, 0.25], [0.75, 0.75, 1.0, 1.0]],
......@@ -349,6 +396,37 @@ class DeepMACMetaArchTest(tf.test.TestCase):
prob = tf.nn.sigmoid(0.9).numpy()
self.assertAllClose(masks, prob * np.ones((2, 3, 16, 16)))
def test_postprocess_emb_dist(self):
model = build_meta_arch(network_type='embedding_distance_probability',
use_instance_embedding=False,
use_xy=False, pixel_embedding_dim=8,
use_dice_loss=True,
dice_loss_prediction_probability=True)
boxes = np.zeros((2, 3, 4), dtype=np.float32)
boxes[:, :, [0, 2]] = 0.0
boxes[:, :, [1, 3]] = 8.0
boxes = tf.constant(boxes)
masks = model._postprocess_masks(
boxes, tf.zeros((2, 32, 32, 2)), tf.zeros((2, 32, 32, 2)))
self.assertEqual(masks.shape, (2, 3, 16, 16))
def test_postprocess_emb_dist_fullres(self):
model = build_meta_arch(network_type='embedding_distance_probability',
predict_full_resolution_masks=True,
use_instance_embedding=False,
pixel_embedding_dim=8, use_xy=False,
use_dice_loss=True,
dice_loss_prediction_probability=True)
boxes = np.zeros((2, 3, 4), dtype=np.float32)
boxes = tf.constant(boxes)
masks = model._postprocess_masks(
boxes, tf.zeros((2, 32, 32, 2)), tf.zeros((2, 32, 32, 2)))
self.assertEqual(masks.shape, (2, 3, 128, 128))
def test_postprocess_no_crop_resize_shape(self):
model = build_meta_arch(predict_full_resolution_masks=True)
......@@ -494,7 +572,7 @@ class FullyConnectedMaskHeadTest(tf.test.TestCase):
class ResNetMaskHeadTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(['resnet4', 'resnet8', 'resnet20'])
def test_pass(self, name):
def test_forward(self, name):
net = deepmac_meta_arch.ResNetMaskNetwork(name, 8)
out = net(tf.zeros((3, 32, 32, 16)))
self.assertEqual(out.shape[:3], (3, 32, 32))
......
......@@ -231,6 +231,10 @@ message WeightedDiceClassificationLoss {
// If set, we square the probabilities in the denominator term used for
// normalization.
optional bool squared_normalization = 1 [default=false];
// Whether or not the input prediction to the loss function is a
// probability. If not, the input is to be interpreted as logit
optional bool is_prediction_probability = 2 [default=false];
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment